Skip to content

Commit 5d2229d

Browse files
authored
Merge branch 'PyTorchLightning:master' into refactor/gpus
2 parents 461aba2 + 71e0ddb commit 5d2229d

File tree

6 files changed

+104
-78
lines changed

6 files changed

+104
-78
lines changed

CHANGELOG.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
382382
- Move `Strategy.process_dataloader` function call from `fit/evaluation/predict_loop.py` to `data_connector.py` ([#12251](https://github.com/PyTorchLightning/pytorch-lightning/pull/12251))
383383

384384

385+
- `ModelCheckpoint(save_last=True, every_n_epochs=N)` now saves a "last" checkpoint every epoch (disregarding `every_n_epochs`) instead of only once at the end of training ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418))
386+
387+
385388
- The strategies that support `sync_batchnorm` now only apply it when fitting ([#11919](https://github.com/PyTorchLightning/pytorch-lightning/pull/11919))
386389

387390

@@ -864,6 +867,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
864867
- Fixed the case where `logger=None` is passed to the Trainer ([#12249](https://github.com/PyTorchLightning/pytorch-lightning/pull/12249))
865868

866869

870+
- Fixed bug where the global step tracked by `ModelCheckpoint` was still set even if no checkpoint was saved ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418))
871+
-
872+
873+
- Fixed bug where `ModelCheckpoint` was overriding the `epoch` and `step` logged values ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418))
874+
875+
876+
- Fixed bug where monitoring the default `epoch` and `step` values with `ModelCheckpoint` would fail ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418))
877+
878+
867879
- Fixed initializing optimizers unnecessarily in `DDPFullyShardedStrategy` ([#12267](https://github.com/PyTorchLightning/pytorch-lightning/pull/12267))
868880

869881

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 65 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ class ModelCheckpoint(Callback):
100100
based on either the maximization or the minimization of the monitored quantity.
101101
For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc.
102102
auto_insert_metric_name: When ``True``, the checkpoints filenames will contain the metric name.
103-
For example, ``filename='checkpoint_{epoch:02d}-{acc:02d}`` with epoch 1 and acc 80 will resolve to
104-
``checkpoint_epoch=01-acc=80.ckp``. Is useful to set it to ``False`` when metric names contain ``/``
103+
For example, ``filename='checkpoint_{epoch:02d}-{acc:02.0f}`` with epoch ``1`` and acc ``1.12`` will resolve
104+
to ``checkpoint_epoch=01-acc=01.ckpt``. Is useful to set it to ``False`` when metric names contain ``/``
105105
as this will result in extra folders.
106106
save_weights_only: if ``True``, then only the model's weights will be
107107
saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
@@ -116,7 +116,8 @@ class ModelCheckpoint(Callback):
116116
This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
117117
every_n_epochs: Number of epochs between checkpoints.
118118
This value must be ``None`` or non-negative.
119-
To disable saving after each epoch, set ``every_n_epochs = 0``.
119+
To disable saving top-k checkpoints, set ``every_n_epochs = 0``.
120+
This argument does not impact the saving of ``save_last=True`` checkpoints.
120121
If all of ``every_n_epochs``, ``every_n_train_steps`` and
121122
``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch
122123
(equivalent to ``every_n_epochs = 1``).
@@ -295,28 +296,25 @@ def on_train_batch_end(
295296
if not skip_time:
296297
self._last_time_checked = now
297298

298-
self.save_checkpoint(trainer)
299+
monitor_candidates = self._monitor_candidates(trainer)
300+
self._save_topk_checkpoint(trainer, monitor_candidates)
301+
self._save_last_checkpoint(trainer, monitor_candidates)
299302

300303
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
301304
"""Save a checkpoint at the end of the training epoch."""
302-
if (
303-
not self._should_skip_saving_checkpoint(trainer)
304-
and self._save_on_train_epoch_end
305-
and self._every_n_epochs > 0
306-
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
307-
):
308-
self.save_checkpoint(trainer)
305+
if not self._should_skip_saving_checkpoint(trainer) and self._save_on_train_epoch_end:
306+
monitor_candidates = self._monitor_candidates(trainer)
307+
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
308+
self._save_topk_checkpoint(trainer, monitor_candidates)
309+
self._save_last_checkpoint(trainer, monitor_candidates)
309310

310311
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
311312
"""Save a checkpoint at the end of the validation stage."""
312-
if (
313-
self._should_skip_saving_checkpoint(trainer)
314-
or self._save_on_train_epoch_end
315-
or self._every_n_epochs < 1
316-
or (trainer.current_epoch + 1) % self._every_n_epochs != 0
317-
):
318-
return
319-
self.save_checkpoint(trainer)
313+
if not self._should_skip_saving_checkpoint(trainer) and not self._save_on_train_epoch_end:
314+
monitor_candidates = self._monitor_candidates(trainer)
315+
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
316+
self._save_topk_checkpoint(trainer, monitor_candidates)
317+
self._save_last_checkpoint(trainer, monitor_candidates)
320318

321319
def on_save_checkpoint(
322320
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
@@ -352,26 +350,41 @@ def on_load_checkpoint(
352350
self.last_model_path = callback_state.get("last_model_path", self.last_model_path)
353351
self.best_model_path = callback_state["best_model_path"]
354352

355-
def save_checkpoint(self, trainer: "pl.Trainer") -> None:
353+
def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover
356354
"""Performs the main logic around saving a checkpoint.
357355
358356
This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the
359357
behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
360358
"""
361-
self._validate_monitor_key(trainer)
362-
363-
# what can be monitored
364-
monitor_candidates = self._monitor_candidates(trainer, epoch=trainer.current_epoch, step=trainer.global_step)
365-
366-
# callback supports multiple simultaneous modes
367-
# here we call each mode sequentially
368-
# Mode 1: save the top k checkpoints
369-
self._save_top_k_checkpoint(trainer, monitor_candidates)
370-
# Mode 2: save monitor=None checkpoints
371-
self._save_none_monitor_checkpoint(trainer, monitor_candidates)
372-
# Mode 3: save last checkpoints
359+
# TODO: unused method. deprecate it
360+
monitor_candidates = self._monitor_candidates(trainer)
361+
self._save_topk_checkpoint(trainer, monitor_candidates)
373362
self._save_last_checkpoint(trainer, monitor_candidates)
374363

364+
def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
365+
if self.save_top_k == 0:
366+
return
367+
368+
# validate metric
369+
if self.monitor is not None:
370+
if self.monitor not in monitor_candidates:
371+
m = (
372+
f"`ModelCheckpoint(monitor={self.monitor!r})` could not find the monitored key in the returned"
373+
f" metrics: {list(monitor_candidates)}."
374+
f" HINT: Did you call `log({self.monitor!r}, value)` in the `LightningModule`?"
375+
)
376+
if trainer.fit_loop.epoch_loop.val_loop._has_run:
377+
raise MisconfigurationException(m)
378+
warning_cache.warn(m)
379+
self._save_monitor_checkpoint(trainer, monitor_candidates)
380+
else:
381+
self._save_none_monitor_checkpoint(trainer, monitor_candidates)
382+
383+
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
384+
trainer.save_checkpoint(filepath, self.save_weights_only)
385+
386+
self._last_global_step_saved = trainer.global_step
387+
375388
# notify loggers
376389
if trainer.is_global_zero:
377390
for logger in trainer.loggers:
@@ -594,21 +607,6 @@ def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
594607
if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
595608
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
596609

597-
def _validate_monitor_key(self, trainer: "pl.Trainer") -> None:
598-
metrics = trainer.callback_metrics
599-
600-
# validate metric
601-
if self.monitor is not None and not self._is_valid_monitor_key(metrics):
602-
m = (
603-
f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:"
604-
f" {list(metrics.keys())}. "
605-
f"HINT: Did you call self.log('{self.monitor}', value) in the LightningModule?"
606-
)
607-
if not trainer.fit_loop.epoch_loop.val_loop._has_run:
608-
warning_cache.warn(m)
609-
else:
610-
raise MisconfigurationException(m)
611-
612610
def _get_metric_interpolated_filepath_name(
613611
self, monitor_candidates: Dict[str, _METRIC], trainer: "pl.Trainer", del_filepath: Optional[str] = None
614612
) -> str:
@@ -621,51 +619,46 @@ def _get_metric_interpolated_filepath_name(
621619

622620
return filepath
623621

624-
def _monitor_candidates(self, trainer: "pl.Trainer", epoch: int, step: int) -> Dict[str, _METRIC]:
622+
def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]:
625623
monitor_candidates = deepcopy(trainer.callback_metrics)
626-
monitor_candidates.update(epoch=epoch, step=step)
624+
# cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor
625+
# or does not exist we overwrite it as it's likely an error
626+
epoch = monitor_candidates.get("epoch")
627+
monitor_candidates["epoch"] = (
628+
epoch.int() if isinstance(epoch, torch.Tensor) else torch.tensor(trainer.current_epoch)
629+
)
630+
step = monitor_candidates.get("step")
631+
monitor_candidates["step"] = step.int() if isinstance(step, torch.Tensor) else torch.tensor(trainer.global_step)
627632
return monitor_candidates
628633

629634
def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
630635
if not self.save_last:
631636
return
632-
self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step)
633637

634638
filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST)
635639
# set the last model path before saving because it will be part of the state.
636640
previous, self.last_model_path = self.last_model_path, filepath
637-
trainer.save_checkpoint(filepath, self.save_weights_only)
641+
self._save_checkpoint(trainer, filepath)
638642
if previous and previous != filepath:
639643
trainer.strategy.remove_checkpoint(previous)
640644

641-
def _save_top_k_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
642-
if self.monitor is None or self.save_top_k == 0:
643-
return
644-
self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step)
645-
645+
def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
646646
current = monitor_candidates.get(self.monitor)
647647
if self.check_monitor_top_k(trainer, current):
648648
self._update_best_and_save(current, trainer, monitor_candidates)
649649
elif self.verbose:
650-
epoch = monitor_candidates.get("epoch")
651-
step = monitor_candidates.get("step")
652-
rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor} was not in top {self.save_top_k}")
650+
epoch = monitor_candidates["epoch"]
651+
step = monitor_candidates["step"]
652+
rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}")
653653

