Skip to content

ModelCheckpoint's save_last now ignores every_n_epochs #12418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 24, 2022
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Marked `trainer.logger_connector` as protected ([#12195](https://github.com/PyTorchLightning/pytorch-lightning/pull/12195))


- `ModelCheckpoint(save_last=True, every_n_epochs=N)` now saves a "last" checkpoint every epoch (disregarding `every_n_epochs`) instead of only once at the end of training ([#TODO](https://github.com/PyTorchLightning/pytorch-lightning/pull/TODO))


- The strategies that support `sync_batchnorm` now only apply it when fitting ([#11919](https://github.com/PyTorchLightning/pytorch-lightning/pull/11919))


Expand Down Expand Up @@ -830,6 +833,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the case where `logger=None` is passed to the Trainer ([#12249](https://github.com/PyTorchLightning/pytorch-lightning/pull/12249))


- Fixed bug where the global step tracked by `ModelCheckpoint` was still set even if no checkpoint was saved ([#TODO](https://github.com/PyTorchLightning/pytorch-lightning/pull/TODO))


- Fixed bug where `ModelCheckpoint` was overriding the `epoch` and `step` logged values ([#TODO](https://github.com/PyTorchLightning/pytorch-lightning/pull/TODO))


- Fixed bug where monitoring the default `epoch` and `step` values with `ModelCheckpoint` would fail ([#TODO](https://github.com/PyTorchLightning/pytorch-lightning/pull/TODO))


- Fixed initializing optimizers unnecessarily in `DDPFullyShardedStrategy` ([#12267](https://github.com/PyTorchLightning/pytorch-lightning/pull/12267))


Expand Down
149 changes: 48 additions & 101 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,10 @@ class ModelCheckpoint(Callback):
every_n_train_steps: Number of training steps between checkpoints.
If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training.
To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative.
This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
train_time_interval: Checkpoints are monitored at the specified time interval.
For all practical purposes, this cannot be smaller than the amount
of time it takes to process a single training batch. This is not
guaranteed to execute at the exact time specified, but should be close.
This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
every_n_epochs: Number of epochs between checkpoints.
This value must be ``None`` or non-negative.
To disable saving after each epoch, set ``every_n_epochs = 0``.
Expand All @@ -123,11 +121,11 @@ class ModelCheckpoint(Callback):
If ``every_n_epochs == None`` and either ``every_n_train_steps != None`` or ``train_time_interval != None``,
saving at the end of each epoch is disabled
(equivalent to ``every_n_epochs = 0``).
This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``.
Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and
``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
will only save checkpoints at epochs 0 < E <= N
where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
This argument does not impact the saving of ``save_last=True`` checkpoints.
save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch.
If this is ``False``, then the check runs at the end of the validation.

Expand Down Expand Up @@ -238,7 +236,9 @@ def __init__(

self.__init_monitor_mode(mode)
self.__init_ckpt_dir(dirpath, filename)
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval)
self._train_time_interval = train_time_interval
self._every_n_epochs = 1 if every_n_epochs is None else every_n_epochs
self._every_n_train_steps = 0 if every_n_train_steps is None else every_n_train_steps
self.__validate_init_configuration()

@property
Expand Down Expand Up @@ -299,24 +299,13 @@ def on_train_batch_end(

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Save a checkpoint at the end of the training epoch."""
if (
not self._should_skip_saving_checkpoint(trainer)
and self._save_on_train_epoch_end
and self._every_n_epochs > 0
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
):
if not self._should_skip_saving_checkpoint(trainer) and self._save_on_train_epoch_end:
self.save_checkpoint(trainer)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Save a checkpoint at the end of the validation stage."""
if (
self._should_skip_saving_checkpoint(trainer)
or self._save_on_train_epoch_end
or self._every_n_epochs < 1
or (trainer.current_epoch + 1) % self._every_n_epochs != 0
):
return
self.save_checkpoint(trainer)
if not self._should_skip_saving_checkpoint(trainer) and not self._save_on_train_epoch_end:
self.save_checkpoint(trainer)

def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
Expand Down Expand Up @@ -358,20 +347,35 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None:
This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the
behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
"""
self._validate_monitor_key(trainer)

# what can be monitored
monitor_candidates = self._monitor_candidates(trainer, epoch=trainer.current_epoch, step=trainer.global_step)

# callback supports multiple simultaneous modes
# here we call each mode sequentially
# Mode 1: save the top k checkpoints
self._save_top_k_checkpoint(trainer, monitor_candidates)
# Mode 2: save monitor=None checkpoints
self._save_none_monitor_checkpoint(trainer, monitor_candidates)
# Mode 3: save last checkpoints
monitor_candidates = self._monitor_candidates(trainer)
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)

def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
if self.save_top_k == 0 or self._every_n_epochs < 1 or (trainer.current_epoch + 1) % self._every_n_epochs != 0:
# `every_n_epochs` only applies to monitored checkpoints
return

# validate metric
if self.monitor is not None:
if self.monitor not in monitor_candidates:
m = (
f"`ModelCheckpoint(monitor='{self.monitor}')` could not find the monitored key in the returned"
f" metrics: {list(monitor_candidates)}."
f" HINT: Did you call `log({self.monitor!r}, value)` in the `LightningModule`?"
)
if trainer.fit_loop.epoch_loop.val_loop._has_run:
raise MisconfigurationException(m)
warning_cache.warn(m)
self._save_monitor_checkpoint(trainer, monitor_candidates)
else:
self._save_none_monitor_checkpoint(trainer, monitor_candidates)

def _actual_save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
trainer.save_checkpoint(filepath, self.save_weights_only)

self._last_global_step_saved = trainer.global_step

# notify loggers
if trainer.is_global_zero:
for logger in trainer.loggers:
Expand All @@ -397,16 +401,6 @@ def __validate_init_configuration(self) -> None:
if self._every_n_epochs < 0:
raise MisconfigurationException(f"Invalid value for every_n_epochs={self._every_n_epochs}. Must be >= 0")

every_n_train_steps_triggered = self._every_n_train_steps >= 1
every_n_epochs_triggered = self._every_n_epochs >= 1
train_time_interval_triggered = self._train_time_interval is not None
if every_n_train_steps_triggered + every_n_epochs_triggered + train_time_interval_triggered > 1:
raise MisconfigurationException(
f"Combination of parameters every_n_train_steps={self._every_n_train_steps}, "
f"every_n_epochs={self._every_n_epochs} and train_time_interval={self._train_time_interval} "
"should be mutually exclusive."
)

if self.monitor is None:
# -1: save all epochs, 0: nothing is saved, 1: save last epoch
if self.save_top_k not in (-1, 0, 1):
Expand Down Expand Up @@ -439,27 +433,6 @@ def __init_monitor_mode(self, mode: str) -> None:

self.kth_value, self.mode = mode_dict[mode]

def __init_triggers(
self,
every_n_train_steps: Optional[int],
every_n_epochs: Optional[int],
train_time_interval: Optional[timedelta],
) -> None:

# Default to running once after each validation epoch if neither
# every_n_train_steps nor every_n_epochs is set
if every_n_train_steps is None and every_n_epochs is None and train_time_interval is None:
every_n_epochs = 1
every_n_train_steps = 0
log.debug("Both every_n_train_steps and every_n_epochs are not set. Setting every_n_epochs=1")
else:
every_n_epochs = every_n_epochs or 0
every_n_train_steps = every_n_train_steps or 0

self._train_time_interval: Optional[timedelta] = train_time_interval
self._every_n_epochs: int = every_n_epochs
self._every_n_train_steps: int = every_n_train_steps

@property
def every_n_epochs(self) -> Optional[int]:
return self._every_n_epochs
Expand Down Expand Up @@ -594,21 +567,6 @@ def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

def _validate_monitor_key(self, trainer: "pl.Trainer") -> None:
metrics = trainer.callback_metrics

# validate metric
if self.monitor is not None and not self._is_valid_monitor_key(metrics):
m = (
f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:"
f" {list(metrics.keys())}. "
f"HINT: Did you call self.log('{self.monitor}', value) in the LightningModule?"
)
if not trainer.fit_loop.epoch_loop.val_loop._has_run:
warning_cache.warn(m)
else:
raise MisconfigurationException(m)

def _get_metric_interpolated_filepath_name(
self, monitor_candidates: Dict[str, _METRIC], trainer: "pl.Trainer", del_filepath: Optional[str] = None
) -> str:
Expand All @@ -621,51 +579,40 @@ def _get_metric_interpolated_filepath_name(

return filepath

def _monitor_candidates(self, trainer: "pl.Trainer", epoch: int, step: int) -> Dict[str, _METRIC]:
def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]:
monitor_candidates = deepcopy(trainer.callback_metrics)
monitor_candidates.update(epoch=epoch, step=step)
monitor_candidates.setdefault("epoch", torch.tensor(trainer.current_epoch))
monitor_candidates.setdefault("step", torch.tensor(trainer.global_step))
return monitor_candidates

def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
if not self.save_last:
return
self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step)

filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST)
# set the last model path before saving because it will be part of the state.
previous, self.last_model_path = self.last_model_path, filepath
trainer.save_checkpoint(filepath, self.save_weights_only)
self._actual_save_checkpoint(trainer, filepath)
if previous and previous != filepath:
trainer.strategy.remove_checkpoint(previous)

def _save_top_k_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
if self.monitor is None or self.save_top_k == 0:
return
self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step)

def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
current = monitor_candidates.get(self.monitor)
if self.check_monitor_top_k(trainer, current):
self._update_best_and_save(current, trainer, monitor_candidates)
elif self.verbose:
epoch = monitor_candidates.get("epoch")
step = monitor_candidates.get("step")
rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor} was not in top {self.save_top_k}")
epoch = monitor_candidates["epoch"]
step = monitor_candidates["step"]
rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}")

def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
if self.monitor is not None or self.save_top_k == 0:
return
self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step)

filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
# set the best model path before saving because it will be part of the state.
previous, self.best_model_path = self.best_model_path, filepath
trainer.save_checkpoint(filepath, self.save_weights_only)
self._actual_save_checkpoint(trainer, filepath)
if self.save_top_k == 1 and previous and previous != filepath:
trainer.strategy.remove_checkpoint(previous)

def _is_valid_monitor_key(self, metrics: Dict[str, _METRIC]) -> bool:
return self.monitor in metrics or len(metrics) == 0

def _update_best_and_save(
self, current: torch.Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]
) -> None:
Expand Down Expand Up @@ -697,13 +644,13 @@ def _update_best_and_save(
self.best_model_score = self.best_k_models[self.best_model_path]

if self.verbose:
epoch = monitor_candidates.get("epoch")
step = monitor_candidates.get("step")
epoch = monitor_candidates["epoch"]
step = monitor_candidates["step"]
rank_zero_info(
f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}"
f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}'
f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} reached {current:0.5f}"
f" (best {self.best_model_score:0.5f}), saving model to {filepath!r} as top {k}"
)
trainer.save_checkpoint(filepath, self.save_weights_only)
self._actual_save_checkpoint(trainer, filepath)

if del_filepath is not None and filepath != del_filepath:
trainer.strategy.remove_checkpoint(del_filepath)
Expand Down
Loading