From c9eaa4f637b6a291eb55233363111126fea1e38d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 23 Mar 2022 03:24:35 +0100 Subject: [PATCH 01/11] Fix `ModelCheckpoint` trigger interactions --- CHANGELOG.md | 12 ++ .../callbacks/model_checkpoint.py | 149 ++++++------------ tests/checkpointing/test_model_checkpoint.py | 52 +++--- 3 files changed, 87 insertions(+), 126 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a79f57d27e0e5..17e2120e1f273 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -368,6 +368,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Marked `trainer.logger_connector` as protected ([#12195](https://github.com/PyTorchLightning/pytorch-lightning/pull/12195)) +- `ModelCheckpoint(save_last=True, every_n_epochs=N)` now saves a "last" checkpoint every epoch (disregarding `every_n_epochs`) instead of only once at the end of training ([#TODO](https://github.com/PyTorchLightning/pytorch-lightning/pull/TODO)) + + - The strategies that support `sync_batchnorm` now only apply it when fitting ([#11919](https://github.com/PyTorchLightning/pytorch-lightning/pull/11919)) @@ -830,6 +833,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the case where `logger=None` is passed to the Trainer ([#12249](https://github.com/PyTorchLightning/pytorch-lightning/pull/12249)) +- Fixed bug where the global step tracked by `ModelCheckpoint` was still set even if no checkpoint was saved ([#TODO](https://github.com/PyTorchLightning/pytorch-lightning/pull/TODO)) + + +- Fixed bug where `ModelCheckpoint` was overriding the `epoch` and `step` logged values ([#TODO](https://github.com/PyTorchLightning/pytorch-lightning/pull/TODO)) + + +- Fixed bug where monitoring the default `epoch` and `step` values with `ModelCheckpoint` would fail ([#TODO](https://github.com/PyTorchLightning/pytorch-lightning/pull/TODO)) + + - Fixed initializing optimizers unnecessarily in `DDPFullyShardedStrategy` ([#12267](https://github.com/PyTorchLightning/pytorch-lightning/pull/12267)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 57ed9098bd211..8efcb247d817e 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -108,12 +108,10 @@ class ModelCheckpoint(Callback): every_n_train_steps: Number of training steps between checkpoints. If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training. To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative. - This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``. train_time_interval: Checkpoints are monitored at the specified time interval. For all practical purposes, this cannot be smaller than the amount of time it takes to process a single training batch. This is not guaranteed to execute at the exact time specified, but should be close. - This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``. every_n_epochs: Number of epochs between checkpoints. This value must be ``None`` or non-negative. To disable saving after each epoch, set ``every_n_epochs = 0``. @@ -123,11 +121,11 @@ class ModelCheckpoint(Callback): If ``every_n_epochs == None`` and either ``every_n_train_steps != None`` or ``train_time_interval != None``, saving at the end of each epoch is disabled (equivalent to ``every_n_epochs = 0``). - This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``. Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and ``Trainer(max_epochs=N, check_val_every_n_epoch=M)`` will only save checkpoints at epochs 0 < E <= N where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E. + This argument does not impact the saving of ``save_last=True`` checkpoints. save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch. If this is ``False``, then the check runs at the end of the validation. @@ -238,7 +236,9 @@ def __init__( self.__init_monitor_mode(mode) self.__init_ckpt_dir(dirpath, filename) - self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval) + self._train_time_interval = train_time_interval + self._every_n_epochs = 1 if every_n_epochs is None else every_n_epochs + self._every_n_train_steps = 0 if every_n_train_steps is None else every_n_train_steps self.__validate_init_configuration() @property @@ -299,24 +299,13 @@ def on_train_batch_end( def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Save a checkpoint at the end of the training epoch.""" - if ( - not self._should_skip_saving_checkpoint(trainer) - and self._save_on_train_epoch_end - and self._every_n_epochs > 0 - and (trainer.current_epoch + 1) % self._every_n_epochs == 0 - ): + if not self._should_skip_saving_checkpoint(trainer) and self._save_on_train_epoch_end: self.save_checkpoint(trainer) def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Save a checkpoint at the end of the validation stage.""" - if ( - self._should_skip_saving_checkpoint(trainer) - or self._save_on_train_epoch_end - or self._every_n_epochs < 1 - or (trainer.current_epoch + 1) % self._every_n_epochs != 0 - ): - return - self.save_checkpoint(trainer) + if not self._should_skip_saving_checkpoint(trainer) and not self._save_on_train_epoch_end: + self.save_checkpoint(trainer) def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] @@ -358,20 +347,35 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None: This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases. """ - self._validate_monitor_key(trainer) - - # what can be monitored - monitor_candidates = self._monitor_candidates(trainer, epoch=trainer.current_epoch, step=trainer.global_step) - - # callback supports multiple simultaneous modes - # here we call each mode sequentially - # Mode 1: save the top k checkpoints - self._save_top_k_checkpoint(trainer, monitor_candidates) - # Mode 2: save monitor=None checkpoints - self._save_none_monitor_checkpoint(trainer, monitor_candidates) - # Mode 3: save last checkpoints + monitor_candidates = self._monitor_candidates(trainer) + self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) + def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: + if self.save_top_k == 0 or self._every_n_epochs < 1 or (trainer.current_epoch + 1) % self._every_n_epochs != 0: + # `every_n_epochs` only applies to monitored checkpoints + return + + # validate metric + if self.monitor is not None: + if self.monitor not in monitor_candidates: + m = ( + f"`ModelCheckpoint(monitor='{self.monitor}')` could not find the monitored key in the returned" + f" metrics: {list(monitor_candidates)}." + f" HINT: Did you call `log({self.monitor!r}, value)` in the `LightningModule`?" + ) + if trainer.fit_loop.epoch_loop.val_loop._has_run: + raise MisconfigurationException(m) + warning_cache.warn(m) + self._save_monitor_checkpoint(trainer, monitor_candidates) + else: + self._save_none_monitor_checkpoint(trainer, monitor_candidates) + + def _actual_save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: + trainer.save_checkpoint(filepath, self.save_weights_only) + + self._last_global_step_saved = trainer.global_step + # notify loggers if trainer.is_global_zero: for logger in trainer.loggers: @@ -397,16 +401,6 @@ def __validate_init_configuration(self) -> None: if self._every_n_epochs < 0: raise MisconfigurationException(f"Invalid value for every_n_epochs={self._every_n_epochs}. Must be >= 0") - every_n_train_steps_triggered = self._every_n_train_steps >= 1 - every_n_epochs_triggered = self._every_n_epochs >= 1 - train_time_interval_triggered = self._train_time_interval is not None - if every_n_train_steps_triggered + every_n_epochs_triggered + train_time_interval_triggered > 1: - raise MisconfigurationException( - f"Combination of parameters every_n_train_steps={self._every_n_train_steps}, " - f"every_n_epochs={self._every_n_epochs} and train_time_interval={self._train_time_interval} " - "should be mutually exclusive." - ) - if self.monitor is None: # -1: save all epochs, 0: nothing is saved, 1: save last epoch if self.save_top_k not in (-1, 0, 1): @@ -439,27 +433,6 @@ def __init_monitor_mode(self, mode: str) -> None: self.kth_value, self.mode = mode_dict[mode] - def __init_triggers( - self, - every_n_train_steps: Optional[int], - every_n_epochs: Optional[int], - train_time_interval: Optional[timedelta], - ) -> None: - - # Default to running once after each validation epoch if neither - # every_n_train_steps nor every_n_epochs is set - if every_n_train_steps is None and every_n_epochs is None and train_time_interval is None: - every_n_epochs = 1 - every_n_train_steps = 0 - log.debug("Both every_n_train_steps and every_n_epochs are not set. Setting every_n_epochs=1") - else: - every_n_epochs = every_n_epochs or 0 - every_n_train_steps = every_n_train_steps or 0 - - self._train_time_interval: Optional[timedelta] = train_time_interval - self._every_n_epochs: int = every_n_epochs - self._every_n_train_steps: int = every_n_train_steps - @property def every_n_epochs(self) -> Optional[int]: return self._every_n_epochs @@ -594,21 +567,6 @@ def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0: rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") - def _validate_monitor_key(self, trainer: "pl.Trainer") -> None: - metrics = trainer.callback_metrics - - # validate metric - if self.monitor is not None and not self._is_valid_monitor_key(metrics): - m = ( - f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" - f" {list(metrics.keys())}. " - f"HINT: Did you call self.log('{self.monitor}', value) in the LightningModule?" - ) - if not trainer.fit_loop.epoch_loop.val_loop._has_run: - warning_cache.warn(m) - else: - raise MisconfigurationException(m) - def _get_metric_interpolated_filepath_name( self, monitor_candidates: Dict[str, _METRIC], trainer: "pl.Trainer", del_filepath: Optional[str] = None ) -> str: @@ -621,51 +579,40 @@ def _get_metric_interpolated_filepath_name( return filepath - def _monitor_candidates(self, trainer: "pl.Trainer", epoch: int, step: int) -> Dict[str, _METRIC]: + def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]: monitor_candidates = deepcopy(trainer.callback_metrics) - monitor_candidates.update(epoch=epoch, step=step) + monitor_candidates.setdefault("epoch", torch.tensor(trainer.current_epoch)) + monitor_candidates.setdefault("step", torch.tensor(trainer.global_step)) return monitor_candidates def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: if not self.save_last: return - self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step) filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST) # set the last model path before saving because it will be part of the state. previous, self.last_model_path = self.last_model_path, filepath - trainer.save_checkpoint(filepath, self.save_weights_only) + self._actual_save_checkpoint(trainer, filepath) if previous and previous != filepath: trainer.strategy.remove_checkpoint(previous) - def _save_top_k_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: - if self.monitor is None or self.save_top_k == 0: - return - self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step) - + def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: current = monitor_candidates.get(self.monitor) if self.check_monitor_top_k(trainer, current): self._update_best_and_save(current, trainer, monitor_candidates) elif self.verbose: - epoch = monitor_candidates.get("epoch") - step = monitor_candidates.get("step") - rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor} was not in top {self.save_top_k}") + epoch = monitor_candidates["epoch"] + step = monitor_candidates["step"] + rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}") def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: - if self.monitor is not None or self.save_top_k == 0: - return - self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step) - filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer) # set the best model path before saving because it will be part of the state. previous, self.best_model_path = self.best_model_path, filepath - trainer.save_checkpoint(filepath, self.save_weights_only) + self._actual_save_checkpoint(trainer, filepath) if self.save_top_k == 1 and previous and previous != filepath: trainer.strategy.remove_checkpoint(previous) - def _is_valid_monitor_key(self, metrics: Dict[str, _METRIC]) -> bool: - return self.monitor in metrics or len(metrics) == 0 - def _update_best_and_save( self, current: torch.Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC] ) -> None: @@ -697,13 +644,13 @@ def _update_best_and_save( self.best_model_score = self.best_k_models[self.best_model_path] if self.verbose: - epoch = monitor_candidates.get("epoch") - step = monitor_candidates.get("step") + epoch = monitor_candidates["epoch"] + step = monitor_candidates["step"] rank_zero_info( - f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}" - f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}' + f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} reached {current:0.5f}" + f" (best {self.best_model_score:0.5f}), saving model to {filepath!r} as top {k}" ) - trainer.save_checkpoint(filepath, self.save_weights_only) + self._actual_save_checkpoint(trainer, filepath) if del_filepath is not None and filepath != del_filepath: trainer.strategy.remove_checkpoint(del_filepath) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 3dadf0b733a74..e3857d03bf9a7 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -109,7 +109,6 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): log_value = self.val_logs[self.current_epoch, batch_idx] self.log("val_log", log_value) - self.log("epoch", self.current_epoch, on_epoch=True) return super().validation_step(batch, batch_idx) def configure_optimizers(self): @@ -538,22 +537,6 @@ def test_invalid_every_n_train_steps(tmpdir): ModelCheckpoint(dirpath=tmpdir, every_n_epochs=2) -def test_invalid_trigger_combination(tmpdir): - """Test that a MisconfigurationException is raised if more than one of every_n_epochs, every_n_train_steps, and - train_time_interval are enabled together.""" - with pytest.raises(MisconfigurationException, match=r".*Combination of parameters every_n_train_steps"): - ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_epochs=2) - with pytest.raises(MisconfigurationException, match=r".*Combination of parameters every_n_train_steps"): - ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_epochs=2) - with pytest.raises(MisconfigurationException, match=r".*Combination of parameters every_n_train_steps"): - ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_train_steps=2) - - # These should not fail - ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_epochs=3) - ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_epochs=0) - ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_epochs=0, train_time_interval=timedelta(minutes=1)) - - def test_none_every_n_train_steps_val_epochs(tmpdir): checkpoint_callback = ModelCheckpoint(dirpath=tmpdir) assert checkpoint_callback.every_n_epochs == 1 @@ -625,7 +608,6 @@ def test_ckpt_every_n_train_steps(tmpdir): epoch_length = 64 checkpoint_callback = ModelCheckpoint( filename="{step}", - every_n_epochs=0, every_n_train_steps=every_n_train_steps, dirpath=tmpdir, save_top_k=-1, @@ -675,10 +657,9 @@ def test_model_checkpoint_train_time_interval(mock_datetime, tmpdir) -> None: ) trainer.fit(model) - # Each batch takes 7 sec and we checkpoint every minute. There are 64 - # batches per epoch, so total time to run is 7*64*2 = 896 sec < 14.96 minutes, - # so we should have 14 checkpoints. - assert len(os.listdir(tmpdir)) == 14 + # Each batch takes 7 sec and we checkpoint every minute. There are 64 batches per epoch, so total time to run is + # 7*64*2 = 896 sec < 14.96 minutes so we should have 14 checkpoints. +2 for those saved at the end of the each epoch + assert len(os.listdir(tmpdir)) == 14 + 2 def test_model_checkpoint_topk_zero(tmpdir): @@ -1086,7 +1067,7 @@ def __init__(self, hparams): super().__init__() self.save_hyperparameters(hparams) - model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=1, monitor="foo") + model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=1) trainer = Trainer( max_epochs=1, default_root_dir=tmpdir, @@ -1253,7 +1234,7 @@ def test_save_last_saves_correct_last_model_path(tmpdir): trainer = Trainer(callbacks=mc) trainer.strategy.connect(BoringModel()) - mc._save_last_checkpoint(trainer, {"foo": 1}) + mc._save_last_checkpoint(trainer, {"foo": 1, "step": 0}) expected = "foo=1-last.ckpt" assert os.listdir(tmpdir) == [expected] full_path = str(tmpdir / expected) @@ -1266,7 +1247,7 @@ def test_none_monitor_saves_correct_best_model_path(tmpdir): trainer = Trainer(callbacks=mc) trainer.strategy.connect(BoringModel()) - mc._save_none_monitor_checkpoint(trainer, {}) + mc._save_none_monitor_checkpoint(trainer, {"step": 0}) expected = "epoch=0-step=0.ckpt" assert os.listdir(tmpdir) == [expected] full_path = str(tmpdir / expected) @@ -1281,3 +1262,24 @@ def test_last_global_step_saved(): trainer.callback_metrics = {"foo": 123} model_checkpoint.save_checkpoint(trainer) assert model_checkpoint._last_global_step_saved == 0 + + +def test_save_last_every_n_epochs_interaction(tmpdir): + """Test that `save_last` ignores `every_n_epochs`.""" + mc = ModelCheckpoint(every_n_epochs=5, save_last=True, save_top_k=0, save_on_train_epoch_end=True) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + callbacks=mc, + limit_train_batches=1, + limit_val_batches=0, + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + ) + assert mc.every_n_epochs > trainer.max_epochs + model = BoringModel() + with patch.object(trainer, "save_checkpoint") as save_mock: + trainer.fit(model) + assert mc.last_model_path # a "last" ckpt was saved + assert save_mock.call_count == trainer.max_epochs From 9ba217deaefd2de23e2d934a36a793b626609dd7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 23 Mar 2022 03:30:15 +0100 Subject: [PATCH 02/11] Docs fix --- pytorch_lightning/callbacks/model_checkpoint.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 8efcb247d817e..7468ccfed94cb 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -114,14 +114,11 @@ class ModelCheckpoint(Callback): guaranteed to execute at the exact time specified, but should be close. every_n_epochs: Number of epochs between checkpoints. This value must be ``None`` or non-negative. - To disable saving after each epoch, set ``every_n_epochs = 0``. + To disable saving top-k checkpoints, set ``every_n_epochs = 0``. If all of ``every_n_epochs``, ``every_n_train_steps`` and ``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch (equivalent to ``every_n_epochs = 1``). - If ``every_n_epochs == None`` and either ``every_n_train_steps != None`` or ``train_time_interval != None``, - saving at the end of each epoch is disabled - (equivalent to ``every_n_epochs = 0``). - Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and + Setting both ``ModelCheckpoint(every_n_epochs=V, save_on_train_epoch_end=False)`` and ``Trainer(max_epochs=N, check_val_every_n_epoch=M)`` will only save checkpoints at epochs 0 < E <= N where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E. From 0dee5353474dcb79e0d1341fff90ab20ba28d509 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 23 Mar 2022 03:38:28 +0100 Subject: [PATCH 03/11] CHANGELOG --- CHANGELOG.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 17e2120e1f273..6400762af57d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -368,7 +368,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Marked `trainer.logger_connector` as protected ([#12195](https://github.com/PyTorchLightning/pytorch-lightning/pull/12195)) -- `ModelCheckpoint(save_last=True, every_n_epochs=N)` now saves a "last" checkpoint every epoch (disregarding `every_n_epochs`) instead of only once at the end of training ([#TODO](https://github.com/PyTorchLightning/pytorch-lightning/pull/TODO)) +- `ModelCheckpoint(save_last=True, every_n_epochs=N)` now saves a "last" checkpoint every epoch (disregarding `every_n_epochs`) instead of only once at the end of training ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418)) - The strategies that support `sync_batchnorm` now only apply it when fitting ([#11919](https://github.com/PyTorchLightning/pytorch-lightning/pull/11919)) @@ -833,13 +833,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the case where `logger=None` is passed to the Trainer ([#12249](https://github.com/PyTorchLightning/pytorch-lightning/pull/12249)) -- Fixed bug where the global step tracked by `ModelCheckpoint` was still set even if no checkpoint was saved ([#TODO](https://github.com/PyTorchLightning/pytorch-lightning/pull/TODO)) +- Fixed bug where the global step tracked by `ModelCheckpoint` was still set even if no checkpoint was saved ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418)) +- +- Fixed bug where `ModelCheckpoint` was overriding the `epoch` and `step` logged values ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418)) -- Fixed bug where `ModelCheckpoint` was overriding the `epoch` and `step` logged values ([#TODO](https://github.com/PyTorchLightning/pytorch-lightning/pull/TODO)) - -- Fixed bug where monitoring the default `epoch` and `step` values with `ModelCheckpoint` would fail ([#TODO](https://github.com/PyTorchLightning/pytorch-lightning/pull/TODO)) +- Fixed bug where monitoring the default `epoch` and `step` values with `ModelCheckpoint` would fail ([#12418](https://github.com/PyTorchLightning/pytorch-lightning/pull/12418)) - Fixed initializing optimizers unnecessarily in `DDPFullyShardedStrategy` ([#12267](https://github.com/PyTorchLightning/pytorch-lightning/pull/12267)) From 1bf091567bd77028e86a897aba69ff9f82d60cf1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 23 Mar 2022 03:51:36 +0100 Subject: [PATCH 04/11] Undo trigger changes, perhaps for another day :) --- .../callbacks/model_checkpoint.py | 55 ++++++++++++++++--- tests/checkpointing/test_model_checkpoint.py | 24 +++++++- 2 files changed, 69 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 7468ccfed94cb..b997410874504 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -108,17 +108,23 @@ class ModelCheckpoint(Callback): every_n_train_steps: Number of training steps between checkpoints. If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training. To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative. + This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``. train_time_interval: Checkpoints are monitored at the specified time interval. For all practical purposes, this cannot be smaller than the amount of time it takes to process a single training batch. This is not guaranteed to execute at the exact time specified, but should be close. + This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``. every_n_epochs: Number of epochs between checkpoints. This value must be ``None`` or non-negative. To disable saving top-k checkpoints, set ``every_n_epochs = 0``. If all of ``every_n_epochs``, ``every_n_train_steps`` and ``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch (equivalent to ``every_n_epochs = 1``). - Setting both ``ModelCheckpoint(every_n_epochs=V, save_on_train_epoch_end=False)`` and + If ``every_n_epochs == None`` and either ``every_n_train_steps != None`` or ``train_time_interval != None``, + saving at the end of each epoch is disabled + (equivalent to ``every_n_epochs = 0``). + This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``. + Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and ``Trainer(max_epochs=N, check_val_every_n_epoch=M)`` will only save checkpoints at epochs 0 < E <= N where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E. @@ -233,9 +239,7 @@ def __init__( self.__init_monitor_mode(mode) self.__init_ckpt_dir(dirpath, filename) - self._train_time_interval = train_time_interval - self._every_n_epochs = 1 if every_n_epochs is None else every_n_epochs - self._every_n_train_steps = 0 if every_n_train_steps is None else every_n_train_steps + self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval) self.__validate_init_configuration() @property @@ -297,12 +301,18 @@ def on_train_batch_end( def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Save a checkpoint at the end of the training epoch.""" if not self._should_skip_saving_checkpoint(trainer) and self._save_on_train_epoch_end: - self.save_checkpoint(trainer) + monitor_candidates = self._monitor_candidates(trainer) + if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: + self._save_topk_checkpoint(trainer, monitor_candidates) + self._save_last_checkpoint(trainer, monitor_candidates) def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Save a checkpoint at the end of the validation stage.""" if not self._should_skip_saving_checkpoint(trainer) and not self._save_on_train_epoch_end: - self.save_checkpoint(trainer) + monitor_candidates = self._monitor_candidates(trainer) + if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: + self._save_topk_checkpoint(trainer, monitor_candidates) + self._save_last_checkpoint(trainer, monitor_candidates) def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] @@ -349,7 +359,7 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None: self._save_last_checkpoint(trainer, monitor_candidates) def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: - if self.save_top_k == 0 or self._every_n_epochs < 1 or (trainer.current_epoch + 1) % self._every_n_epochs != 0: + if self.save_top_k == 0: # `every_n_epochs` only applies to monitored checkpoints return @@ -398,6 +408,16 @@ def __validate_init_configuration(self) -> None: if self._every_n_epochs < 0: raise MisconfigurationException(f"Invalid value for every_n_epochs={self._every_n_epochs}. Must be >= 0") + every_n_train_steps_triggered = self._every_n_train_steps >= 1 + every_n_epochs_triggered = self._every_n_epochs >= 1 + train_time_interval_triggered = self._train_time_interval is not None + if every_n_train_steps_triggered + every_n_epochs_triggered + train_time_interval_triggered > 1: + raise MisconfigurationException( + f"Combination of parameters every_n_train_steps={self._every_n_train_steps}, " + f"every_n_epochs={self._every_n_epochs} and train_time_interval={self._train_time_interval} " + "should be mutually exclusive." + ) + if self.monitor is None: # -1: save all epochs, 0: nothing is saved, 1: save last epoch if self.save_top_k not in (-1, 0, 1): @@ -430,6 +450,27 @@ def __init_monitor_mode(self, mode: str) -> None: self.kth_value, self.mode = mode_dict[mode] + def __init_triggers( + self, + every_n_train_steps: Optional[int], + every_n_epochs: Optional[int], + train_time_interval: Optional[timedelta], + ) -> None: + + # Default to running once after each validation epoch if neither + # every_n_train_steps nor every_n_epochs is set + if every_n_train_steps is None and every_n_epochs is None and train_time_interval is None: + every_n_epochs = 1 + every_n_train_steps = 0 + log.debug("Both every_n_train_steps and every_n_epochs are not set. Setting every_n_epochs=1") + else: + every_n_epochs = every_n_epochs or 0 + every_n_train_steps = every_n_train_steps or 0 + + self._train_time_interval: Optional[timedelta] = train_time_interval + self._every_n_epochs: int = every_n_epochs + self._every_n_train_steps: int = every_n_train_steps + @property def every_n_epochs(self) -> Optional[int]: return self._every_n_epochs diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index e3857d03bf9a7..0c7d33c6d0e14 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -537,6 +537,22 @@ def test_invalid_every_n_train_steps(tmpdir): ModelCheckpoint(dirpath=tmpdir, every_n_epochs=2) +def test_invalid_trigger_combination(tmpdir): + """Test that a MisconfigurationException is raised if more than one of every_n_epochs, every_n_train_steps, and + train_time_interval are enabled together.""" + with pytest.raises(MisconfigurationException, match=r".*Combination of parameters every_n_train_steps"): + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_epochs=2) + with pytest.raises(MisconfigurationException, match=r".*Combination of parameters every_n_train_steps"): + ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_epochs=2) + with pytest.raises(MisconfigurationException, match=r".*Combination of parameters every_n_train_steps"): + ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_train_steps=2) + + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_epochs=3) + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_epochs=0) + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_epochs=0, train_time_interval=timedelta(minutes=1)) + + def test_none_every_n_train_steps_val_epochs(tmpdir): checkpoint_callback = ModelCheckpoint(dirpath=tmpdir) assert checkpoint_callback.every_n_epochs == 1 @@ -608,6 +624,7 @@ def test_ckpt_every_n_train_steps(tmpdir): epoch_length = 64 checkpoint_callback = ModelCheckpoint( filename="{step}", + every_n_epochs=0, every_n_train_steps=every_n_train_steps, dirpath=tmpdir, save_top_k=-1, @@ -657,9 +674,10 @@ def test_model_checkpoint_train_time_interval(mock_datetime, tmpdir) -> None: ) trainer.fit(model) - # Each batch takes 7 sec and we checkpoint every minute. There are 64 batches per epoch, so total time to run is - # 7*64*2 = 896 sec < 14.96 minutes so we should have 14 checkpoints. +2 for those saved at the end of the each epoch - assert len(os.listdir(tmpdir)) == 14 + 2 + # Each batch takes 7 sec and we checkpoint every minute. There are 64 + # batches per epoch, so total time to run is 7*64*2 = 896 sec < 14.96 minutes, + # so we should have 14 checkpoints. + assert len(os.listdir(tmpdir)) == 14 def test_model_checkpoint_topk_zero(tmpdir): From d414792d72e390d341dd219c619ed38154b54487 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 23 Mar 2022 03:56:02 +0100 Subject: [PATCH 05/11] Simplify --- pytorch_lightning/callbacks/model_checkpoint.py | 3 +-- tests/checkpointing/test_model_checkpoint.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index b997410874504..056225dc44895 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -117,6 +117,7 @@ class ModelCheckpoint(Callback): every_n_epochs: Number of epochs between checkpoints. This value must be ``None`` or non-negative. To disable saving top-k checkpoints, set ``every_n_epochs = 0``. + This argument does not impact the saving of ``save_last=True`` checkpoints. If all of ``every_n_epochs``, ``every_n_train_steps`` and ``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch (equivalent to ``every_n_epochs = 1``). @@ -128,7 +129,6 @@ class ModelCheckpoint(Callback): ``Trainer(max_epochs=N, check_val_every_n_epoch=M)`` will only save checkpoints at epochs 0 < E <= N where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E. - This argument does not impact the saving of ``save_last=True`` checkpoints. save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch. If this is ``False``, then the check runs at the end of the validation. @@ -360,7 +360,6 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None: def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: if self.save_top_k == 0: - # `every_n_epochs` only applies to monitored checkpoints return # validate metric diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 0c7d33c6d0e14..9baa50ff0eba2 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1252,7 +1252,7 @@ def test_save_last_saves_correct_last_model_path(tmpdir): trainer = Trainer(callbacks=mc) trainer.strategy.connect(BoringModel()) - mc._save_last_checkpoint(trainer, {"foo": 1, "step": 0}) + mc._save_last_checkpoint(trainer, {"foo": 1}) expected = "foo=1-last.ckpt" assert os.listdir(tmpdir) == [expected] full_path = str(tmpdir / expected) @@ -1265,7 +1265,7 @@ def test_none_monitor_saves_correct_best_model_path(tmpdir): trainer = Trainer(callbacks=mc) trainer.strategy.connect(BoringModel()) - mc._save_none_monitor_checkpoint(trainer, {"step": 0}) + mc._save_none_monitor_checkpoint(trainer, {}) expected = "epoch=0-step=0.ckpt" assert os.listdir(tmpdir) == [expected] full_path = str(tmpdir / expected) From 1035e325304bc111578f517853fce8cdb8227fb1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 23 Mar 2022 04:02:44 +0100 Subject: [PATCH 06/11] pragma: no-cover --- pytorch_lightning/callbacks/model_checkpoint.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 056225dc44895..f41c4a628fdbc 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -296,7 +296,9 @@ def on_train_batch_end( if not skip_time: self._last_time_checked = now - self.save_checkpoint(trainer) + monitor_candidates = self._monitor_candidates(trainer) + self._save_topk_checkpoint(trainer, monitor_candidates) + self._save_last_checkpoint(trainer, monitor_candidates) def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Save a checkpoint at the end of the training epoch.""" @@ -348,12 +350,13 @@ def on_load_checkpoint( self.last_model_path = callback_state.get("last_model_path", self.last_model_path) self.best_model_path = callback_state["best_model_path"] - def save_checkpoint(self, trainer: "pl.Trainer") -> None: + def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover """Performs the main logic around saving a checkpoint. This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases. """ + # TODO: unused method. deprecate it monitor_candidates = self._monitor_candidates(trainer) self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) From ccbdb77fc31ec7a02b2daf2eba6b79d77cf7939e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 23 Mar 2022 15:09:41 +0100 Subject: [PATCH 07/11] Kush review --- pytorch_lightning/callbacks/model_checkpoint.py | 8 ++++---- tests/checkpointing/test_model_checkpoint.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f41c4a628fdbc..0a360f2999899 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -380,7 +380,7 @@ def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[ else: self._save_none_monitor_checkpoint(trainer, monitor_candidates) - def _actual_save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: + def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: trainer.save_checkpoint(filepath, self.save_weights_only) self._last_global_step_saved = trainer.global_step @@ -632,7 +632,7 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[ filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST) # set the last model path before saving because it will be part of the state. previous, self.last_model_path = self.last_model_path, filepath - self._actual_save_checkpoint(trainer, filepath) + self._save_checkpoint(trainer, filepath) if previous and previous != filepath: trainer.strategy.remove_checkpoint(previous) @@ -649,7 +649,7 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer) # set the best model path before saving because it will be part of the state. previous, self.best_model_path = self.best_model_path, filepath - self._actual_save_checkpoint(trainer, filepath) + self._save_checkpoint(trainer, filepath) if self.save_top_k == 1 and previous and previous != filepath: trainer.strategy.remove_checkpoint(previous) @@ -690,7 +690,7 @@ def _update_best_and_save( f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} reached {current:0.5f}" f" (best {self.best_model_score:0.5f}), saving model to {filepath!r} as top {k}" ) - self._actual_save_checkpoint(trainer, filepath) + self._save_checkpoint(trainer, filepath) if del_filepath is not None and filepath != del_filepath: trainer.strategy.remove_checkpoint(del_filepath) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 9baa50ff0eba2..544cbb5affe47 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1282,9 +1282,10 @@ def test_last_global_step_saved(): assert model_checkpoint._last_global_step_saved == 0 -def test_save_last_every_n_epochs_interaction(tmpdir): +@pytest.mark.parametrize("every_n_epochs", (0, 5)) +def test_save_last_every_n_epochs_interaction(tmpdir, every_n_epochs): """Test that `save_last` ignores `every_n_epochs`.""" - mc = ModelCheckpoint(every_n_epochs=5, save_last=True, save_top_k=0, save_on_train_epoch_end=True) + mc = ModelCheckpoint(every_n_epochs=every_n_epochs, save_last=True, save_top_k=0, save_on_train_epoch_end=True) trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, @@ -1295,7 +1296,6 @@ def test_save_last_every_n_epochs_interaction(tmpdir): enable_model_summary=False, logger=False, ) - assert mc.every_n_epochs > trainer.max_epochs model = BoringModel() with patch.object(trainer, "save_checkpoint") as save_mock: trainer.fit(model) From 61f9bef1dae7b151f723702faf39ce4e1cb62d6b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 23 Mar 2022 15:26:20 +0100 Subject: [PATCH 08/11] Fix tests --- pytorch_lightning/callbacks/model_checkpoint.py | 10 ++++++++-- tests/loggers/test_mlflow.py | 3 +-- tests/models/test_restore.py | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 0a360f2999899..9b9cddf11ee10 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -621,8 +621,14 @@ def _get_metric_interpolated_filepath_name( def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]: monitor_candidates = deepcopy(trainer.callback_metrics) - monitor_candidates.setdefault("epoch", torch.tensor(trainer.current_epoch)) - monitor_candidates.setdefault("step", torch.tensor(trainer.global_step)) + # cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor + # or does not exist we overwrite it as it's likely an error + epoch = monitor_candidates.get("epoch") + monitor_candidates["epoch"] = ( + epoch.int() if isinstance(epoch, torch.Tensor) else torch.tensor(trainer.current_epoch) + ) + step = monitor_candidates.get("step") + monitor_candidates["step"] = step.int() if isinstance(step, torch.Tensor) else torch.tensor(trainer.global_step) return monitor_candidates def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None: diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index 5ce5ceb75a0b1..77afe361b035f 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -159,8 +159,7 @@ def test_mlflow_logger_dirs_creation(tmpdir): assert set(os.listdir(tmpdir / exp_id)) == {run_id, "meta.yaml"} class CustomModel(BoringModel): - def training_epoch_end(self, *args, **kwargs): - super().training_epoch_end(*args, **kwargs) + def on_train_epoch_end(self, *args, **kwargs): self.log("epoch", self.current_epoch) model = CustomModel() diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index e5259c4047ad2..e5c91e0f71e8d 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -106,7 +106,7 @@ def validation_step_end(self, outputs): def test_model_properties_fit_ckpt_path(tmpdir): """Test that properties like `current_epoch` and `global_step` in model and trainer are always the same.""" model = BoringModel() - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_last=True) trainer_args = dict( default_root_dir=tmpdir, max_epochs=1, From 81f862a8cb7a5f844aaeea24f370d25c7b185805 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 23 Mar 2022 15:48:54 +0100 Subject: [PATCH 09/11] Skip test --- tests/trainer/test_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b6980fc860a20..17e40e2aeb2c5 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -403,6 +403,7 @@ def test_model_freeze_unfreeze(): assert param.requires_grad +@pytest.mark.xfail(reason="FIXME(@carmocca): this test wasn't running and is now broken") @pytest.mark.parametrize("url_ckpt", [True, False]) def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Verify resuming from checkpoint runs the right number of epochs.""" @@ -429,7 +430,7 @@ def on_load_checkpoint(self, _): max_epochs=2, limit_train_batches=0.65, limit_val_batches=1, - callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_top_k=-1)], + callbacks=[ModelCheckpoint(dirpath=tmpdir, save_top_k=-1)], default_root_dir=tmpdir, val_check_interval=1.0, enable_progress_bar=False, @@ -449,6 +450,7 @@ def on_load_checkpoint(self, _): ip, port = tmpdir_server checkpoints = [f"http://{ip}:{port}/" + ckpt.name for ckpt in checkpoints] + assert checkpoints for ckpt in checkpoints: next_model = TestModel() state = pl_load(ckpt) From 5fd354fdf5883cd52bbefa2a2dd50052c4a028d6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 23 Mar 2022 16:29:46 +0100 Subject: [PATCH 10/11] Fix docstring format --- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 9b9cddf11ee10..9d13cc49f8c40 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -100,8 +100,8 @@ class ModelCheckpoint(Callback): based on either the maximization or the minimization of the monitored quantity. For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc. auto_insert_metric_name: When ``True``, the checkpoints filenames will contain the metric name. - For example, ``filename='checkpoint_{epoch:02d}-{acc:02d}`` with epoch 1 and acc 80 will resolve to - ``checkpoint_epoch=01-acc=80.ckp``. Is useful to set it to ``False`` when metric names contain ``/`` + For example, ``filename='checkpoint_{epoch:02d}-{acc:02.0f}`` with epoch ``1`` and acc ``1.12`` will resolve + to ``checkpoint_epoch=01-acc=01.ckp``. Is useful to set it to ``False`` when metric names contain ``/`` as this will result in extra folders. save_weights_only: if ``True``, then only the model's weights will be saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too. From a49efb7f2d5c4becfa32b998b92cd7996820a958 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 24 Mar 2022 13:44:57 +0100 Subject: [PATCH 11/11] Apply suggestions from code review Co-authored-by: Rohit Gupta --- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 9d13cc49f8c40..ac05e6f66ebdd 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -101,7 +101,7 @@ class ModelCheckpoint(Callback): For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc. auto_insert_metric_name: When ``True``, the checkpoints filenames will contain the metric name. For example, ``filename='checkpoint_{epoch:02d}-{acc:02.0f}`` with epoch ``1`` and acc ``1.12`` will resolve - to ``checkpoint_epoch=01-acc=01.ckp``. Is useful to set it to ``False`` when metric names contain ``/`` + to ``checkpoint_epoch=01-acc=01.ckpt``. Is useful to set it to ``False`` when metric names contain ``/`` as this will result in extra folders. save_weights_only: if ``True``, then only the model's weights will be saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too. @@ -369,7 +369,7 @@ def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[ if self.monitor is not None: if self.monitor not in monitor_candidates: m = ( - f"`ModelCheckpoint(monitor='{self.monitor}')` could not find the monitored key in the returned" + f"`ModelCheckpoint(monitor={self.monitor!r})` could not find the monitored key in the returned" f" metrics: {list(monitor_candidates)}." f" HINT: Did you call `log({self.monitor!r}, value)` in the `LightningModule`?" )