diff --git a/CHANGELOG.md b/CHANGELOG.md index 774064191714b..5bcd8cbb0a5d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -293,6 +293,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated Accelerator collective API `barrier`, `broadcast`, and `all_gather`, call `TrainingTypePlugin` collective API directly ([#9677](https://github.com/PyTorchLightning/pytorch-lightning/pull/9677)) +- Deprecated passing `resume_from_checkpoint` to the `Trainer` constructor in favor of `trainer.fit(ckpt_path=)` ([#9693](https://github.com/PyTorchLightning/pytorch-lightning/pull/9693)) + + ### Removed - Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/)) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index b8931c415553b..d14cfdfcd6f4b 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -270,7 +270,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: difference = callback_states.keys() - current_callbacks_keys if difference: rank_zero_warn( - "Be aware that when using `resume_from_checkpoint`," + "Be aware that when using `ckpt_path`," " callbacks used to create the checkpoint need to be provided." f" Please add the following callbacks: {list(difference)}.", UserWarning, diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 1dd87e3d4b5be..89734527b5dcd 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -37,7 +37,12 @@ class CheckpointConnector: def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH] = None) -> None: self.trainer = trainer - self.resume_checkpoint_path = resume_from_checkpoint + self.resume_checkpoint_path: Optional[_PATH] = resume_from_checkpoint + if resume_from_checkpoint is not None: + rank_zero_deprecation( + "Setting `Trainer(resume_from_checkpoint=)` is deprecated in v1.5 and" + " will be removed in v1.7. Please pass `Trainer.fit(ckpt_path=)` directly instead." + ) self._loaded_checkpoint: Dict[str, Any] = {} @property @@ -47,17 +52,17 @@ def hpc_resume_path(self) -> Optional[str]: if max_version is not None: return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt") - def resume_start(self) -> None: + def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: """Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: 1. from HPC weights if found - 2. from `resume_from_checkpoint` file if provided + 2. from `checkpoint_path` file if provided 3. don't restore Raises: FileNotFoundError: If the path to the checkpoint file is provided but the file does not exist. """ - self.resume_checkpoint_path = self.hpc_resume_path or self.resume_checkpoint_path + self.resume_checkpoint_path = self.hpc_resume_path or checkpoint_path checkpoint_path = self.resume_checkpoint_path if not checkpoint_path: return @@ -96,7 +101,7 @@ def restore(self, checkpoint_path: Optional[_PATH] = None) -> None: state-restore, in this priority: 1. from HPC weights if found - 2. from `resume_from_checkpoint` file if provided + 2. from `checkpoint_path` file if provided 3. don't restore All restored states are listed in return value description of `dump_checkpoint`. @@ -104,8 +109,7 @@ def restore(self, checkpoint_path: Optional[_PATH] = None) -> None: Args: checkpoint_path: Path to a PyTorch Lightning checkpoint file. """ - self.resume_checkpoint_path = checkpoint_path - self.resume_start() + self.resume_start(checkpoint_path) # restore module states self.restore_datamodule() @@ -154,15 +158,6 @@ def restore_model(self) -> None: if isinstance(module, Metric): module.reset() - def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None: - """Restore only the model weights.""" - checkpoint = self._loaded_checkpoint - if checkpoint_path is not None: - checkpoint = self._load_and_validate_checkpoint(checkpoint_path) - - self.trainer.lightning_module.on_load_checkpoint(checkpoint) - self.trainer.training_type_plugin.load_model_state_dict(checkpoint) - def restore_training_state(self) -> None: """Restore the trainer state from the pre-loaded checkpoint. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7a85afeb70928..6c53fb9769f9d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -342,6 +342,10 @@ def __init__( no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch. + .. deprecated:: v1.5 + ``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v1.7. + Please use ``Trainer.fit(ckpt_path)`` instead. + sync_batchnorm: Synchronize batch norm layers between process groups/whole world. terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the @@ -574,6 +578,7 @@ def fit( val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, train_dataloader=None, # TODO: remove with 1.6 + ckpt_path: Optional[str] = None, ) -> None: r""" Runs the full optimization routine. @@ -587,6 +592,10 @@ def fit( val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. + ckpt_path: Path/URL of the checkpoint from which training is resumed. If there is + no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint, + training will start from the beginning of the next epoch. + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. """ if train_dataloader is not None: @@ -595,7 +604,9 @@ def fit( " Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'" ) train_dataloaders = train_dataloader - self._call_and_handle_interrupt(self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule) + self._call_and_handle_interrupt( + self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path + ) def _fit_impl( self, @@ -603,6 +614,7 @@ def _fit_impl( train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, + ckpt_path: Optional[str] = None, ) -> None: Trainer._log_api_event("fit") @@ -625,7 +637,9 @@ def _fit_impl( model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule ) - self._run(model) + # TODO: ckpt_path only in v1.7 + ckpt_path = ckpt_path or self.resume_from_checkpoint + self._run(model, ckpt_path) assert self.state.stopped self.training = False @@ -711,7 +725,7 @@ def _validate_impl( ) # run validate - results = self._run(model) + results = self._run(model, self.validated_ckpt_path) assert self.state.stopped self.validating = False @@ -800,7 +814,7 @@ def _test_impl( ) # run test - results = self._run(model) + results = self._run(model, self.tested_ckpt_path) assert self.state.stopped self.testing = False @@ -882,7 +896,7 @@ def _predict_impl( ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) - results = self._run(model) + results = self._run(model, self.predicted_ckpt_path) assert self.state.stopped self.predicting = False @@ -951,24 +965,18 @@ def tune( return result - def _restore_modules_and_callbacks(self) -> None: + def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None: # restore modules after setup - if self.state.fn == TrainerFn.FITTING: - self.checkpoint_connector.resume_start() - self.checkpoint_connector.restore_datamodule() + self.checkpoint_connector.resume_start(checkpoint_path) self.checkpoint_connector.restore_model() - # restore callback states - self.checkpoint_connector.restore_callbacks() - - def _load_checkpoint_weights(self): - # only one process running at this point for TPUs, as spawn isn't triggered yet - # todo: move this logic internally within the barrier. - if not self._device_type == DeviceType.TPU: - self.training_type_plugin.barrier() - rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}") - self.checkpoint_connector.restore_model_weights(self._ckpt_path) - - def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: + if self.state.fn == TrainerFn.FITTING: + self.checkpoint_connector.restore_datamodule() + # restore callback states + self.checkpoint_connector.restore_callbacks() + + def _run( + self, model: "pl.LightningModule", ckpt_path: Optional[str] = None + ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) @@ -985,9 +993,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.data_connector.prepare_data() self.callback_connector._attach_model_callbacks() - if self._ckpt_path and not self.accelerator.restore_checkpoint_after_pre_dispatch: - self._load_checkpoint_weights() - # ---------------------------- # SET UP TRAINING # ---------------------------- @@ -997,7 +1002,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, # check if we should delay restoring checkpoint till later if not self.accelerator.restore_checkpoint_after_pre_dispatch: - self._restore_modules_and_callbacks() + self._restore_modules_and_callbacks(ckpt_path) self._call_configure_sharded_model() # allow user to setup in model sharded environment self.accelerator.setup(self) @@ -1046,12 +1051,13 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self._pre_dispatch() if self.accelerator.restore_checkpoint_after_pre_dispatch: - if self._ckpt_path: - self._load_checkpoint_weights() - self._restore_modules_and_callbacks() + self._restore_modules_and_callbacks(ckpt_path) # restore optimizers, etc. - self.checkpoint_connector.restore_training_state() + if self.state.fn == TrainerFn.FITTING: + self.checkpoint_connector.restore_training_state() + + self.checkpoint_connector.resume_end() # dispatch `start_training` or `start_evaluating` or `start_predicting` self._dispatch() @@ -1152,8 +1158,6 @@ def _pre_training_routine(self): # register signals self.signal_connector.register_signal_handlers() - self.checkpoint_connector.resume_end() - # -------------------------- # Pre-train # -------------------------- @@ -1742,6 +1746,10 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]: @property def resume_from_checkpoint(self) -> Optional[Union[str, Path]]: + rank_zero_deprecation( + "`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7." + " Specify fit ckpt_path with `trainer.fit(ckpt_path=)` instead." + ) return self.checkpoint_connector.resume_checkpoint_path def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None: @@ -1974,15 +1982,6 @@ def train_loop(self) -> FitLoop: ) return self.fit_loop - @property - def _ckpt_path(self) -> Optional[str]: - if self.state.fn == TrainerFn.VALIDATING: - return self.validated_ckpt_path - if self.state.fn == TrainerFn.TESTING: - return self.tested_ckpt_path - if self.state.fn == TrainerFn.PREDICTING: - return self.predicted_ckpt_path - """ Logging properties """ diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index f95d182f9e5e1..be66730efdb5c 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -80,10 +80,8 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch - trainer = Trainer( - default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True, resume_from_checkpoint=checkpoint_path - ) - trainer.fit(model) + trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True) + trainer.fit(model, ckpt_path=checkpoint_path) for func in (trainer.test, trainer.validate, trainer.predict): accelerator.training_type_plugin.predispatched_called = False func(model, ckpt_path=checkpoint_path) diff --git a/tests/accelerators/test_tpu_backend.py b/tests/accelerators/test_tpu_backend.py index 7f7bad327f515..0d7c0d5d8c3f8 100644 --- a/tests/accelerators/test_tpu_backend.py +++ b/tests/accelerators/test_tpu_backend.py @@ -61,10 +61,8 @@ def test_resume_training_on_cpu(tmpdir): assert weight_tensor.device == torch.device("cpu") # Verify that training is resumed on CPU - trainer = Trainer( - resume_from_checkpoint=model_path, checkpoint_callback=True, max_epochs=1, default_root_dir=tmpdir - ) - trainer.fit(model) + trainer = Trainer(checkpoint_callback=True, max_epochs=1, default_root_dir=tmpdir) + trainer.fit(model, ckpt_path=model_path) assert trainer.state.finished, f"Training failed with {trainer.state}" diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 78e21d821b810..b349b7cf5a487 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -132,8 +132,8 @@ def test_resume_callback_state_saved_by_type(tmpdir): assert ckpt_path.exists() callback = OldStatefulCallback(state=222) - trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path) - trainer.fit(model) + trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback]) + trainer.fit(model, ckpt_path=ckpt_path) assert callback.state == 111 @@ -153,16 +153,14 @@ def test_resume_incomplete_callbacks_list_warning(tmpdir): default_root_dir=tmpdir, max_steps=1, callbacks=[callback1], # one callback is missing! - resume_from_checkpoint=ckpt_path, ) with pytest.warns(UserWarning, match=escape(f"Please add the following callbacks: [{repr(callback0.state_key)}]")): - trainer.fit(model) + trainer.fit(model, ckpt_path=ckpt_path) trainer = Trainer( default_root_dir=tmpdir, max_steps=1, callbacks=[callback1, callback0], # all callbacks here, order switched - resume_from_checkpoint=ckpt_path, ) with no_warning_call(UserWarning, match="Please add the following callbacks:"): - trainer.fit(model) + trainer.fit(model, ckpt_path=ckpt_path) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index fe6873c8f43bf..48fbcc058ec2b 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -90,12 +90,11 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): new_trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - resume_from_checkpoint=checkpoint_filepath, callbacks=[early_stop_callback], ) with pytest.raises(MisconfigurationException, match=r"You restored a checkpoint with current_epoch"): - new_trainer.fit(model) + new_trainer.fit(model, ckpt_path=checkpoint_filepath) def test_early_stopping_no_extraneous_invocations(tmpdir): diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index 31b8b0e160132..1d736483641f9 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -128,8 +128,8 @@ def configure_optimizers(self): trainer.fit(model) assert model.backbone.has_been_used - trainer = Trainer(max_epochs=3, resume_from_checkpoint=chk.last_model_path) - trainer.fit(model) + trainer = Trainer(max_epochs=3) + trainer.fit(model, ckpt_path=chk.last_model_path) def test_freeze_unfreeze_function(tmpdir): @@ -258,9 +258,9 @@ def configure_optimizers(self): model = FreezeModel() cb = OnEpochLayerFinetuning() - trainer = Trainer(max_epochs=10, resume_from_checkpoint=chk.last_model_path, callbacks=[cb]) + trainer = Trainer(max_epochs=10, callbacks=[cb]) with pytest.raises(IndexError, match="index 6 is out of range"): - trainer.fit(model) + trainer.fit(model, ckpt_path=chk.last_model_path) def test_on_before_accelerator_backend_setup(tmpdir): @@ -400,10 +400,9 @@ def test_callbacks_restore(tmpdir): } trainer_kwargs["max_epochs"] = 3 - trainer_kwargs["resume_from_checkpoint"] = chk.last_model_path trainer = Trainer(**trainer_kwargs) - trainer.fit(model) + trainer.fit(model, ckpt_path=chk.last_model_path) def test_callbacks_restore_backbone(tmpdir): @@ -438,6 +437,5 @@ def forward(self, x): max_epochs=3, enable_progress_bar=False, callbacks=BackboneFinetuning(unfreeze_backbone_at_epoch=1), - resume_from_checkpoint=ckpt.last_model_path, ) - trainer.fit(BackboneBoringModel()) + trainer.fit(BackboneBoringModel(), ckpt_path=ckpt.last_model_path) diff --git a/tests/callbacks/test_lambda_function.py b/tests/callbacks/test_lambda_function.py index 88752d56bf697..f2fa040b43c78 100644 --- a/tests/callbacks/test_lambda_function.py +++ b/tests/callbacks/test_lambda_function.py @@ -48,6 +48,8 @@ def call(hook, *_, **__): ) trainer.fit(model) + ckpt_path = trainer.checkpoint_callback.best_model_path + # raises KeyboardInterrupt and loads from checkpoint trainer = Trainer( default_root_dir=tmpdir, @@ -56,10 +58,9 @@ def call(hook, *_, **__): limit_val_batches=1, limit_test_batches=1, limit_predict_batches=1, - resume_from_checkpoint=trainer.checkpoint_callback.best_model_path, callbacks=[LambdaCallback(**hooks_args)], ) - trainer.fit(model) + trainer.fit(model, ckpt_path=ckpt_path) trainer.test(model) trainer.predict(model) diff --git a/tests/callbacks/test_timer.py b/tests/callbacks/test_timer.py index 94ee3e87bc3e3..a307a72bdde3b 100644 --- a/tests/callbacks/test_timer.py +++ b/tests/callbacks/test_timer.py @@ -173,9 +173,8 @@ def test_timer_resume_training(tmpdir): trainer = Trainer( default_root_dir=tmpdir, callbacks=[timer, checkpoint_callback], - resume_from_checkpoint=checkpoint_callback.best_model_path, ) - trainer.fit(model) + trainer.fit(model, ckpt_path=checkpoint_callback.best_model_path) assert timer._offset > 0 assert trainer.global_step == saved_global_step + 1 diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 0910959fc7e7c..cd47ac00e7a3a 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -83,10 +83,9 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str): callbacks=[es, stop], max_epochs=21, accumulate_grad_batches=2, - resume_from_checkpoint=path_ckpt, ) torch.backends.cudnn.deterministic = True - trainer.fit(model, datamodule=dm) + trainer.fit(model, datamodule=dm, ckpt_path=path_ckpt) res = trainer.test(model, datamodule=dm) assert res[0]["test_loss"] <= 0.7 assert res[0]["test_acc"] >= 0.85 diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 1bb8d6b63ec9b..66a4512c2d4e9 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -905,11 +905,10 @@ def validation_step(self, batch, batch_idx): limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, - resume_from_checkpoint=checkpoint_callback.best_model_path, weights_summary=None, enable_progress_bar=False, ) - trainer.fit(model) + trainer.fit(model, ckpt_path=checkpoint_callback.best_model_path) trainer.test(model, verbose=False) assert set(os.listdir(tmpdir)) == {"epoch=00.ckpt", "lightning_logs"} assert set(os.listdir(tmpdir.join("lightning_logs"))) == {f"version_{i}" for i in range(4)} @@ -981,17 +980,16 @@ def assert_checkpoint_log_dir(idx): # load from checkpoint trainer_config["callbacks"] = [ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)] - trainer = pl.Trainer(**trainer_config, resume_from_checkpoint=chk) + trainer = pl.Trainer(**trainer_config) assert_trainer_init(trainer) model = ExtendedBoringModel() trainer.test(model) - # resume_from_checkpoint is resumed when calling `.fit` assert trainer.global_step == 0 assert trainer.current_epoch == 0 - trainer.fit(model) + trainer.fit(model, ckpt_path=chk) assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs assert_checkpoint_log_dir(idx) diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py index a617162a4daed..e023a4a3472ff 100644 --- a/tests/checkpointing/test_trainer_checkpoint.py +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -22,7 +22,7 @@ from tests.helpers import BoringModel -def test_finetuning_with_resume_from_checkpoint(tmpdir): +def test_finetuning_with_ckpt_path(tmpdir): """This test validates that generated ModelCheckpoint is pointing to the right best_model_path during test.""" checkpoint_callback = ModelCheckpoint(monitor="val_loss", dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1) @@ -63,10 +63,9 @@ def validation_step(self, batch, batch_idx): limit_train_batches=12, limit_val_batches=12, limit_test_batches=12, - resume_from_checkpoint=best_model_paths[-1], enable_progress_bar=False, ) - trainer.fit(model) + trainer.fit(model, ckpt_path=best_model_paths[-1]) trainer.test() results.append(deepcopy(trainer.callback_metrics)) best_model_paths.append(trainer.checkpoint_callback.best_model_path) @@ -78,7 +77,7 @@ def validation_step(self, batch, batch_idx): assert f"epoch={idx + 1}" in best_model_path -def test_accumulated_gradient_batches_with_resume_from_checkpoint(tmpdir): +def test_accumulated_gradient_batches_with_ckpt_path(tmpdir): """This test validates that accumulated gradient is properly recomputed and reset on the trainer.""" ckpt = ModelCheckpoint(dirpath=tmpdir, save_last=True) @@ -90,6 +89,5 @@ def test_accumulated_gradient_batches_with_resume_from_checkpoint(tmpdir): trainer.fit(model) trainer_kwargs["max_epochs"] = 2 - trainer_kwargs["resume_from_checkpoint"] = ckpt.last_model_path trainer = Trainer(**trainer_kwargs) - trainer.fit(model) + trainer.fit(model, ckpt_path=ckpt.last_model_path) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 1d79a0a0194f8..c374b9e4f0ee6 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -463,10 +463,9 @@ def on_epoch_end(self) -> None: else trainer_kwargs["default_root_dir"] ) ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt") - trainer_kwargs["resume_from_checkpoint"] = ckpt_path trainer = Trainer(**trainer_kwargs) - trainer.fit(model) + trainer.fit(model, ckpt_path=ckpt_path) assert model.has_validated_sum diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 17facc032994d..795d7f00236fb 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -255,3 +255,12 @@ def test_v1_7_0_deprecate_lightning_distributed(tmpdir): from pytorch_lightning.distributed.dist import LightningDistributed _ = LightningDistributed() + + +def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir): + with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"): + trainer = Trainer(resume_from_checkpoint="a") + with pytest.deprecated_call( + match=r"trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7." + ): + _ = trainer.resume_from_checkpoint diff --git a/tests/loops/optimization/test_optimizer_loop.py b/tests/loops/optimization/test_optimizer_loop.py index 4f7f8ac7d48ce..9051c2e32c757 100644 --- a/tests/loops/optimization/test_optimizer_loop.py +++ b/tests/loops/optimization/test_optimizer_loop.py @@ -227,7 +227,6 @@ def configure_optimizers(self): model.training_epoch_end = None model.optimizer_step = Mock(wraps=model.optimizer_step) trainer = Trainer( - resume_from_checkpoint=str(tmpdir / ".pl_auto_save.ckpt"), default_root_dir=tmpdir, max_epochs=n_epochs, limit_train_batches=n_batches, @@ -236,7 +235,7 @@ def configure_optimizers(self): logger=False, checkpoint_callback=False, ) - trainer.fit(model) + trainer.fit(model, ckpt_path=str(tmpdir / ".pl_auto_save.ckpt")) weights_resumed = model.parameters() # check that the final weights of a resumed run match the weights of a run that never failed diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index de849f90a079f..945778e1dda1f 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -584,14 +584,13 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): limit_val_batches=0, enable_progress_bar=False, weights_summary=None, - resume_from_checkpoint=best_model_path, callbacks=[callback], ) assert called == [ dict(name="Callback.on_init_start", args=(trainer,)), dict(name="Callback.on_init_end", args=(trainer,)), ] - trainer.fit(model) + trainer.fit(model, ckpt_path=best_model_path) saved_ckpt = { "callbacks": ANY, "epoch": 2, # TODO: wrong saved epoch diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 5a68f757e8744..61cabf1b866b9 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -104,7 +104,7 @@ def validation_step_end(self, outputs): self.log("val_acc", self.valid_acc(outputs["logits"], outputs["y"])) -def test_model_properties_resume_from_checkpoint(tmpdir): +def test_model_properties_fit_ckpt_path(tmpdir): """Test that properties like `current_epoch` and `global_step` in model and trainer are always the same.""" model = BoringModel() checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) @@ -120,17 +120,17 @@ def test_model_properties_resume_from_checkpoint(tmpdir): trainer.fit(model) trainer_args.update(max_epochs=2) - trainer = Trainer(**trainer_args, resume_from_checkpoint=str(tmpdir / "last.ckpt")) - trainer.fit(model) + trainer = Trainer(**trainer_args) + trainer.fit(model, ckpt_path=str(tmpdir / "last.ckpt")) def test_try_resume_from_non_existing_checkpoint(tmpdir): - """Test that trying to resume from non-existing `resume_from_checkpoint` fails with an error.""" + """Test that trying to resume from non-existing `ckpt_path` fails with an error.""" model = BoringModel() - trainer = Trainer(resume_from_checkpoint=str(tmpdir / "non_existing.ckpt")) + trainer = Trainer() with pytest.raises(FileNotFoundError, match="Aborting training"): - trainer.fit(model) + trainer.fit(model, ckpt_path=str(tmpdir / "non_existing.ckpt")) class CaptureCallbacksBeforeTraining(Callback): @@ -140,7 +140,7 @@ def on_pretrain_routine_end(self, trainer, pl_module): self.callbacks = deepcopy(trainer.callbacks) -def test_callbacks_state_resume_from_checkpoint(tmpdir): +def test_callbacks_state_fit_ckpt_path(tmpdir): """Test that resuming from a checkpoint restores callbacks that persist state.""" dm = ClassifDataModule() model = ClassificationModel() @@ -165,8 +165,8 @@ def get_trainer_args(): callbacks_before_resume = deepcopy(trainer.callbacks) # resumed training - trainer = Trainer(**get_trainer_args(), resume_from_checkpoint=str(tmpdir / "last.ckpt")) - trainer.fit(model, datamodule=dm) + trainer = Trainer(**get_trainer_args()) + trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt")) assert len(callbacks_before_resume) == len(callback_capture.callbacks) @@ -176,7 +176,7 @@ def get_trainer_args(): assert before.best_model_score == after.best_model_score -def test_callbacks_references_resume_from_checkpoint(tmpdir): +def test_callbacks_references_fit_ckpt_path(tmpdir): """Test that resuming from a checkpoint sets references as expected.""" dm = ClassifDataModule() model = ClassificationModel() @@ -198,10 +198,10 @@ def test_callbacks_references_resume_from_checkpoint(tmpdir): new_checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) # pass in a new checkpoint object, which should take # precedence over the one in the last.ckpt file - trainer = Trainer(**args, callbacks=[new_checkpoint], resume_from_checkpoint=str(tmpdir / "last.ckpt")) + trainer = Trainer(**args, callbacks=[new_checkpoint]) assert checkpoint is not new_checkpoint assert new_checkpoint is trainer.callbacks[-1] is trainer.checkpoint_callback - trainer.fit(model, datamodule=dm) + trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt")) @RunIf(min_gpus=2) diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index 939c05d1b7afe..9d23f48bed98e 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -62,8 +62,6 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: model = BoringModel() plugin = TestPlugin(torch.device("cpu")) - trainer = Trainer( - default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin, resume_from_checkpoint=checkpoint_path - ) - trainer.fit(model) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin) + trainer.fit(model, ckpt_path=checkpoint_path) assert plugin.load_optimizer_state_dict_called == restore_optimizer_and_schedulers diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index bdc6bca16d495..baac266efbb2a 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -620,7 +620,6 @@ def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): plugins=DeepSpeedPlugin(stage=3, load_full_weights=True), gpus=1, precision=16, - resume_from_checkpoint=checkpoint_path, ) with pytest.warns( UserWarning, @@ -628,7 +627,7 @@ def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): "scheduler states can not be restored. If you'd like to restore these states, you must " "provide a path to the originally saved DeepSpeed checkpoint.", ): - trainer.fit(model, datamodule=dm) + trainer.fit(model, datamodule=dm, ckpt_path=checkpoint_path) @RunIf(min_gpus=1, deepspeed=True, special=True) @@ -680,10 +679,9 @@ def on_train_batch_start( plugins=DeepSpeedPlugin(stage=3), gpus=1, precision=16, - resume_from_checkpoint=ck.best_model_path, callbacks=TestCallback(), ) - trainer.fit(model, datamodule=dm) + trainer.fit(model, datamodule=dm, ckpt_path=ck.best_model_path) @RunIf(min_gpus=2, deepspeed=True, special=True) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index f2c99fa4cb17a..960ba2c88680c 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -132,7 +132,7 @@ def test_ddp_sharded_plugin_finetune(tmpdir): @RunIf(skip_windows=True, fairscale=True) -def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir): +def test_ddp_sharded_plugin_fit_ckpt_path(tmpdir): """Test to ensure that resuming from checkpoint works.""" model = BoringModel() trainer = Trainer(accelerator="ddp_sharded_spawn", num_processes=2, fast_dev_run=True) @@ -144,17 +144,15 @@ def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir): model = BoringModel() - trainer = Trainer( - accelerator="ddp_sharded_spawn", num_processes=2, fast_dev_run=True, resume_from_checkpoint=checkpoint_path - ) + trainer = Trainer(accelerator="ddp_sharded_spawn", num_processes=2, fast_dev_run=True) - trainer.fit(model) + trainer.fit(model, ckpt_path=checkpoint_path) @pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.") # todo @pytest.mark.skip(reason="Currently unsupported restarting training on different number of devices.") @RunIf(min_gpus=2, skip_windows=True, fairscale=True) -def test_ddp_sharded_plugin_resume_from_checkpoint_downsize_gpus(tmpdir): +def test_ddp_sharded_plugin_fit_ckpt_path_downsize_gpus(tmpdir): """Test to ensure that resuming from checkpoint works when downsizing number of GPUS.""" model = BoringModel() trainer = Trainer(accelerator="ddp_sharded_spawn", fast_dev_run=True, gpus=2) @@ -166,15 +164,13 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_downsize_gpus(tmpdir): model = BoringModel() - trainer = Trainer( - accelerator="ddp_sharded_spawn", fast_dev_run=True, gpus=1, resume_from_checkpoint=checkpoint_path - ) + trainer = Trainer(accelerator="ddp_sharded_spawn", fast_dev_run=True, gpus=1) - trainer.fit(model) + trainer.fit(model, ckpt_path=checkpoint_path) @RunIf(min_gpus=1, skip_windows=True, fairscale=True) -def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): +def test_ddp_sharded_plugin_fit_ckpt_path_gpu_to_cpu(tmpdir): """Test to ensure that resuming from checkpoint works when going from GPUs- > CPU.""" model = BoringModel() trainer = Trainer(accelerator="ddp_sharded_spawn", gpus=1, fast_dev_run=True) @@ -186,11 +182,9 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): model = BoringModel() - trainer = Trainer( - accelerator="ddp_sharded_spawn", num_processes=2, fast_dev_run=True, resume_from_checkpoint=checkpoint_path - ) + trainer = Trainer(accelerator="ddp_sharded_spawn", num_processes=2, fast_dev_run=True) - trainer.fit(model) + trainer.fit(model, ckpt_path=checkpoint_path) @RunIf(skip_windows=True, special=True, fairscale=True) diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py index 83a45f02224d5..7c8a4e2dc9bfb 100644 --- a/tests/trainer/connectors/test_checkpoint_connector.py +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -60,7 +60,6 @@ def test_preloaded_checkpoint_lifecycle(tmpdir): connector = trainer.checkpoint_connector - assert not trainer.resume_from_checkpoint assert not connector.resume_checkpoint_path assert not connector._loaded_checkpoint @@ -72,9 +71,9 @@ def test_preloaded_checkpoint_lifecycle(tmpdir): assert not connector._loaded_checkpoint ckpt_path = trainer.checkpoint_callback.best_model_path - trainer = Trainer(default_root_dir=tmpdir, max_steps=2, resume_from_checkpoint=ckpt_path) + trainer = Trainer(default_root_dir=tmpdir, max_steps=2) connector = trainer.checkpoint_connector - connector.resume_start() + connector.resume_start(ckpt_path) assert connector.resume_checkpoint_path == ckpt_path assert connector._loaded_checkpoint assert isinstance(connector._loaded_checkpoint, dict) @@ -106,8 +105,8 @@ def test_hpc_restore_attempt(tmpdir): torch.nn.init.constant_(param, 0) # case 2: explicit resume path provided, restore hpc anyway - trainer = Trainer(default_root_dir=tmpdir, max_steps=3, resume_from_checkpoint="not existing") - trainer.fit(model) + trainer = Trainer(default_root_dir=tmpdir, max_steps=3) + trainer.fit(model, ckpt_path="not existing") for param in model.parameters(): assert param.abs().sum() > 0 diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a1e6ce01107db..41fbbe75677bb 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -461,7 +461,7 @@ def test_model_freeze_unfreeze(): @pytest.mark.parametrize("url_ckpt", [True, False]) -def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt): +def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Verify resuming from checkpoint runs the right number of epochs.""" # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir monkeypatch.setenv("TORCH_HOME", tmpdir) @@ -512,8 +512,8 @@ def on_load_checkpoint(self, _): state = pl_load(ckpt) # Resume training - new_trainer = Trainer(default_root_dir=tmpdir, resume_from_checkpoint=ckpt, max_epochs=2) - new_trainer.fit(next_model) + new_trainer = Trainer(default_root_dir=tmpdir, max_epochs=2) + new_trainer.fit(next_model, ckpt_path=ckpt) assert state["global_step"] + next_model.num_batches_seen == trainer.num_training_batches * trainer.max_epochs assert next_model.num_on_load_checkpoint_called == 1 @@ -1876,9 +1876,9 @@ def test_on_load_checkpoint_missing_callbacks(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=3, callbacks=[chk, CustomCallbackOnLoadCheckpoint()]) trainer.fit(model) - trainer = Trainer(default_root_dir=tmpdir, max_epochs=5, resume_from_checkpoint=chk.last_model_path) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=5) with pytest.warns(UserWarning, match="CustomCallbackOnLoadCheckpoint"): - trainer.fit(model) + trainer.fit(model, ckpt_path=chk.last_model_path) def test_module_current_fx_attributes_reset(tmpdir): diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index b64cad8613232..2d8a316c0d746 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -907,7 +907,7 @@ def configure_optimizers(self): return torch.optim.SGD(self.layer.parameters(), lr=0.1) -def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1): +def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1, ckpt_path=None): seed_everything(1) train_dataloader = [ DataLoader(dataset_class(3, 1), batch_size=1, num_workers=0) for dataset_class in dataset_classes @@ -916,7 +916,7 @@ def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1): model = TestModel(fail_on_step=fail_on_step) trainer = Trainer(**trainer_kwargs) with suppress(CustomException): - trainer.fit(model, train_dataloader=train_dataloader) + trainer.fit(model, train_dataloader=train_dataloader, ckpt_path=ckpt_path) return model.seen_batches, model.parameters() @@ -958,8 +958,9 @@ def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, mult assert os.path.exists(checkpoint_path) # Resume after failure - trainer_kwargs.update(resume_from_checkpoint=checkpoint_path) - resumed_batches, weights1 = _run_training(trainer_kwargs, dataset_classes, fail_on_step=-1) + resumed_batches, weights1 = _run_training( + trainer_kwargs, dataset_classes, fail_on_step=-1, ckpt_path=checkpoint_path + ) assert len(resumed_batches) == 5 # the resumed batches should match the batches of the successful training diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index ace76bd9374c8..1e0dfbc0e74a5 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -137,7 +137,6 @@ def _raise(): log_gpu_memory=None, distributed_backend=None, weights_save_path=None, - resume_from_checkpoint=None, profiler=None, ), ), @@ -907,7 +906,7 @@ def test_lightning_cli_model_choices(): ) as run: cli = LightningCLI(trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.model, BoringModel) - run.assert_called_once_with(cli.model, ANY, ANY, ANY) + run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY) with mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]): cli = LightningCLI(run=False)