Skip to content

Fix current epoch value override on restart #12429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 28, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ def restarting(self, restarting: bool) -> None:
self.epoch_progress.current.processed,
)
finished_before_on_train_end = any(v != self.epoch_progress.current.completed for v in values)
if finished_before_on_train_end:
self.epoch_progress.current.completed = self.epoch_progress.current.processed
restarting &= finished_before_on_train_end
Loop.restarting.fset(self, restarting) # call the parent setter

Expand Down Expand Up @@ -168,6 +166,9 @@ def done(self) -> bool:
# `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
# we use it here because the checkpoint data won't have `completed` increased yet
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
if stop_epochs:
# in case they are not equal, override so `trainer.current_epoch` has the expected value
self.epoch_progress.current.completed = self.epoch_progress.current.processed

should_stop = False
if self.trainer.should_stop:
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
"state_dict": ANY,
"loops": ANY,
}
saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload, "epoch": 1}
saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload}
expected = [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
Expand Down Expand Up @@ -646,7 +646,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
dict(name="on_epoch_start"),
dict(name="Callback.on_train_epoch_start", args=(trainer, model)),
dict(name="on_train_epoch_start"),
*model._train_batch(trainer, model, steps_after_reload, current_batch=1, current_epoch=1),
*model._train_batch(trainer, model, steps_after_reload, current_batch=1),
dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)),
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
dict(name="Callback.state_dict"),
Expand Down
4 changes: 1 addition & 3 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,7 @@ def on_train_start(self):
if self.trainer.state.fn == TrainerFn.TUNING:
self._test_on_val_test_predict_tune_start()
else:
# `-1` because this checkpoint is saved `on_train_epoch_end` which is considered part of the epoch so
# the `current_epoch` count has not been increased yet
assert self.trainer.current_epoch - 1 == state_dict["epoch"]
assert self.trainer.current_epoch == state_dict["epoch"]
assert self.trainer.global_step == state_dict["global_step"]
assert self._check_model_state_dict()
assert self._check_optimizers()
Expand Down
24 changes: 13 additions & 11 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def test_model_freeze_unfreeze():
assert param.requires_grad


# TODO: move to `test/models/test_restore.py`
@pytest.mark.parametrize("url_ckpt", [True, False])
def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
"""Verify resuming from checkpoint runs the right number of epochs."""
Expand All @@ -425,11 +426,12 @@ def on_load_checkpoint(self, _):
self.num_on_load_checkpoint_called += 1

model = TestModel()
max_epochs = 2
trainer = Trainer(
max_epochs=2,
max_epochs=max_epochs,
limit_train_batches=0.65,
limit_val_batches=1,
callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_top_k=-1)],
callbacks=ModelCheckpoint(dirpath=tmpdir, save_top_k=-1),
default_root_dir=tmpdir,
val_check_interval=1.0,
enable_progress_bar=False,
Expand All @@ -438,26 +440,26 @@ def on_load_checkpoint(self, _):
)
trainer.fit(model)

assert model.num_epochs_end_seen == 2
assert model.num_batches_seen == trainer.num_training_batches * 2
assert model.num_epochs_end_seen == max_epochs
assert model.num_batches_seen == trainer.num_training_batches * max_epochs == trainer.global_step
assert model.num_on_load_checkpoint_called == 0

# Other checkpoints can be uncommented if/when resuming mid-epoch is supported
checkpoints = Path(trainer.checkpoint_callback.dirpath).glob("*.ckpt")
checkpoints = set(Path(trainer.checkpoint_callback.dirpath).glob("*.ckpt"))
if url_ckpt:
# transform local paths into url checkpoints
ip, port = tmpdir_server
checkpoints = [f"http://{ip}:{port}/" + ckpt.name for ckpt in checkpoints]

assert len(checkpoints) == max_epochs
for ckpt in checkpoints:
next_model = TestModel()
model = TestModel()
state = pl_load(ckpt)

# Resume training
new_trainer = Trainer(default_root_dir=tmpdir, max_epochs=2)
new_trainer.fit(next_model, ckpt_path=ckpt)
assert state["global_step"] + next_model.num_batches_seen == trainer.num_training_batches * trainer.max_epochs
assert next_model.num_on_load_checkpoint_called == 1
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, enable_progress_bar=False)
trainer.fit(model, ckpt_path=ckpt)
assert state["global_step"] + model.num_batches_seen == trainer.global_step
assert model.num_on_load_checkpoint_called == 1


def test_trainer_max_steps_and_epochs(tmpdir):
Expand Down