Skip to content

Commit 9ebdc52

Browse files
nithinraokcarmocca
authored andcommitted
Set the state before saving "last" or "none" checkpoints (#11481)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent bcd7c87 commit 9ebdc52

File tree

4 files changed

+38
-10
lines changed

4 files changed

+38
-10
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2626
- Pin sphinx-autodoc-typehints with <v1.15 ([#11400](https://github.com/PyTorchLightning/pytorch-lightning/pull/11400))
2727
- Skip testing with PyTorch 1.7 and Python 3.9 on Ubuntu ([#11217](https://github.com/PyTorchLightning/pytorch-lightning/pull/11217))
2828
- Fixed type promotion when tensors of higher category than float are logged ([#11401](https://github.com/PyTorchLightning/pytorch-lightning/pull/11401))
29+
- Fixed bug where the path for "last" checkpoints was not getting saved correctly which caused newer runs to not remove the previous "last" checkpoint ([#11481](https://github.com/PyTorchLightning/pytorch-lightning/pull/11481))
30+
- Fixed bug where the path for best checkpoints was not getting saved correctly when no metric was monitored which caused newer runs to not use the best checkpoint ([#11481](https://github.com/PyTorchLightning/pytorch-lightning/pull/11481))
2931
- Fixed the format of the configuration saved automatically by the CLI's `SaveConfigCallback` ([#11532](https://github.com/PyTorchLightning/pytorch-lightning/pull/11532))
3032

3133
### Changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -667,12 +667,11 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
667667
return
668668

669669
filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST)
670+
# set the last model path before saving because it will be part of the state.
671+
previous, self.last_model_path = self.last_model_path, filepath
670672
trainer.save_checkpoint(filepath, self.save_weights_only)
671-
672-
if self.last_model_path and self.last_model_path != filepath:
673-
trainer.training_type_plugin.remove_checkpoint(self.last_model_path)
674-
675-
self.last_model_path = filepath
673+
if previous and previous != filepath:
674+
trainer.training_type_plugin.remove_checkpoint(previous)
676675

677676
def _save_top_k_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
678677
if self.monitor is None or self.save_top_k == 0:
@@ -692,12 +691,11 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate
692691
return
693692

694693
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
694+
# set the best model path before saving because it will be part of the state.
695+
previous, self.best_model_path = self.best_model_path, filepath
695696
trainer.save_checkpoint(filepath, self.save_weights_only)
696-
697-
if self.save_top_k == 1 and self.best_model_path and self.best_model_path != filepath:
698-
trainer.training_type_plugin.remove_checkpoint(self.best_model_path)
699-
700-
self.best_model_path = filepath
697+
if self.save_top_k == 1 and previous and previous != filepath:
698+
trainer.training_type_plugin.remove_checkpoint(previous)
701699

702700
def _is_valid_monitor_key(self, metrics: Dict[str, _METRIC]) -> bool:
703701
return self.monitor in metrics or len(metrics) == 0

tests/checkpointing/test_model_checkpoint.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,3 +1240,30 @@ def test_model_checkpoint_saveload_ckpt(tmpdir):
12401240
assert getattr(cb_restore, key) == val
12411241
else:
12421242
assert getattr(cb_restore, key) != val
1243+
1244+
1245+
def test_save_last_saves_correct_last_model_path(tmpdir):
1246+
mc = ModelCheckpoint(dirpath=tmpdir, save_last=True)
1247+
mc.CHECKPOINT_NAME_LAST = "{foo}-last"
1248+
trainer = Trainer(callbacks=mc)
1249+
trainer.training_type_plugin.connect(BoringModel())
1250+
1251+
mc._save_last_checkpoint(trainer, {"foo": 1})
1252+
expected = "foo=1-last.ckpt"
1253+
assert os.listdir(tmpdir) == [expected]
1254+
full_path = str(tmpdir / expected)
1255+
ckpt = torch.load(full_path)
1256+
assert ckpt["callbacks"][mc.state_key]["last_model_path"] == full_path
1257+
1258+
1259+
def test_none_monitor_saves_correct_best_model_path(tmpdir):
1260+
mc = ModelCheckpoint(dirpath=tmpdir, monitor=None)
1261+
trainer = Trainer(callbacks=mc)
1262+
trainer.training_type_plugin.connect(BoringModel())
1263+
1264+
mc._save_none_monitor_checkpoint(trainer, {})
1265+
expected = "epoch=0-step=0.ckpt"
1266+
assert os.listdir(tmpdir) == [expected]
1267+
full_path = str(tmpdir / expected)
1268+
ckpt = torch.load(full_path)
1269+
assert ckpt["callbacks"][mc.state_key]["best_model_path"] == full_path

tests/deprecated_api/test_remove_1-7.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir):
425425
assert trainer.checkpoint_connector.resume_checkpoint_path is None
426426
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path
427427
trainer.fit(model)
428+
ckpt_path = trainer.checkpoint_callback.best_model_path # last `fit` replaced the `best_model_path`
428429
assert callback.state == 111
429430
assert trainer.checkpoint_connector.resume_checkpoint_path is None
430431
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None

0 commit comments

Comments
 (0)