654654
def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
655-
if self.monitor is not None or self.save_top_k == 0:
656-
return
657-
self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step)
658-
659655
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
660656
# set the best model path before saving because it will be part of the state.
661657
previous, self.best_model_path = self.best_model_path, filepath
662-
trainer.save_checkpoint(filepath, self.save_weights_only)
658+
self._save_checkpoint(trainer, filepath)
663659
if self.save_top_k == 1 and previous and previous != filepath:
664660
trainer.strategy.remove_checkpoint(previous)
665661

666-
def _is_valid_monitor_key(self, metrics: Dict[str, _METRIC]) -> bool:
667-
return self.monitor in metrics or len(metrics) == 0
668-
669662
def _update_best_and_save(
670663
self, current: torch.Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]
671664
) -> None:
@@ -697,13 +690,13 @@ def _update_best_and_save(
697690
self.best_model_score = self.best_k_models[self.best_model_path]
698691

699692
if self.verbose:
700-
epoch = monitor_candidates.get("epoch")
701-
step = monitor_candidates.get("step")
693+
epoch = monitor_candidates["epoch"]
694+
step = monitor_candidates["step"]
702695
rank_zero_info(
703-
f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}"
704-
f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}'
696+
f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} reached {current:0.5f}"
697+
f" (best {self.best_model_score:0.5f}), saving model to {filepath!r} as top {k}"
705698
)
706-
trainer.save_checkpoint(filepath, self.save_weights_only)
699+
self._save_checkpoint(trainer, filepath)
707700

