diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index b36ece18e3fce..e90dff744067c 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -131,8 +131,6 @@ def restarting(self, restarting: bool) -> None: self.epoch_progress.current.processed, ) finished_before_on_train_end = any(v != self.epoch_progress.current.completed for v in values) - if finished_before_on_train_end: - self.epoch_progress.current.completed = self.epoch_progress.current.processed restarting &= finished_before_on_train_end Loop.restarting.fset(self, restarting) # call the parent setter @@ -168,6 +166,9 @@ def done(self) -> bool: # `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved. # we use it here because the checkpoint data won't have `completed` increased yet stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs) + if stop_epochs: + # in case they are not equal, override so `trainer.current_epoch` has the expected value + self.epoch_progress.current.completed = self.epoch_progress.current.processed should_stop = False if self.trainer.should_stop: diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index d8810eab180ef..de127628af4cb 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -616,7 +616,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): "state_dict": ANY, "loops": ANY, } - saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload, "epoch": 1} + saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload} expected = [ dict(name="Callback.on_init_start", args=(trainer,)), dict(name="Callback.on_init_end", args=(trainer,)), @@ -646,7 +646,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): dict(name="on_epoch_start"), dict(name="Callback.on_train_epoch_start", args=(trainer, model)), dict(name="on_train_epoch_start"), - *model._train_batch(trainer, model, steps_after_reload, current_batch=1, current_epoch=1), + *model._train_batch(trainer, model, steps_after_reload, current_batch=1), dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)), dict(name="Callback.on_train_epoch_end", args=(trainer, model)), dict(name="Callback.state_dict"), diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index e5c91e0f71e8d..dfcbb1c82df89 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -199,9 +199,7 @@ def on_train_start(self): if self.trainer.state.fn == TrainerFn.TUNING: self._test_on_val_test_predict_tune_start() else: - # `-1` because this checkpoint is saved `on_train_epoch_end` which is considered part of the epoch so - # the `current_epoch` count has not been increased yet - assert self.trainer.current_epoch - 1 == state_dict["epoch"] + assert self.trainer.current_epoch == state_dict["epoch"] assert self.trainer.global_step == state_dict["global_step"] assert self._check_model_state_dict() assert self._check_optimizers() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 314e4f3578e48..736141e2af99f 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -403,7 +403,7 @@ def test_model_freeze_unfreeze(): assert param.requires_grad -@pytest.mark.xfail(reason="FIXME(@carmocca): this test wasn't running and is now broken") +# TODO: move to `test/models/test_restore.py` @pytest.mark.parametrize("url_ckpt", [True, False]) def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Verify resuming from checkpoint runs the right number of epochs.""" @@ -426,11 +426,12 @@ def on_load_checkpoint(self, _): self.num_on_load_checkpoint_called += 1 model = TestModel() + max_epochs = 2 trainer = Trainer( - max_epochs=2, + max_epochs=max_epochs, limit_train_batches=0.65, limit_val_batches=1, - callbacks=[ModelCheckpoint(dirpath=tmpdir, save_top_k=-1)], + callbacks=ModelCheckpoint(dirpath=tmpdir, save_top_k=-1), default_root_dir=tmpdir, val_check_interval=1.0, enable_progress_bar=False, @@ -439,27 +440,25 @@ def on_load_checkpoint(self, _): ) trainer.fit(model) - assert model.num_epochs_end_seen == 2 - assert model.num_batches_seen == trainer.num_training_batches * 2 + assert model.num_epochs_end_seen == max_epochs + assert model.num_batches_seen == trainer.num_training_batches * max_epochs == trainer.global_step assert model.num_on_load_checkpoint_called == 0 - # Other checkpoints can be uncommented if/when resuming mid-epoch is supported - checkpoints = Path(trainer.checkpoint_callback.dirpath).glob("*.ckpt") + checkpoints = set(Path(trainer.checkpoint_callback.dirpath).glob("*.ckpt")) if url_ckpt: # transform local paths into url checkpoints ip, port = tmpdir_server checkpoints = [f"http://{ip}:{port}/" + ckpt.name for ckpt in checkpoints] - assert checkpoints + assert len(checkpoints) == max_epochs for ckpt in checkpoints: - next_model = TestModel() + model = TestModel() state = pl_load(ckpt) - # Resume training - new_trainer = Trainer(default_root_dir=tmpdir, max_epochs=2) - new_trainer.fit(next_model, ckpt_path=ckpt) - assert state["global_step"] + next_model.num_batches_seen == trainer.num_training_batches * trainer.max_epochs - assert next_model.num_on_load_checkpoint_called == 1 + trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, enable_progress_bar=False) + trainer.fit(model, ckpt_path=ckpt) + assert state["global_step"] + model.num_batches_seen == trainer.global_step + assert model.num_on_load_checkpoint_called == 1 def test_trainer_max_steps_and_epochs(tmpdir):