From 6ecf45bb755f8c1e3e468b0f27b13e52fa8b58d4 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Thu, 23 Sep 2021 18:16:51 -0700 Subject: [PATCH 01/12] first commit wip --- .../connectors/checkpoint_connector.py | 23 +++-- pytorch_lightning/trainer/trainer.py | 85 +++++++++---------- tests/accelerators/test_cpu.py | 4 +- tests/callbacks/test_callbacks.py | 10 +-- tests/callbacks/test_early_stopping.py | 3 +- tests/callbacks/test_finetuning_callback.py | 14 ++- tests/callbacks/test_lambda_function.py | 6 +- .../checkpointing/test_legacy_checkpoints.py | 3 +- tests/checkpointing/test_model_checkpoint.py | 8 +- tests/core/test_metric_result_integration.py | 3 +- tests/models/test_hooks.py | 3 +- tests/plugins/test_custom_plugin.py | 4 +- .../connectors/test_checkpoint_connector.py | 10 +-- tests/trainer/test_trainer.py | 16 ++-- tests/utilities/test_auto_restart.py | 7 +- tests/utilities/test_cli.py | 3 +- 16 files changed, 97 insertions(+), 105 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index b750b0f81b26f..ee2b6d32e0fc0 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -46,17 +46,21 @@ def hpc_resume_path(self) -> Optional[str]: if max_version is not None: return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt") - def resume_start(self) -> None: + def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: """Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: 1. from HPC weights if found - 2. from `resume_from_checkpoint` file if provided - 3. don't restore + 2. from `resume_from_checkpoint` file if provided + .. deprecated:: v1.5 + `Trainer(resume_from_checkpoint=)` is deprecated in v1.5 and will be removed in v1.7. + Please use `Trainer.fit(ckpt_path=)` instead. + 3. from `checkpoint_path` file if provided + 4. don't restore Raises: FileNotFoundError: If the path to the checkpoint file is provided but the file does not exist. """ - self.resume_checkpoint_path = self.hpc_resume_path or self.resume_checkpoint_path + self.resume_checkpoint_path = self.hpc_resume_path or self.resume_checkpoint_path or checkpoint_path checkpoint_path = self.resume_checkpoint_path if not checkpoint_path: return @@ -94,7 +98,7 @@ def restore(self, checkpoint_path: Optional[_PATH] = None) -> None: state-restore, in this priority: 1. from HPC weights if found - 2. from `resume_from_checkpoint` file if provided + 2. from `checkpoint_path` file if provided 3. don't restore All restored states are listed in return value description of `dump_checkpoint`. @@ -103,7 +107,7 @@ def restore(self, checkpoint_path: Optional[_PATH] = None) -> None: checkpoint_path: Path to a PyTorch Lightning checkpoint file. """ self.resume_checkpoint_path = checkpoint_path - self.resume_start() + self.resume_start(checkpoint_path) # restore module states self.restore_datamodule() @@ -153,7 +157,12 @@ def restore_model(self) -> None: module.reset() def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None: - """Restore only the model weights.""" + """Restore only the model weights. + + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.7. + Please use `CheckpointConnector.restore_model` instead. + """ checkpoint = self._loaded_checkpoint if checkpoint_path is not None: checkpoint = self._load_and_validate_checkpoint(checkpoint_path) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2e115decf3ade..df98eb924f0d2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -333,6 +333,10 @@ def __init__( no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch. + .. deprecated:: v1.5 + ``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v1.7. + Please use ``Trainer.fit(ckpt_path)`` instead. + sync_batchnorm: Synchronize batch norm layers between process groups/whole world. terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the @@ -435,11 +439,6 @@ def __init__( # Needed because of LightningOptimizer self._lightning_optimizers = None - # .validate() and .test() set this when they load a checkpoint - self.validated_ckpt_path: Optional[str] = None - self.tested_ckpt_path: Optional[str] = None - self.predicted_ckpt_path: Optional[str] = None - # init callbacks # Declare attributes to be set in callback_connector on_trainer_init self.callback_connector.on_trainer_init( @@ -561,6 +560,7 @@ def fit( model: "pl.LightningModule", train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, val_dataloaders: Optional[EVAL_DATALOADERS] = None, + ckpt_path: Optional[str] = None, datamodule: Optional[LightningDataModule] = None, train_dataloader=None, # TODO: remove with 1.6 ) -> None: @@ -576,6 +576,10 @@ def fit( val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. + ckpt_path: Path/URL of the checkpoint from which training is resumed. If there is + no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint, + training will start from the beginning of the next epoch. + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. """ if train_dataloader is not None: @@ -584,13 +588,14 @@ def fit( " Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'" ) train_dataloaders = train_dataloader - self._call_and_handle_interrupt(self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule) + self._call_and_handle_interrupt(self._fit_impl, model, train_dataloaders, val_dataloaders, ckpt_path, datamodule) def _fit_impl( self, model: "pl.LightningModule", train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, val_dataloaders: Optional[EVAL_DATALOADERS] = None, + ckpt_path: Optional[str] = None, datamodule: Optional[LightningDataModule] = None, ) -> None: Trainer._log_api_event("fit") @@ -614,7 +619,7 @@ def _fit_impl( model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule ) - self._run(model) + self._run(model, ckpt_path) assert self.state.stopped self.training = False @@ -695,12 +700,12 @@ def _validate_impl( # links data to the trainer self.data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule) - self.validated_ckpt_path = self.__set_ckpt_path( + ckpt_path = self.__set_ckpt_path( ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) # run validate - results = self._run(model) + results = self._run(model, ckpt_path) assert self.state.stopped self.validating = False @@ -784,12 +789,12 @@ def _test_impl( # links data to the trainer self.data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule) - self.tested_ckpt_path = self.__set_ckpt_path( + ckpt_path = self.__set_ckpt_path( ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) # run test - results = self._run(model) + results = self._run(model, ckpt_path) assert self.state.stopped self.testing = False @@ -867,11 +872,11 @@ def _predict_impl( # links data to the trainer self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule) - self.predicted_ckpt_path = self.__set_ckpt_path( + ckpt_path = self.__set_ckpt_path( ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) - results = self._run(model) + results = self._run(model, ckpt_path) assert self.state.stopped self.predicting = False @@ -940,24 +945,16 @@ def tune( return result - def _restore_modules_and_callbacks(self) -> None: + def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None: # restore modules after setup - if self.state.fn == TrainerFn.FITTING: - self.checkpoint_connector.resume_start() - self.checkpoint_connector.restore_datamodule() + self.checkpoint_connector.resume_start(checkpoint_path) self.checkpoint_connector.restore_model() - # restore callback states - self.checkpoint_connector.restore_callbacks() - - def _load_checkpoint_weights(self): - # only one process running at this point for TPUs, as spawn isn't triggered yet - # todo: move this logic internally within the barrier. - if not self._device_type == DeviceType.TPU: - self.accelerator.barrier() - rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}") - self.checkpoint_connector.restore_model_weights(self._ckpt_path) - - def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: + if self.state.fn == TrainerFn.FITTING: + self.checkpoint_connector.restore_datamodule() + # restore callback states + self.checkpoint_connector.restore_callbacks() + + def _run(self, model: "pl.LightningModule", ckpt_path: Optional[str] = None) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) @@ -974,9 +971,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.data_connector.prepare_data() self.callback_connector._attach_model_callbacks() - if self._ckpt_path and not self.accelerator.restore_checkpoint_after_pre_dispatch: - self._load_checkpoint_weights() - # ---------------------------- # SET UP TRAINING # ---------------------------- @@ -986,7 +980,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, # check if we should delay restoring checkpoint till later if not self.accelerator.restore_checkpoint_after_pre_dispatch: - self._restore_modules_and_callbacks() + self._restore_modules_and_callbacks(ckpt_path) self._call_configure_sharded_model() # allow user to setup in model sharded environment self.accelerator.setup(self) @@ -1030,12 +1024,11 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self._pre_dispatch() if self.accelerator.restore_checkpoint_after_pre_dispatch: - if self._ckpt_path: - self._load_checkpoint_weights() - self._restore_modules_and_callbacks() + self._restore_modules_and_callbacks(ckpt_path) # restore optimizers, etc. - self.checkpoint_connector.restore_training_state() + if self.state.fn == TrainerFn.FITTING: + self.checkpoint_connector.restore_training_state() # dispatch `start_training` or `start_evaluating` or `start_predicting` self._dispatch() @@ -1730,6 +1723,15 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]: @property def resume_from_checkpoint(self) -> Optional[Union[str, Path]]: + """ + .. deprecated:: v1.5 + This method is deprecated in v1.5 and will be removed in v1.7. + Please use `Trainer.resume_checkpoint_path` instead. + """ + return self.resume_checkpoint_path + + @property + def resume_checkpoint_path(self) -> Optional[Union[str, Path]]: return self.checkpoint_connector.resume_checkpoint_path def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None: @@ -1962,15 +1964,6 @@ def train_loop(self) -> FitLoop: ) return self.fit_loop - @property - def _ckpt_path(self) -> Optional[str]: - if self.state.fn == TrainerFn.VALIDATING: - return self.validated_ckpt_path - if self.state.fn == TrainerFn.TESTING: - return self.tested_ckpt_path - if self.state.fn == TrainerFn.PREDICTING: - return self.predicted_ckpt_path - """ Logging properties """ diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index f95d182f9e5e1..3a83c2577aa8f 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -81,9 +81,9 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch trainer = Trainer( - default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True, resume_from_checkpoint=checkpoint_path + default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True ) - trainer.fit(model) + trainer.fit(model, ckpt_path=checkpoint_path) for func in (trainer.test, trainer.validate, trainer.predict): accelerator.training_type_plugin.predispatched_called = False func(model, ckpt_path=checkpoint_path) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 5803db051c659..77c09a1b3355e 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -134,8 +134,8 @@ def test_resume_callback_state_saved_by_type(tmpdir): assert ckpt_path.exists() callback = OldStatefulCallback(state=222) - trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path) - trainer.fit(model) + trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback]) + trainer.fit(model, ckpt_path=ckpt_path) assert callback.state == 111 @@ -155,16 +155,14 @@ def test_resume_incomplete_callbacks_list_warning(tmpdir): default_root_dir=tmpdir, max_steps=1, callbacks=[callback1], # one callback is missing! - resume_from_checkpoint=ckpt_path, ) with pytest.warns(UserWarning, match=escape(f"Please add the following callbacks: [{repr(callback0.state_key)}]")): - trainer.fit(model) + trainer.fit(model, ckpt_path=ckpt_path) trainer = Trainer( default_root_dir=tmpdir, max_steps=1, callbacks=[callback1, callback0], # all callbacks here, order switched - resume_from_checkpoint=ckpt_path, ) with no_warning_call(UserWarning, match="Please add the following callbacks:"): - trainer.fit(model) + trainer.fit(model, ckpt_path=ckpt_path) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 7cab0d8776056..66b03e9d13aa2 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -90,12 +90,11 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): new_trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - resume_from_checkpoint=checkpoint_filepath, callbacks=[early_stop_callback], ) with pytest.raises(MisconfigurationException, match=r"You restored a checkpoint with current_epoch"): - new_trainer.fit(model) + new_trainer.fit(model, ckpt_path=checkpoint_filepath) def test_early_stopping_no_extraneous_invocations(tmpdir): diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index c014c8e736874..013380376b94c 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -128,8 +128,8 @@ def configure_optimizers(self): trainer.fit(model) assert model.backbone.has_been_used - trainer = Trainer(max_epochs=3, resume_from_checkpoint=chk.last_model_path) - trainer.fit(model) + trainer = Trainer(max_epochs=3) + trainer.fit(model, ckpt_path=chk.last_model_path) def test_freeze_unfreeze_function(tmpdir): @@ -258,9 +258,9 @@ def configure_optimizers(self): model = FreezeModel() cb = OnEpochLayerFinetuning() - trainer = Trainer(max_epochs=10, resume_from_checkpoint=chk.last_model_path, callbacks=[cb]) + trainer = Trainer(max_epochs=10, callbacks=[cb]) with pytest.raises(IndexError, match="index 6 is out of range"): - trainer.fit(model) + trainer.fit(model, ckpt_path=chk.last_model_path) def test_on_before_accelerator_backend_setup(tmpdir): @@ -400,10 +400,9 @@ def test_callbacks_restore(tmpdir): } trainer_kwargs["max_epochs"] = 3 - trainer_kwargs["resume_from_checkpoint"] = chk.last_model_path trainer = Trainer(**trainer_kwargs) - trainer.fit(model) + trainer.fit(model, ckpt_path=chk.last_model_path) def test_callbacks_restore_backbone(tmpdir): @@ -438,6 +437,5 @@ def forward(self, x): max_epochs=3, progress_bar_refresh_rate=0, callbacks=BackboneFinetuning(unfreeze_backbone_at_epoch=1), - resume_from_checkpoint=ckpt.last_model_path, ) - trainer.fit(BackboneBoringModel()) + trainer.fit(BackboneBoringModel(), ckpt_path=ckpt.last_model_path) diff --git a/tests/callbacks/test_lambda_function.py b/tests/callbacks/test_lambda_function.py index 88752d56bf697..932acdccad346 100644 --- a/tests/callbacks/test_lambda_function.py +++ b/tests/callbacks/test_lambda_function.py @@ -56,11 +56,13 @@ def call(hook, *_, **__): limit_val_batches=1, limit_test_batches=1, limit_predict_batches=1, - resume_from_checkpoint=trainer.checkpoint_callback.best_model_path, callbacks=[LambdaCallback(**hooks_args)], ) - trainer.fit(model) + trainer.fit(model, ckpt_path=trainer.checkpoint_callback.best_model_path) trainer.test(model) trainer.predict(model) + # assert len(checker) == len(hooks) + # 47 == 48 assert checker == hooks + # extra items in the right set: 'on_load_checkpoint' \ No newline at end of file diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 040cd642556cf..ca1b638c6f2d4 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -84,9 +84,8 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str): max_epochs=21, accumulate_grad_batches=2, deterministic=True, - resume_from_checkpoint=path_ckpt, ) - trainer.fit(model, datamodule=dm) + trainer.fit(model, datamodule=dm, ckpt_path=path_ckpt) res = trainer.test(model, datamodule=dm) assert res[0]["test_loss"] <= 0.7 assert res[0]["test_acc"] >= 0.85 diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 0c1a6fbd51268..0a3f448b3f327 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -903,11 +903,10 @@ def validation_step(self, batch, batch_idx): limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, - resume_from_checkpoint=checkpoint_callback.best_model_path, weights_summary=None, progress_bar_refresh_rate=0, ) - trainer.fit(model) + trainer.fit(model, ckpt_path=checkpoint_callback.best_model_path) trainer.test(model, verbose=False) assert set(os.listdir(tmpdir)) == {"epoch=00.ckpt", "lightning_logs"} assert set(os.listdir(tmpdir.join("lightning_logs"))) == {f"version_{i}" for i in range(4)} @@ -979,17 +978,16 @@ def assert_checkpoint_log_dir(idx): # load from checkpoint trainer_config["callbacks"] = [ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)] - trainer = pl.Trainer(**trainer_config, resume_from_checkpoint=chk) + trainer = pl.Trainer(**trainer_config) assert_trainer_init(trainer) model = ExtendedBoringModel() trainer.test(model) - # resume_from_checkpoint is resumed when calling `.fit` assert trainer.global_step == 0 assert trainer.current_epoch == 0 - trainer.fit(model) + trainer.fit(model, ckpt_path=chk) assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs assert_checkpoint_log_dir(idx) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 1d79a0a0194f8..c374b9e4f0ee6 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -463,10 +463,9 @@ def on_epoch_end(self) -> None: else trainer_kwargs["default_root_dir"] ) ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt") - trainer_kwargs["resume_from_checkpoint"] = ckpt_path trainer = Trainer(**trainer_kwargs) - trainer.fit(model) + trainer.fit(model, ckpt_path=ckpt_path) assert model.has_validated_sum diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 6ea8b76d253fa..3f9c676c38c88 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -584,14 +584,13 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): limit_val_batches=0, progress_bar_refresh_rate=0, weights_summary=None, - resume_from_checkpoint=best_model_path, callbacks=[callback], ) assert called == [ dict(name="Callback.on_init_start", args=(trainer,)), dict(name="Callback.on_init_end", args=(trainer,)), ] - trainer.fit(model) + trainer.fit(model, ckpt_path=best_model_path) saved_ckpt = { "callbacks": ANY, "epoch": 2, # TODO: wrong saved epoch diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index 939c05d1b7afe..bebd29736e840 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -63,7 +63,7 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: model = BoringModel() plugin = TestPlugin(torch.device("cpu")) trainer = Trainer( - default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin, resume_from_checkpoint=checkpoint_path + default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin ) - trainer.fit(model) + trainer.fit(model, ckpt_path=checkpoint_path) assert plugin.load_optimizer_state_dict_called == restore_optimizer_and_schedulers diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py index 83a45f02224d5..f5fbfef8b1cf3 100644 --- a/tests/trainer/connectors/test_checkpoint_connector.py +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -60,7 +60,7 @@ def test_preloaded_checkpoint_lifecycle(tmpdir): connector = trainer.checkpoint_connector - assert not trainer.resume_from_checkpoint + assert not trainer.resume_checkpoint_path assert not connector.resume_checkpoint_path assert not connector._loaded_checkpoint @@ -72,9 +72,9 @@ def test_preloaded_checkpoint_lifecycle(tmpdir): assert not connector._loaded_checkpoint ckpt_path = trainer.checkpoint_callback.best_model_path - trainer = Trainer(default_root_dir=tmpdir, max_steps=2, resume_from_checkpoint=ckpt_path) + trainer = Trainer(default_root_dir=tmpdir, max_steps=2) connector = trainer.checkpoint_connector - connector.resume_start() + connector.resume_start(ckpt_path) assert connector.resume_checkpoint_path == ckpt_path assert connector._loaded_checkpoint assert isinstance(connector._loaded_checkpoint, dict) @@ -106,8 +106,8 @@ def test_hpc_restore_attempt(tmpdir): torch.nn.init.constant_(param, 0) # case 2: explicit resume path provided, restore hpc anyway - trainer = Trainer(default_root_dir=tmpdir, max_steps=3, resume_from_checkpoint="not existing") - trainer.fit(model) + trainer = Trainer(default_root_dir=tmpdir, max_steps=3) + trainer.fit(model, ckpt_path="not existing") for param in model.parameters(): assert param.abs().sum() > 0 diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5f1bdd1f34541..a12fea2cc989b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -684,8 +684,8 @@ def predict_step(self, batch, *_): trainer.fit(model) trainer_fn = getattr(trainer, fn) - path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path" - assert getattr(trainer, path_attr) is None + # path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path" + assert trainer.resume_checkpoint_path is None if ckpt_path == "best": # ckpt_path is 'best', meaning we load the best weights @@ -696,20 +696,20 @@ def predict_step(self, batch, *_): trainer_fn(model, ckpt_path=ckpt_path) else: trainer_fn(ckpt_path=ckpt_path) - assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path + assert trainer.resume_checkpoint_path == trainer.checkpoint_callback.best_model_path trainer_fn(model, ckpt_path=ckpt_path) - assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path + assert trainer.resume_checkpoint_path == trainer.checkpoint_callback.best_model_path elif ckpt_path is None: # ckpt_path is None, meaning we don't load any checkpoints and use the provided model trainer_fn(model, ckpt_path=ckpt_path) - assert getattr(trainer, path_attr) is None + assert trainer.resume_checkpoint_path is None if save_top_k > 0: # ckpt_path is None with no model provided means load the best weights with pytest.warns(UserWarning, match="The best model of the previous `fit` call will be used"): trainer_fn(ckpt_path=ckpt_path) - assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path + assert trainer.resume_checkpoint_path == trainer.checkpoint_callback.best_model_path else: # specific checkpoint, pick one from saved ones if save_top_k == 0: @@ -722,10 +722,10 @@ def predict_step(self, batch, *_): ].absolute() ) trainer_fn(ckpt_path=ckpt_path) - assert getattr(trainer, path_attr) == ckpt_path + assert trainer.resume_checkpoint_path == ckpt_path trainer_fn(model, ckpt_path=ckpt_path) - assert getattr(trainer, path_attr) == ckpt_path + assert trainer.resume_checkpoint_path == ckpt_path def test_disabled_training(tmpdir): diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 5500fe5393f27..c887c36763794 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -906,7 +906,7 @@ def configure_optimizers(self): return torch.optim.SGD(self.layer.parameters(), lr=0.1) -def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1): +def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1, ckpt_path = None): seed_everything(1) train_dataloader = [ DataLoader(dataset_class(3, 1), batch_size=1, num_workers=0) for dataset_class in dataset_classes @@ -915,7 +915,7 @@ def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1): model = TestModel(fail_on_step=fail_on_step) trainer = Trainer(**trainer_kwargs) with suppress(CustomException): - trainer.fit(model, train_dataloader=train_dataloader) + trainer.fit(model, train_dataloader=train_dataloader, ckpt_path=ckpt_path) return model.seen_batches, model.parameters() @@ -957,8 +957,7 @@ def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, mult assert os.path.exists(checkpoint_path) # Resume after failure - trainer_kwargs.update(resume_from_checkpoint=checkpoint_path) - resumed_batches, weights1 = _run_training(trainer_kwargs, dataset_classes, fail_on_step=-1) + resumed_batches, weights1 = _run_training(trainer_kwargs, dataset_classes, fail_on_step=-1, ckpt_path=checkpoint_path) assert len(resumed_batches) == 5 # the resumed batches should match the batches of the successful training diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index ace76bd9374c8..1e0dfbc0e74a5 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -137,7 +137,6 @@ def _raise(): log_gpu_memory=None, distributed_backend=None, weights_save_path=None, - resume_from_checkpoint=None, profiler=None, ), ), @@ -907,7 +906,7 @@ def test_lightning_cli_model_choices(): ) as run: cli = LightningCLI(trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.model, BoringModel) - run.assert_called_once_with(cli.model, ANY, ANY, ANY) + run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY) with mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]): cli = LightningCLI(run=False) From 5a0d60c033ef4a9fa5627fbd3c3901a8aed63325 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Thu, 23 Sep 2021 21:28:01 -0700 Subject: [PATCH 02/12] test_lambda_fix --- tests/callbacks/test_lambda_function.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/callbacks/test_lambda_function.py b/tests/callbacks/test_lambda_function.py index 932acdccad346..f2fa040b43c78 100644 --- a/tests/callbacks/test_lambda_function.py +++ b/tests/callbacks/test_lambda_function.py @@ -48,6 +48,8 @@ def call(hook, *_, **__): ) trainer.fit(model) + ckpt_path = trainer.checkpoint_callback.best_model_path + # raises KeyboardInterrupt and loads from checkpoint trainer = Trainer( default_root_dir=tmpdir, @@ -58,11 +60,8 @@ def call(hook, *_, **__): limit_predict_batches=1, callbacks=[LambdaCallback(**hooks_args)], ) - trainer.fit(model, ckpt_path=trainer.checkpoint_callback.best_model_path) + trainer.fit(model, ckpt_path=ckpt_path) trainer.test(model) trainer.predict(model) - # assert len(checker) == len(hooks) - # 47 == 48 assert checker == hooks - # extra items in the right set: 'on_load_checkpoint' \ No newline at end of file From 5b7df74c9fb94d05f370add236be61770f53b8cc Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Fri, 24 Sep 2021 00:37:36 -0700 Subject: [PATCH 03/12] more test updates --- CHANGELOG.md | 2 ++ pytorch_lightning/trainer/callback_hook.py | 2 +- .../connectors/checkpoint_connector.py | 21 +++++----------- pytorch_lightning/trainer/trainer.py | 9 ++++--- tests/accelerators/test_tpu_backend.py | 4 ++-- tests/callbacks/test_timer.py | 3 +-- .../checkpointing/test_trainer_checkpoint.py | 10 ++++---- tests/deprecated_api/test_remove_1-7.py | 4 ++++ .../loops/optimization/test_optimizer_loop.py | 3 +-- tests/models/test_restore.py | 24 +++++++++---------- tests/plugins/test_deepspeed_plugin.py | 6 ++--- tests/plugins/test_sharded_plugin.py | 18 +++++++------- tests/trainer/test_trainer.py | 10 ++++---- 13 files changed, 53 insertions(+), 63 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cbc28db556601..94fd7f68b1349 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -252,6 +252,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `LightningLoggerBase.close`, `LoggerCollection.close` in favor of `LightningLoggerBase.finalize`, `LoggerCollection.finalize` ([#9422](https://github.com/PyTorchLightning/pytorch-lightning/pull/9422)) +- Deprecated passing `resume_from_checkpoint` to the `Trainer` constructor in favor of `trainer.fit(ckpt_path=)` ([#123](https://github.com/PyTorchLightning/pytorch-lightning/pull/123)) + ### Removed diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index b8931c415553b..d14cfdfcd6f4b 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -270,7 +270,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: difference = callback_states.keys() - current_callbacks_keys if difference: rank_zero_warn( - "Be aware that when using `resume_from_checkpoint`," + "Be aware that when using `ckpt_path`," " callbacks used to create the checkpoint need to be provided." f" Please add the following callbacks: {list(difference)}.", UserWarning, diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index ee2b6d32e0fc0..d55eae098bdcf 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -37,6 +37,11 @@ class CheckpointConnector: def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH] = None) -> None: self.trainer = trainer self.resume_checkpoint_path = resume_from_checkpoint + if resume_from_checkpoint is not None: + rank_zero_deprecation( + "Setting `Trainer(resume_from_checkpoint=)` is deprecated in v1.5 and" + " will be removed in v1.7. Please pass `Trainer.fit(ckpt_path=)` directly instead." + ) self._loaded_checkpoint: Dict[str, Any] = {} @property @@ -50,7 +55,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: """Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: 1. from HPC weights if found - 2. from `resume_from_checkpoint` file if provided + 2. from `resume_from_checkpoint` file if provided .. deprecated:: v1.5 `Trainer(resume_from_checkpoint=)` is deprecated in v1.5 and will be removed in v1.7. Please use `Trainer.fit(ckpt_path=)` instead. @@ -156,20 +161,6 @@ def restore_model(self) -> None: if isinstance(module, Metric): module.reset() - def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None: - """Restore only the model weights. - - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.7. - Please use `CheckpointConnector.restore_model` instead. - """ - checkpoint = self._loaded_checkpoint - if checkpoint_path is not None: - checkpoint = self._load_and_validate_checkpoint(checkpoint_path) - - self.trainer.lightning_module.on_load_checkpoint(checkpoint) - self.trainer.training_type_plugin.load_model_state_dict(checkpoint) - def restore_training_state(self) -> None: """Restore the trainer state from the pre-loaded checkpoint. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index df98eb924f0d2..83555f453f896 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1723,11 +1723,10 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]: @property def resume_from_checkpoint(self) -> Optional[Union[str, Path]]: - """ - .. deprecated:: v1.5 - This method is deprecated in v1.5 and will be removed in v1.7. - Please use `Trainer.resume_checkpoint_path` instead. - """ + rank_zero_deprecation( + "`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7." + " Use `trainer.resume_checkpoint_path` instead." + ) return self.resume_checkpoint_path @property diff --git a/tests/accelerators/test_tpu_backend.py b/tests/accelerators/test_tpu_backend.py index 7f7bad327f515..6d2d5d38dafb0 100644 --- a/tests/accelerators/test_tpu_backend.py +++ b/tests/accelerators/test_tpu_backend.py @@ -62,9 +62,9 @@ def test_resume_training_on_cpu(tmpdir): # Verify that training is resumed on CPU trainer = Trainer( - resume_from_checkpoint=model_path, checkpoint_callback=True, max_epochs=1, default_root_dir=tmpdir + checkpoint_callback=True, max_epochs=1, default_root_dir=tmpdir ) - trainer.fit(model) + trainer.fit(model, ckpt_path=model_path) assert trainer.state.finished, f"Training failed with {trainer.state}" diff --git a/tests/callbacks/test_timer.py b/tests/callbacks/test_timer.py index 94ee3e87bc3e3..a307a72bdde3b 100644 --- a/tests/callbacks/test_timer.py +++ b/tests/callbacks/test_timer.py @@ -173,9 +173,8 @@ def test_timer_resume_training(tmpdir): trainer = Trainer( default_root_dir=tmpdir, callbacks=[timer, checkpoint_callback], - resume_from_checkpoint=checkpoint_callback.best_model_path, ) - trainer.fit(model) + trainer.fit(model, ckpt_path=checkpoint_callback.best_model_path) assert timer._offset > 0 assert trainer.global_step == saved_global_step + 1 diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py index 739dc98a22834..4aef278631aac 100644 --- a/tests/checkpointing/test_trainer_checkpoint.py +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -22,7 +22,7 @@ from tests.helpers import BoringModel -def test_finetuning_with_resume_from_checkpoint(tmpdir): +def test_finetuning_with_ckpt_path(tmpdir): """This test validates that generated ModelCheckpoint is pointing to the right best_model_path during test.""" checkpoint_callback = ModelCheckpoint(monitor="val_loss", dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1) @@ -63,10 +63,9 @@ def validation_step(self, batch, batch_idx): limit_train_batches=12, limit_val_batches=12, limit_test_batches=12, - resume_from_checkpoint=best_model_paths[-1], progress_bar_refresh_rate=0, ) - trainer.fit(model) + trainer.fit(model, ckpt_path=best_model_paths[-1]) trainer.test() results.append(deepcopy(trainer.callback_metrics)) best_model_paths.append(trainer.checkpoint_callback.best_model_path) @@ -78,7 +77,7 @@ def validation_step(self, batch, batch_idx): assert f"epoch={idx + 1}" in best_model_path -def test_accumulated_gradient_batches_with_resume_from_checkpoint(tmpdir): +def test_accumulated_gradient_batches_with_ckpt_path(tmpdir): """This test validates that accumulated gradient is properly recomputed and reset on the trainer.""" ckpt = ModelCheckpoint(dirpath=tmpdir, save_last=True) @@ -90,6 +89,5 @@ def test_accumulated_gradient_batches_with_resume_from_checkpoint(tmpdir): trainer.fit(model) trainer_kwargs["max_epochs"] = 2 - trainer_kwargs["resume_from_checkpoint"] = ckpt.last_model_path trainer = Trainer(**trainer_kwargs) - trainer.fit(model) + trainer.fit(model, ckpt_path=ckpt.last_model_path) diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 25ab1ea5fd3cc..8e8eeec032d32 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -238,3 +238,7 @@ def test_v1_7_0_lightning_logger_base_close(tmpdir): ): logger = LoggerCollection([logger]) logger.close() + +def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir): + with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"): + _ = Trainer(resume_from_checkpoint="a") diff --git a/tests/loops/optimization/test_optimizer_loop.py b/tests/loops/optimization/test_optimizer_loop.py index beaf0d6daaf40..62858b7638162 100644 --- a/tests/loops/optimization/test_optimizer_loop.py +++ b/tests/loops/optimization/test_optimizer_loop.py @@ -223,7 +223,6 @@ def configure_optimizers(self): model.training_epoch_end = None model.optimizer_step = Mock(wraps=model.optimizer_step) trainer = Trainer( - resume_from_checkpoint=str(tmpdir / ".pl_auto_save.ckpt"), default_root_dir=tmpdir, max_epochs=n_epochs, limit_train_batches=n_batches, @@ -232,7 +231,7 @@ def configure_optimizers(self): logger=False, checkpoint_callback=False, ) - trainer.fit(model) + trainer.fit(model, ckpt_path=str(tmpdir / ".pl_auto_save.ckpt")) weights_resumed = model.parameters() # check that the final weights of a resumed run match the weights of a run that never failed diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index c9a784ed0a0f5..c1461354560e8 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -104,7 +104,7 @@ def validation_step_end(self, outputs): self.log("val_acc", self.valid_acc(outputs["logits"], outputs["y"])) -def test_model_properties_resume_from_checkpoint(tmpdir): +def test_model_properties_fit_ckpt_path(tmpdir): """Test that properties like `current_epoch` and `global_step` in model and trainer are always the same.""" model = BoringModel() checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) @@ -120,17 +120,17 @@ def test_model_properties_resume_from_checkpoint(tmpdir): trainer.fit(model) trainer_args.update(max_epochs=2) - trainer = Trainer(**trainer_args, resume_from_checkpoint=str(tmpdir / "last.ckpt")) - trainer.fit(model) + trainer = Trainer(**trainer_args) + trainer.fit(model, ckpt_path=str(tmpdir / "last.ckpt")) def test_try_resume_from_non_existing_checkpoint(tmpdir): - """Test that trying to resume from non-existing `resume_from_checkpoint` fails with an error.""" + """Test that trying to resume from non-existing `ckpt_path` fails with an error.""" model = BoringModel() - trainer = Trainer(resume_from_checkpoint=str(tmpdir / "non_existing.ckpt")) + trainer = Trainer() with pytest.raises(FileNotFoundError, match="Aborting training"): - trainer.fit(model) + trainer.fit(model, ckpt_path=str(tmpdir / "non_existing.ckpt")) class CaptureCallbacksBeforeTraining(Callback): @@ -140,7 +140,7 @@ def on_pretrain_routine_end(self, trainer, pl_module): self.callbacks = deepcopy(trainer.callbacks) -def test_callbacks_state_resume_from_checkpoint(tmpdir): +def test_callbacks_state_fit_ckpt_path(tmpdir): """Test that resuming from a checkpoint restores callbacks that persist state.""" dm = ClassifDataModule() model = ClassificationModel() @@ -165,8 +165,8 @@ def get_trainer_args(): callbacks_before_resume = deepcopy(trainer.callbacks) # resumed training - trainer = Trainer(**get_trainer_args(), resume_from_checkpoint=str(tmpdir / "last.ckpt")) - trainer.fit(model, datamodule=dm) + trainer = Trainer(**get_trainer_args()) + trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt")) assert len(callbacks_before_resume) == len(callback_capture.callbacks) @@ -176,7 +176,7 @@ def get_trainer_args(): assert before.best_model_score == after.best_model_score -def test_callbacks_references_resume_from_checkpoint(tmpdir): +def test_callbacks_references_fit_ckpt_path(tmpdir): """Test that resuming from a checkpoint sets references as expected.""" dm = ClassifDataModule() model = ClassificationModel() @@ -198,10 +198,10 @@ def test_callbacks_references_resume_from_checkpoint(tmpdir): new_checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) # pass in a new checkpoint object, which should take # precedence over the one in the last.ckpt file - trainer = Trainer(**args, callbacks=[new_checkpoint], resume_from_checkpoint=str(tmpdir / "last.ckpt")) + trainer = Trainer(**args, callbacks=[new_checkpoint]) assert checkpoint is not new_checkpoint assert new_checkpoint is trainer.callbacks[-1] is trainer.checkpoint_callback - trainer.fit(model, datamodule=dm) + trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt")) @RunIf(min_gpus=2) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index c7ccaab3e72f4..d719793a0dd49 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -620,7 +620,6 @@ def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): plugins=DeepSpeedPlugin(stage=3, load_full_weights=True), gpus=1, precision=16, - resume_from_checkpoint=checkpoint_path, ) with pytest.warns( UserWarning, @@ -628,7 +627,7 @@ def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): "scheduler states can not be restored. If you'd like to restore these states, you must " "provide a path to the originally saved DeepSpeed checkpoint.", ): - trainer.fit(model, datamodule=dm) + trainer.fit(model, datamodule=dm, ckpt_path=checkpoint_path) @RunIf(min_gpus=1, deepspeed=True, special=True) @@ -680,10 +679,9 @@ def on_train_batch_start( plugins=DeepSpeedPlugin(stage=3), gpus=1, precision=16, - resume_from_checkpoint=ck.best_model_path, callbacks=TestCallback(), ) - trainer.fit(model, datamodule=dm) + trainer.fit(model, datamodule=dm, ckpt_path=ck.best_model_path) @RunIf(min_gpus=2, deepspeed=True, special=True) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 6926e07c32b85..e95e35c71da52 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -132,7 +132,7 @@ def test_ddp_sharded_plugin_finetune(tmpdir): @RunIf(skip_windows=True, fairscale=True) -def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir): +def test_ddp_sharded_plugin_fit_ckpt_path(tmpdir): """Test to ensure that resuming from checkpoint works.""" model = BoringModel() trainer = Trainer(accelerator="ddp_sharded_spawn", num_processes=2, fast_dev_run=True) @@ -145,16 +145,16 @@ def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir): model = BoringModel() trainer = Trainer( - accelerator="ddp_sharded_spawn", num_processes=2, fast_dev_run=True, resume_from_checkpoint=checkpoint_path + accelerator="ddp_sharded_spawn", num_processes=2, fast_dev_run=True ) - trainer.fit(model) + trainer.fit(model, ckpt_path=checkpoint_path) @pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.") # todo @pytest.mark.skip(reason="Currently unsupported restarting training on different number of devices.") @RunIf(min_gpus=2, skip_windows=True, fairscale=True) -def test_ddp_sharded_plugin_resume_from_checkpoint_downsize_gpus(tmpdir): +def test_ddp_sharded_plugin_fit_ckpt_path_downsize_gpus(tmpdir): """Test to ensure that resuming from checkpoint works when downsizing number of GPUS.""" model = BoringModel() trainer = Trainer(accelerator="ddp_sharded_spawn", fast_dev_run=True, gpus=2) @@ -167,14 +167,14 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_downsize_gpus(tmpdir): model = BoringModel() trainer = Trainer( - accelerator="ddp_sharded_spawn", fast_dev_run=True, gpus=1, resume_from_checkpoint=checkpoint_path + accelerator="ddp_sharded_spawn", fast_dev_run=True, gpus=1 ) - trainer.fit(model) + trainer.fit(model, ckpt_path=checkpoint_path) @RunIf(min_gpus=1, skip_windows=True, fairscale=True) -def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): +def test_ddp_sharded_plugin_fit_ckpt_path_gpu_to_cpu(tmpdir): """Test to ensure that resuming from checkpoint works when going from GPUs- > CPU.""" model = BoringModel() trainer = Trainer(accelerator="ddp_sharded_spawn", gpus=1, fast_dev_run=True) @@ -187,10 +187,10 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): model = BoringModel() trainer = Trainer( - accelerator="ddp_sharded_spawn", num_processes=2, fast_dev_run=True, resume_from_checkpoint=checkpoint_path + accelerator="ddp_sharded_spawn", num_processes=2, fast_dev_run=True ) - trainer.fit(model) + trainer.fit(model, ckpt_path=checkpoint_path) @RunIf(skip_windows=True, special=True, fairscale=True) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a12fea2cc989b..d8e98b0d206c4 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -404,7 +404,7 @@ def test_model_freeze_unfreeze(): @pytest.mark.parametrize("url_ckpt", [True, False]) -def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt): +def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Verify resuming from checkpoint runs the right number of epochs.""" # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir monkeypatch.setenv("TORCH_HOME", tmpdir) @@ -455,8 +455,8 @@ def on_load_checkpoint(self, _): state = pl_load(ckpt) # Resume training - new_trainer = Trainer(default_root_dir=tmpdir, resume_from_checkpoint=ckpt, max_epochs=2) - new_trainer.fit(next_model) + new_trainer = Trainer(default_root_dir=tmpdir, max_epochs=2) + new_trainer.fit(next_model, ckpt_path=ckpt) assert state["global_step"] + next_model.num_batches_seen == trainer.num_training_batches * trainer.max_epochs assert next_model.num_on_load_checkpoint_called == 1 @@ -1792,10 +1792,10 @@ def test_on_load_checkpoint_missing_callbacks(tmpdir): trainer.fit(model) trainer = Trainer( - default_root_dir=tmpdir, max_epochs=5, resume_from_checkpoint=chk.last_model_path, progress_bar_refresh_rate=1 + default_root_dir=tmpdir, max_epochs=5, progress_bar_refresh_rate=1 ) with pytest.warns(UserWarning, match="CustomCallbackOnLoadCheckpoint"): - trainer.fit(model) + trainer.fit(model, ckpt_path=chk.last_model_path) def test_module_current_fx_attributes_reset(tmpdir): From b5dee8e82586e3a1bcc86c8e587c8a3b39729db6 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Fri, 24 Sep 2021 12:14:34 -0700 Subject: [PATCH 04/12] updates --- .../trainer/connectors/checkpoint_connector.py | 5 ++--- pytorch_lightning/trainer/trainer.py | 13 +++++-------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index d55eae098bdcf..961671fd3b82e 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -36,7 +36,7 @@ class CheckpointConnector: def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH] = None) -> None: self.trainer = trainer - self.resume_checkpoint_path = resume_from_checkpoint + self.resume_checkpoint_path = None if resume_from_checkpoint is not None: rank_zero_deprecation( "Setting `Trainer(resume_from_checkpoint=)` is deprecated in v1.5 and" @@ -65,7 +65,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: Raises: FileNotFoundError: If the path to the checkpoint file is provided but the file does not exist. """ - self.resume_checkpoint_path = self.hpc_resume_path or self.resume_checkpoint_path or checkpoint_path + self.resume_checkpoint_path = self.hpc_resume_path or checkpoint_path checkpoint_path = self.resume_checkpoint_path if not checkpoint_path: return @@ -111,7 +111,6 @@ def restore(self, checkpoint_path: Optional[_PATH] = None) -> None: Args: checkpoint_path: Path to a PyTorch Lightning checkpoint file. """ - self.resume_checkpoint_path = checkpoint_path self.resume_start(checkpoint_path) # restore module states diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 83555f453f896..8e60f5fc9d57c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -411,6 +411,9 @@ def __init__( self.signal_connector = SignalConnector(self) self.tuner = Tuner(self) + # TODO: remove in v1.7 + self._resume_from_checkpoint = resume_from_checkpoint + # max_epochs won't default to 1000 if max_steps/max_time are specified (including being set to -1). fit_loop = FitLoop( min_epochs=(1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs), @@ -619,6 +622,8 @@ def _fit_impl( model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule ) + # TODO: ckpt_path only in v1.7 + ckpt_path = ckpt_path or self._resume_from_checkpoint self._run(model, ckpt_path) assert self.state.stopped @@ -1721,14 +1726,6 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]: in the Trainer.callbacks list.""" return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] - @property - def resume_from_checkpoint(self) -> Optional[Union[str, Path]]: - rank_zero_deprecation( - "`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7." - " Use `trainer.resume_checkpoint_path` instead." - ) - return self.resume_checkpoint_path - @property def resume_checkpoint_path(self) -> Optional[Union[str, Path]]: return self.checkpoint_connector.resume_checkpoint_path From 2bb5bc52d6a9955d0718562dda8b5ac232313c37 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Fri, 24 Sep 2021 12:18:26 -0700 Subject: [PATCH 05/12] resume_start doc update --- .../trainer/connectors/checkpoint_connector.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 961671fd3b82e..f261d5f31258e 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -55,12 +55,8 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: """Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: 1. from HPC weights if found - 2. from `resume_from_checkpoint` file if provided - .. deprecated:: v1.5 - `Trainer(resume_from_checkpoint=)` is deprecated in v1.5 and will be removed in v1.7. - Please use `Trainer.fit(ckpt_path=)` instead. - 3. from `checkpoint_path` file if provided - 4. don't restore + 2. from `checkpoint_path` file if provided + 3. don't restore Raises: FileNotFoundError: If the path to the checkpoint file is provided but the file does not exist. From ef2d2cd6550a8be0cd4bde5441e6eb273d7887d9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 24 Sep 2021 21:01:36 +0000 Subject: [PATCH 06/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/trainer.py | 8 ++++++-- tests/accelerators/test_cpu.py | 4 +--- tests/accelerators/test_tpu_backend.py | 4 +--- tests/deprecated_api/test_remove_1-7.py | 1 + tests/plugins/test_custom_plugin.py | 4 +--- tests/plugins/test_sharded_plugin.py | 12 +++--------- tests/trainer/test_trainer.py | 4 +--- tests/utilities/test_auto_restart.py | 6 ++++-- 8 files changed, 18 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4b8e85dc79305..56f133386fc42 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -594,7 +594,9 @@ def fit( " Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'" ) train_dataloaders = train_dataloader - self._call_and_handle_interrupt(self._fit_impl, model, train_dataloaders, val_dataloaders, ckpt_path, datamodule) + self._call_and_handle_interrupt( + self._fit_impl, model, train_dataloaders, val_dataloaders, ckpt_path, datamodule + ) def _fit_impl( self, @@ -962,7 +964,9 @@ def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None # restore callback states self.checkpoint_connector.restore_callbacks() - def _run(self, model: "pl.LightningModule", ckpt_path: Optional[str] = None) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: + def _run( + self, model: "pl.LightningModule", ckpt_path: Optional[str] = None + ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 3a83c2577aa8f..be66730efdb5c 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -80,9 +80,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch - trainer = Trainer( - default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True - ) + trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True) trainer.fit(model, ckpt_path=checkpoint_path) for func in (trainer.test, trainer.validate, trainer.predict): accelerator.training_type_plugin.predispatched_called = False diff --git a/tests/accelerators/test_tpu_backend.py b/tests/accelerators/test_tpu_backend.py index 6d2d5d38dafb0..0d7c0d5d8c3f8 100644 --- a/tests/accelerators/test_tpu_backend.py +++ b/tests/accelerators/test_tpu_backend.py @@ -61,9 +61,7 @@ def test_resume_training_on_cpu(tmpdir): assert weight_tensor.device == torch.device("cpu") # Verify that training is resumed on CPU - trainer = Trainer( - checkpoint_callback=True, max_epochs=1, default_root_dir=tmpdir - ) + trainer = Trainer(checkpoint_callback=True, max_epochs=1, default_root_dir=tmpdir) trainer.fit(model, ckpt_path=model_path) assert trainer.state.finished, f"Training failed with {trainer.state}" diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index d332ba0c7a86d..4b9f874a92c08 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -244,6 +244,7 @@ def test_v1_7_0_lightning_logger_base_close(tmpdir): logger = LoggerCollection([logger]) logger.close() + def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir): with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"): _ = Trainer(resume_from_checkpoint="a") diff --git a/tests/plugins/test_custom_plugin.py b/tests/plugins/test_custom_plugin.py index bebd29736e840..9d23f48bed98e 100644 --- a/tests/plugins/test_custom_plugin.py +++ b/tests/plugins/test_custom_plugin.py @@ -62,8 +62,6 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: model = BoringModel() plugin = TestPlugin(torch.device("cpu")) - trainer = Trainer( - default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin - ) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin) trainer.fit(model, ckpt_path=checkpoint_path) assert plugin.load_optimizer_state_dict_called == restore_optimizer_and_schedulers diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index e95e35c71da52..cb7f3b82a2642 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -144,9 +144,7 @@ def test_ddp_sharded_plugin_fit_ckpt_path(tmpdir): model = BoringModel() - trainer = Trainer( - accelerator="ddp_sharded_spawn", num_processes=2, fast_dev_run=True - ) + trainer = Trainer(accelerator="ddp_sharded_spawn", num_processes=2, fast_dev_run=True) trainer.fit(model, ckpt_path=checkpoint_path) @@ -166,9 +164,7 @@ def test_ddp_sharded_plugin_fit_ckpt_path_downsize_gpus(tmpdir): model = BoringModel() - trainer = Trainer( - accelerator="ddp_sharded_spawn", fast_dev_run=True, gpus=1 - ) + trainer = Trainer(accelerator="ddp_sharded_spawn", fast_dev_run=True, gpus=1) trainer.fit(model, ckpt_path=checkpoint_path) @@ -186,9 +182,7 @@ def test_ddp_sharded_plugin_fit_ckpt_path_gpu_to_cpu(tmpdir): model = BoringModel() - trainer = Trainer( - accelerator="ddp_sharded_spawn", num_processes=2, fast_dev_run=True - ) + trainer = Trainer(accelerator="ddp_sharded_spawn", num_processes=2, fast_dev_run=True) trainer.fit(model, ckpt_path=checkpoint_path) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c14453e4738cf..1b2960c9beb31 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1848,9 +1848,7 @@ def test_on_load_checkpoint_missing_callbacks(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=3, callbacks=[chk, CustomCallbackOnLoadCheckpoint()]) trainer.fit(model) - trainer = Trainer( - default_root_dir=tmpdir, max_epochs=5, progress_bar_refresh_rate=1 - ) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=5, progress_bar_refresh_rate=1) with pytest.warns(UserWarning, match="CustomCallbackOnLoadCheckpoint"): trainer.fit(model, ckpt_path=chk.last_model_path) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 22b50e997c597..39134c0e890f2 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -907,7 +907,7 @@ def configure_optimizers(self): return torch.optim.SGD(self.layer.parameters(), lr=0.1) -def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1, ckpt_path = None): +def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1, ckpt_path=None): seed_everything(1) train_dataloader = [ DataLoader(dataset_class(3, 1), batch_size=1, num_workers=0) for dataset_class in dataset_classes @@ -958,7 +958,9 @@ def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, mult assert os.path.exists(checkpoint_path) # Resume after failure - resumed_batches, weights1 = _run_training(trainer_kwargs, dataset_classes, fail_on_step=-1, ckpt_path=checkpoint_path) + resumed_batches, weights1 = _run_training( + trainer_kwargs, dataset_classes, fail_on_step=-1, ckpt_path=checkpoint_path + ) assert len(resumed_batches) == 5 # the resumed batches should match the batches of the successful training From 205a38017dbd4faaa83058d285560df8c934ee53 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Fri, 24 Sep 2021 14:16:04 -0700 Subject: [PATCH 07/12] mypy --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- tests/trainer/test_trainer.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index e767b6f2a7578..c131d1d8cd3c1 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -37,7 +37,7 @@ class CheckpointConnector: def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH] = None) -> None: self.trainer = trainer - self.resume_checkpoint_path = None + self.resume_checkpoint_path: Optional[_PATH] = None if resume_from_checkpoint is not None: rank_zero_deprecation( "Setting `Trainer(resume_from_checkpoint=)` is deprecated in v1.5 and" diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1b2960c9beb31..8e37b8b5f5f47 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -741,7 +741,6 @@ def predict_step(self, batch, *_): trainer.fit(model) trainer_fn = getattr(trainer, fn) - # path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path" assert trainer.resume_checkpoint_path is None if ckpt_path == "best": From ea41b4103f84d30cc2858c869c4e91ddd8adb796 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 Sep 2021 20:35:18 +0000 Subject: [PATCH 08/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/deprecated_api/test_remove_1-7.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 84a4ee5ec0697..2f94354866040 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -249,7 +249,7 @@ def test_v1_7_0_lightning_logger_base_close(tmpdir): logger = LoggerCollection([logger]) logger.close() - + def test_v1_7_0_deprecate_lightning_distributed(tmpdir): with pytest.deprecated_call(match="LightningDistributed is deprecated in v1.5 and will be removed in v1.7."): from pytorch_lightning.distributed.dist import LightningDistributed From 7d89b889e705120aeb23d6457716f477657d499a Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Mon, 27 Sep 2021 15:31:15 -0700 Subject: [PATCH 09/12] add resume_end, depr trainer.resume_checkpoint_path --- pytorch_lightning/trainer/trainer.py | 8 ++------ tests/trainer/connectors/test_checkpoint_connector.py | 1 - tests/trainer/test_trainer.py | 8 -------- 3 files changed, 2 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b7b3eb7301268..f5087efea4825 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1054,6 +1054,8 @@ def _run( if self.state.fn == TrainerFn.FITTING: self.checkpoint_connector.restore_training_state() + self.checkpoint_connector.resume_end() + # dispatch `start_training` or `start_evaluating` or `start_predicting` self._dispatch() @@ -1153,8 +1155,6 @@ def _pre_training_routine(self): # register signals self.signal_connector.register_signal_handlers() - self.checkpoint_connector.resume_end() - # -------------------------- # Pre-train # -------------------------- @@ -1741,10 +1741,6 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]: in the Trainer.callbacks list.""" return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] - @property - def resume_checkpoint_path(self) -> Optional[Union[str, Path]]: - return self.checkpoint_connector.resume_checkpoint_path - def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None: self.checkpoint_connector.save_checkpoint(filepath, weights_only) diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py index f5fbfef8b1cf3..7c8a4e2dc9bfb 100644 --- a/tests/trainer/connectors/test_checkpoint_connector.py +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -60,7 +60,6 @@ def test_preloaded_checkpoint_lifecycle(tmpdir): connector = trainer.checkpoint_connector - assert not trainer.resume_checkpoint_path assert not connector.resume_checkpoint_path assert not connector._loaded_checkpoint diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b2e32e3254b13..63953efab6769 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -741,7 +741,6 @@ def predict_step(self, batch, *_): trainer.fit(model) trainer_fn = getattr(trainer, fn) - assert trainer.resume_checkpoint_path is None if ckpt_path == "best": # ckpt_path is 'best', meaning we load the best weights @@ -752,20 +751,15 @@ def predict_step(self, batch, *_): trainer_fn(model, ckpt_path=ckpt_path) else: trainer_fn(ckpt_path=ckpt_path) - assert trainer.resume_checkpoint_path == trainer.checkpoint_callback.best_model_path - trainer_fn(model, ckpt_path=ckpt_path) - assert trainer.resume_checkpoint_path == trainer.checkpoint_callback.best_model_path elif ckpt_path is None: # ckpt_path is None, meaning we don't load any checkpoints and use the provided model trainer_fn(model, ckpt_path=ckpt_path) - assert trainer.resume_checkpoint_path is None if save_top_k > 0: # ckpt_path is None with no model provided means load the best weights with pytest.warns(UserWarning, match="The best model of the previous `fit` call will be used"): trainer_fn(ckpt_path=ckpt_path) - assert trainer.resume_checkpoint_path == trainer.checkpoint_callback.best_model_path else: # specific checkpoint, pick one from saved ones if save_top_k == 0: @@ -778,10 +772,8 @@ def predict_step(self, batch, *_): ].absolute() ) trainer_fn(ckpt_path=ckpt_path) - assert trainer.resume_checkpoint_path == ckpt_path trainer_fn(model, ckpt_path=ckpt_path) - assert trainer.resume_checkpoint_path == ckpt_path def test_disabled_training(tmpdir): From df7d4a92485d69705f493111614ab0496577107d Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Mon, 27 Sep 2021 18:03:09 -0700 Subject: [PATCH 10/12] fit arg order --- pytorch_lightning/trainer/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f5087efea4825..d86764cb44882 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -573,9 +573,9 @@ def fit( model: "pl.LightningModule", train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, val_dataloaders: Optional[EVAL_DATALOADERS] = None, - ckpt_path: Optional[str] = None, datamodule: Optional[LightningDataModule] = None, train_dataloader=None, # TODO: remove with 1.6 + ckpt_path: Optional[str] = None, ) -> None: r""" Runs the full optimization routine. @@ -602,7 +602,7 @@ def fit( ) train_dataloaders = train_dataloader self._call_and_handle_interrupt( - self._fit_impl, model, train_dataloaders, val_dataloaders, ckpt_path, datamodule + self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path ) def _fit_impl( @@ -610,8 +610,8 @@ def _fit_impl( model: "pl.LightningModule", train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, val_dataloaders: Optional[EVAL_DATALOADERS] = None, - ckpt_path: Optional[str] = None, datamodule: Optional[LightningDataModule] = None, + ckpt_path: Optional[str] = None, ) -> None: Trainer._log_api_event("fit") From 4315cf5a16472932e570c8e3b484b81dc0bc5d36 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Thu, 30 Sep 2021 11:37:03 -0700 Subject: [PATCH 11/12] bring back properties --- .../connectors/checkpoint_connector.py | 2 +- pytorch_lightning/trainer/trainer.py | 30 ++++++++++++------- tests/deprecated_api/test_remove_1-7.py | 4 ++- tests/trainer/test_trainer.py | 9 ++++++ 4 files changed, 33 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index c131d1d8cd3c1..89734527b5dcd 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -37,7 +37,7 @@ class CheckpointConnector: def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH] = None) -> None: self.trainer = trainer - self.resume_checkpoint_path: Optional[_PATH] = None + self.resume_checkpoint_path: Optional[_PATH] = resume_from_checkpoint if resume_from_checkpoint is not None: rank_zero_deprecation( "Setting `Trainer(resume_from_checkpoint=)` is deprecated in v1.5 and" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cb587921dcfe9..6c53fb9769f9d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -423,9 +423,6 @@ def __init__( self.signal_connector = SignalConnector(self) self.tuner = Tuner(self) - # TODO: remove in v1.7 - self._resume_from_checkpoint = resume_from_checkpoint - # max_epochs won't default to 1000 if max_steps/max_time are specified (including being set to -1). fit_loop = FitLoop( min_epochs=(1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs), @@ -452,6 +449,11 @@ def __init__( # Needed because of LightningOptimizer self._lightning_optimizers = None + # .validate() and .test() set this when they load a checkpoint + self.validated_ckpt_path: Optional[str] = None + self.tested_ckpt_path: Optional[str] = None + self.predicted_ckpt_path: Optional[str] = None + # init callbacks # Declare attributes to be set in callback_connector on_trainer_init self.callback_connector.on_trainer_init( @@ -636,7 +638,7 @@ def _fit_impl( ) # TODO: ckpt_path only in v1.7 - ckpt_path = ckpt_path or self._resume_from_checkpoint + ckpt_path = ckpt_path or self.resume_from_checkpoint self._run(model, ckpt_path) assert self.state.stopped @@ -718,12 +720,12 @@ def _validate_impl( # links data to the trainer self.data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule) - ckpt_path = self.__set_ckpt_path( + self.validated_ckpt_path = self.__set_ckpt_path( ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) # run validate - results = self._run(model, ckpt_path) + results = self._run(model, self.validated_ckpt_path) assert self.state.stopped self.validating = False @@ -807,12 +809,12 @@ def _test_impl( # links data to the trainer self.data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule) - ckpt_path = self.__set_ckpt_path( + self.tested_ckpt_path = self.__set_ckpt_path( ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) # run test - results = self._run(model, ckpt_path) + results = self._run(model, self.tested_ckpt_path) assert self.state.stopped self.testing = False @@ -890,11 +892,11 @@ def _predict_impl( # links data to the trainer self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule) - ckpt_path = self.__set_ckpt_path( + self.predicted_ckpt_path = self.__set_ckpt_path( ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) - results = self._run(model, ckpt_path) + results = self._run(model, self.predicted_ckpt_path) assert self.state.stopped self.predicting = False @@ -1742,6 +1744,14 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]: in the Trainer.callbacks list.""" return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] + @property + def resume_from_checkpoint(self) -> Optional[Union[str, Path]]: + rank_zero_deprecation( + "`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7." + " Specify fit ckpt_path with `trainer.fit(ckpt_path=)` instead." + ) + return self.checkpoint_connector.resume_checkpoint_path + def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None: self.checkpoint_connector.save_checkpoint(filepath, weights_only) diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 2f94354866040..d1b16a2aa7fd2 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -259,4 +259,6 @@ def test_v1_7_0_deprecate_lightning_distributed(tmpdir): def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir): with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"): - _ = Trainer(resume_from_checkpoint="a") + trainer = Trainer(resume_from_checkpoint="a") + with pytest.deprecated_call(match=r"trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."): + _ = trainer.resume_from_checkpoint diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index aededbc19e8ec..41fbbe75677bb 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -741,6 +741,8 @@ def predict_step(self, batch, *_): trainer.fit(model) trainer_fn = getattr(trainer, fn) + path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path" + assert getattr(trainer, path_attr) is None if ckpt_path == "best": # ckpt_path is 'best', meaning we load the best weights @@ -751,15 +753,20 @@ def predict_step(self, batch, *_): trainer_fn(model, ckpt_path=ckpt_path) else: trainer_fn(ckpt_path=ckpt_path) + assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path + trainer_fn(model, ckpt_path=ckpt_path) + assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path elif ckpt_path is None: # ckpt_path is None, meaning we don't load any checkpoints and use the provided model trainer_fn(model, ckpt_path=ckpt_path) + assert getattr(trainer, path_attr) is None if save_top_k > 0: # ckpt_path is None with no model provided means load the best weights with pytest.warns(UserWarning, match="The best model of the previous `fit` call will be used"): trainer_fn(ckpt_path=ckpt_path) + assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path else: # specific checkpoint, pick one from saved ones if save_top_k == 0: @@ -772,8 +779,10 @@ def predict_step(self, batch, *_): ].absolute() ) trainer_fn(ckpt_path=ckpt_path) + assert getattr(trainer, path_attr) == ckpt_path trainer_fn(model, ckpt_path=ckpt_path) + assert getattr(trainer, path_attr) == ckpt_path def test_disabled_training(tmpdir): From d5069ecd1e3837733daa2dc3eaf9168142bf159d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 Sep 2021 18:38:50 +0000 Subject: [PATCH 12/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/deprecated_api/test_remove_1-7.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index d1b16a2aa7fd2..795d7f00236fb 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -260,5 +260,7 @@ def test_v1_7_0_deprecate_lightning_distributed(tmpdir): def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir): with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"): trainer = Trainer(resume_from_checkpoint="a") - with pytest.deprecated_call(match=r"trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."): + with pytest.deprecated_call( + match=r"trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7." + ): _ = trainer.resume_from_checkpoint