708701
if del_filepath is not None and filepath != del_filepath:
709702
trainer.strategy.remove_checkpoint(del_filepath)

tests/checkpointing/test_model_checkpoint.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ def training_step(self, batch, batch_idx):
109109
def validation_step(self, batch, batch_idx):
110110
log_value = self.val_logs[self.current_epoch, batch_idx]
111111
self.log("val_log", log_value)
112-
self.log("epoch", self.current_epoch, on_epoch=True)
113112
return super().validation_step(batch, batch_idx)
114113

115114
def configure_optimizers(self):
@@ -1086,7 +1085,7 @@ def __init__(self, hparams):
10861085
super().__init__()
10871086
self.save_hyperparameters(hparams)
10881087

1089-
model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=1, monitor="foo")
1088+
model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=1)
10901089
trainer = Trainer(
10911090
max_epochs=1,
10921091
default_root_dir=tmpdir,
@@ -1281,3 +1280,24 @@ def test_last_global_step_saved():
12811280
trainer.callback_metrics = {"foo": 123}
12821281
model_checkpoint.save_checkpoint(trainer)
12831282
assert model_checkpoint._last_global_step_saved == 0
1283+
1284+
1285+
@pytest.mark.parametrize("every_n_epochs", (0, 5))
1286+
def test_save_last_every_n_epochs_interaction(tmpdir, every_n_epochs):
1287+
"""Test that `save_last` ignores `every_n_epochs`."""
1288+
mc = ModelCheckpoint(every_n_epochs=every_n_epochs, save_last=True, save_top_k=0, save_on_train_epoch_end=True)
1289+
trainer = Trainer(
1290+
default_root_dir=tmpdir,
1291+
max_epochs=2,
1292+
callbacks=mc,
1293+
limit_train_batches=1,
1294+
limit_val_batches=0,
1295+
enable_progress_bar=False,
1296+
enable_model_summary=False,
1297+
logger=False,
1298+
)
1299+
model = BoringModel()
1300+
with patch.object(trainer, "save_checkpoint") as save_mock:
1301+
trainer.fit(model)
1302+
assert mc.last_model_path # a "last" ckpt was saved
1303+
assert save_mock.call_count == trainer.max_epochs

tests/loggers/test_mlflow.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,7 @@ def test_mlflow_logger_dirs_creation(tmpdir):
159159
assert set(os.listdir(tmpdir / exp_id)) == {run_id, "meta.yaml"}
160160

161161
class CustomModel(BoringModel):
162-
def training_epoch_end(self, *args, **kwargs):
163-
super().training_epoch_end(*args, **kwargs)
162+
def on_train_epoch_end(self, *args, **kwargs):
164163
self.log("epoch", self.current_epoch)
165164

166165
model = CustomModel()

tests/models/test_restore.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def validation_step_end(self, outputs):
106106
def test_model_properties_fit_ckpt_path(tmpdir):
107107
"""Test that properties like `current_epoch` and `global_step` in model and trainer are always the same."""
108108
model = BoringModel()
109-
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True)
109+
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_last=True)
110110
trainer_args = dict(
111111
default_root_dir=tmpdir,
112112
max_epochs=1,

tests/trainer/test_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ def test_model_freeze_unfreeze():
403403
assert param.requires_grad
404404

405405

406+
@pytest.mark.xfail(reason="FIXME(@carmocca): this test wasn't running and is now broken")
406407
@pytest.mark.parametrize("url_ckpt", [True, False])
407408
def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
408409
"""Verify resuming from checkpoint runs the right number of epochs."""
@@ -429,7 +430,7 @@ def on_load_checkpoint(self, _):
429430
max_epochs=2,
430431
limit_train_batches=0.65,
431432
limit_val_batches=1,
432-
callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_top_k=-1)],
433+
callbacks=[ModelCheckpoint(dirpath=tmpdir, save_top_k=-1)],
433434
default_root_dir=tmpdir,
434435
val_check_interval=1.0,
435436
enable_progress_bar=False,
@@ -449,6 +450,7 @@ def on_load_checkpoint(self, _):
449450
ip, port = tmpdir_server
450451
checkpoints = [f"http://{ip}:{port}/" + ckpt.name for ckpt in checkpoints]
451452

453+
assert checkpoints
452454
for ckpt in checkpoints:
453455
next_model = TestModel()
454456
state = pl_load(ckpt)

0 commit comments

Comments
 (0)