From e7fea68a90087cf9a6ad69d8e9de05ce7041e4ae Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 23 Mar 2022 19:56:08 +0100 Subject: [PATCH 1/2] Fix current epoch value override on restart --- pytorch_lightning/loops/fit_loop.py | 5 +++-- tests/models/test_hooks.py | 4 ++-- tests/models/test_restore.py | 4 +--- tests/trainer/test_trainer.py | 24 +++++++++++++----------- 4 files changed, 19 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 5111969ca79db..7d2b77425a94a 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -131,8 +131,6 @@ def restarting(self, restarting: bool) -> None: self.epoch_progress.current.processed, ) finished_before_on_train_end = any(v != self.epoch_progress.current.completed for v in values) - if finished_before_on_train_end: - self.epoch_progress.current.completed = self.epoch_progress.current.processed restarting &= finished_before_on_train_end Loop.restarting.fset(self, restarting) # call the parent setter @@ -168,6 +166,9 @@ def done(self) -> bool: # `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved. # we use it here because the checkpoint data won't have `completed` increased yet stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs) + if stop_epochs: + # in case they are not equal, override so `trainer.current_epoch` has the expected value + self.epoch_progress.current.completed = self.epoch_progress.current.processed should_stop = False if self.trainer.should_stop: diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 4a5cfbdae28cc..0f6ffec4f0b45 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -616,7 +616,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): "state_dict": ANY, "loops": ANY, } - saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload, "epoch": 1} + saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload} expected = [ dict(name="Callback.on_init_start", args=(trainer,)), dict(name="Callback.on_init_end", args=(trainer,)), @@ -646,7 +646,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): dict(name="on_epoch_start"), dict(name="Callback.on_train_epoch_start", args=(trainer, model)), dict(name="on_train_epoch_start"), - *model._train_batch(trainer, model, steps_after_reload, current_batch=1, current_epoch=1), + *model._train_batch(trainer, model, steps_after_reload, current_batch=1), dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)), dict(name="Callback.on_train_epoch_end", args=(trainer, model)), dict(name="Callback.state_dict"), diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index e5259c4047ad2..907b690fcc5cf 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -199,9 +199,7 @@ def on_train_start(self): if self.trainer.state.fn == TrainerFn.TUNING: self._test_on_val_test_predict_tune_start() else: - # `-1` because this checkpoint is saved `on_train_epoch_end` which is considered part of the epoch so - # the `current_epoch` count has not been increased yet - assert self.trainer.current_epoch - 1 == state_dict["epoch"] + assert self.trainer.current_epoch == state_dict["epoch"] assert self.trainer.global_step == state_dict["global_step"] assert self._check_model_state_dict() assert self._check_optimizers() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index fca1aa0c37e53..962b20b3723a2 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -403,6 +403,7 @@ def test_model_freeze_unfreeze(): assert param.requires_grad +# TODO: move to `test/models/test_restore.py` @pytest.mark.parametrize("url_ckpt", [True, False]) def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Verify resuming from checkpoint runs the right number of epochs.""" @@ -425,11 +426,12 @@ def on_load_checkpoint(self, _): self.num_on_load_checkpoint_called += 1 model = TestModel() + max_epochs = 2 trainer = Trainer( - max_epochs=2, + max_epochs=max_epochs, limit_train_batches=0.65, limit_val_batches=1, - callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_top_k=-1)], + callbacks=ModelCheckpoint(dirpath=tmpdir, save_top_k=-1), default_root_dir=tmpdir, val_check_interval=1.0, enable_progress_bar=False, @@ -438,26 +440,26 @@ def on_load_checkpoint(self, _): ) trainer.fit(model) - assert model.num_epochs_end_seen == 2 - assert model.num_batches_seen == trainer.num_training_batches * 2 + assert model.num_epochs_end_seen == max_epochs + assert model.num_batches_seen == trainer.num_training_batches * max_epochs == trainer.global_step assert model.num_on_load_checkpoint_called == 0 # Other checkpoints can be uncommented if/when resuming mid-epoch is supported - checkpoints = Path(trainer.checkpoint_callback.dirpath).glob("*.ckpt") + checkpoints = set(Path(trainer.checkpoint_callback.dirpath).glob("*.ckpt")) if url_ckpt: # transform local paths into url checkpoints ip, port = tmpdir_server checkpoints = [f"http://{ip}:{port}/" + ckpt.name for ckpt in checkpoints] + assert len(checkpoints) == max_epochs for ckpt in checkpoints: - next_model = TestModel() + model = TestModel() state = pl_load(ckpt) - # Resume training - new_trainer = Trainer(default_root_dir=tmpdir, max_epochs=2) - new_trainer.fit(next_model, ckpt_path=ckpt) - assert state["global_step"] + next_model.num_batches_seen == trainer.num_training_batches * trainer.max_epochs - assert next_model.num_on_load_checkpoint_called == 1 + trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, enable_progress_bar=False) + trainer.fit(model, ckpt_path=ckpt) + assert state["global_step"] + model.num_batches_seen == trainer.global_step + assert model.num_on_load_checkpoint_called == 1 def test_trainer_max_steps_and_epochs(tmpdir): From 0bfc48c30c1794551533a978e33eda392f5598c0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 23 Mar 2022 20:01:26 +0100 Subject: [PATCH 2/2] Remove outdated comment --- tests/trainer/test_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 962b20b3723a2..736141e2af99f 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -444,7 +444,6 @@ def on_load_checkpoint(self, _): assert model.num_batches_seen == trainer.num_training_batches * max_epochs == trainer.global_step assert model.num_on_load_checkpoint_called == 0 - # Other checkpoints can be uncommented if/when resuming mid-epoch is supported checkpoints = set(Path(trainer.checkpoint_callback.dirpath).glob("*.ckpt")) if url_ckpt: # transform local paths into url checkpoints