@@ -100,8 +100,8 @@ class ModelCheckpoint(Callback):
100
100
based on either the maximization or the minimization of the monitored quantity.
101
101
For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc.
102
102
auto_insert_metric_name: When ``True``, the checkpoints filenames will contain the metric name.
103
- For example, ``filename='checkpoint_{epoch:02d}-{acc:02d }`` with epoch 1 and acc 80 will resolve to
104
- ``checkpoint_epoch=01-acc=80.ckp ``. Is useful to set it to ``False`` when metric names contain ``/``
103
+ For example, ``filename='checkpoint_{epoch:02d}-{acc:02.0f }`` with epoch ``1`` and acc ``1.12`` will resolve
104
+ to ``checkpoint_epoch=01-acc=01.ckpt ``. Is useful to set it to ``False`` when metric names contain ``/``
105
105
as this will result in extra folders.
106
106
save_weights_only: if ``True``, then only the model's weights will be
107
107
saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
@@ -116,7 +116,8 @@ class ModelCheckpoint(Callback):
116
116
This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
117
117
every_n_epochs: Number of epochs between checkpoints.
118
118
This value must be ``None`` or non-negative.
119
- To disable saving after each epoch, set ``every_n_epochs = 0``.
119
+ To disable saving top-k checkpoints, set ``every_n_epochs = 0``.
120
+ This argument does not impact the saving of ``save_last=True`` checkpoints.
120
121
If all of ``every_n_epochs``, ``every_n_train_steps`` and
121
122
``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch
122
123
(equivalent to ``every_n_epochs = 1``).
@@ -295,28 +296,25 @@ def on_train_batch_end(
295
296
if not skip_time :
296
297
self ._last_time_checked = now
297
298
298
- self .save_checkpoint (trainer )
299
+ monitor_candidates = self ._monitor_candidates (trainer )
300
+ self ._save_topk_checkpoint (trainer , monitor_candidates )
301
+ self ._save_last_checkpoint (trainer , monitor_candidates )
299
302
300
303
def on_train_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
301
304
"""Save a checkpoint at the end of the training epoch."""
302
- if (
303
- not self ._should_skip_saving_checkpoint (trainer )
304
- and self ._save_on_train_epoch_end
305
- and self ._every_n_epochs > 0
306
- and (trainer .current_epoch + 1 ) % self ._every_n_epochs == 0
307
- ):
308
- self .save_checkpoint (trainer )
305
+ if not self ._should_skip_saving_checkpoint (trainer ) and self ._save_on_train_epoch_end :
306
+ monitor_candidates = self ._monitor_candidates (trainer )
307
+ if self ._every_n_epochs >= 1 and (trainer .current_epoch + 1 ) % self ._every_n_epochs == 0 :
308
+ self ._save_topk_checkpoint (trainer , monitor_candidates )
309
+ self ._save_last_checkpoint (trainer , monitor_candidates )
309
310
310
311
def on_validation_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
311
312
"""Save a checkpoint at the end of the validation stage."""
312
- if (
313
- self ._should_skip_saving_checkpoint (trainer )
314
- or self ._save_on_train_epoch_end
315
- or self ._every_n_epochs < 1
316
- or (trainer .current_epoch + 1 ) % self ._every_n_epochs != 0
317
- ):
318
- return
319
- self .save_checkpoint (trainer )
313
+ if not self ._should_skip_saving_checkpoint (trainer ) and not self ._save_on_train_epoch_end :
314
+ monitor_candidates = self ._monitor_candidates (trainer )
315
+ if self ._every_n_epochs >= 1 and (trainer .current_epoch + 1 ) % self ._every_n_epochs == 0 :
316
+ self ._save_topk_checkpoint (trainer , monitor_candidates )
317
+ self ._save_last_checkpoint (trainer , monitor_candidates )
320
318
321
319
def on_save_checkpoint (
322
320
self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , checkpoint : Dict [str , Any ]
@@ -352,26 +350,41 @@ def on_load_checkpoint(
352
350
self .last_model_path = callback_state .get ("last_model_path" , self .last_model_path )
353
351
self .best_model_path = callback_state ["best_model_path" ]
354
352
355
- def save_checkpoint (self , trainer : "pl.Trainer" ) -> None :
353
+ def save_checkpoint (self , trainer : "pl.Trainer" ) -> None : # pragma: no-cover
356
354
"""Performs the main logic around saving a checkpoint.
357
355
358
356
This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the
359
357
behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
360
358
"""
361
- self ._validate_monitor_key (trainer )
362
-
363
- # what can be monitored
364
- monitor_candidates = self ._monitor_candidates (trainer , epoch = trainer .current_epoch , step = trainer .global_step )
365
-
366
- # callback supports multiple simultaneous modes
367
- # here we call each mode sequentially
368
- # Mode 1: save the top k checkpoints
369
- self ._save_top_k_checkpoint (trainer , monitor_candidates )
370
- # Mode 2: save monitor=None checkpoints
371
- self ._save_none_monitor_checkpoint (trainer , monitor_candidates )
372
- # Mode 3: save last checkpoints
359
+ # TODO: unused method. deprecate it
360
+ monitor_candidates = self ._monitor_candidates (trainer )
361
+ self ._save_topk_checkpoint (trainer , monitor_candidates )
373
362
self ._save_last_checkpoint (trainer , monitor_candidates )
374
363
364
+ def _save_topk_checkpoint (self , trainer : "pl.Trainer" , monitor_candidates : Dict [str , _METRIC ]) -> None :
365
+ if self .save_top_k == 0 :
366
+ return
367
+
368
+ # validate metric
369
+ if self .monitor is not None :
370
+ if self .monitor not in monitor_candidates :
371
+ m = (
372
+ f"`ModelCheckpoint(monitor={ self .monitor !r} )` could not find the monitored key in the returned"
373
+ f" metrics: { list (monitor_candidates )} ."
374
+ f" HINT: Did you call `log({ self .monitor !r} , value)` in the `LightningModule`?"
375
+ )
376
+ if trainer .fit_loop .epoch_loop .val_loop ._has_run :
377
+ raise MisconfigurationException (m )
378
+ warning_cache .warn (m )
379
+ self ._save_monitor_checkpoint (trainer , monitor_candidates )
380
+ else :
381
+ self ._save_none_monitor_checkpoint (trainer , monitor_candidates )
382
+
383
+ def _save_checkpoint (self , trainer : "pl.Trainer" , filepath : str ) -> None :
384
+ trainer .save_checkpoint (filepath , self .save_weights_only )
385
+
386
+ self ._last_global_step_saved = trainer .global_step
387
+
375
388
# notify loggers
376
389
if trainer .is_global_zero :
377
390
for logger in trainer .loggers :
@@ -594,21 +607,6 @@ def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
594
607
if self .save_top_k != 0 and self ._fs .isdir (dirpath ) and len (self ._fs .ls (dirpath )) > 0 :
595
608
rank_zero_warn (f"Checkpoint directory { dirpath } exists and is not empty." )
596
609
597
- def _validate_monitor_key (self , trainer : "pl.Trainer" ) -> None :
598
- metrics = trainer .callback_metrics
599
-
600
- # validate metric
601
- if self .monitor is not None and not self ._is_valid_monitor_key (metrics ):
602
- m = (
603
- f"ModelCheckpoint(monitor='{ self .monitor } ') not found in the returned metrics:"
604
- f" { list (metrics .keys ())} . "
605
- f"HINT: Did you call self.log('{ self .monitor } ', value) in the LightningModule?"
606
- )
607
- if not trainer .fit_loop .epoch_loop .val_loop ._has_run :
608
- warning_cache .warn (m )
609
- else :
610
- raise MisconfigurationException (m )
611
-
612
610
def _get_metric_interpolated_filepath_name (
613
611
self , monitor_candidates : Dict [str , _METRIC ], trainer : "pl.Trainer" , del_filepath : Optional [str ] = None
614
612
) -> str :
@@ -621,51 +619,46 @@ def _get_metric_interpolated_filepath_name(
621
619
622
620
return filepath
623
621
624
- def _monitor_candidates (self , trainer : "pl.Trainer" , epoch : int , step : int ) -> Dict [str , _METRIC ]:
622
+ def _monitor_candidates (self , trainer : "pl.Trainer" ) -> Dict [str , _METRIC ]:
625
623
monitor_candidates = deepcopy (trainer .callback_metrics )
626
- monitor_candidates .update (epoch = epoch , step = step )
624
+ # cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor
625
+ # or does not exist we overwrite it as it's likely an error
626
+ epoch = monitor_candidates .get ("epoch" )
627
+ monitor_candidates ["epoch" ] = (
628
+ epoch .int () if isinstance (epoch , torch .Tensor ) else torch .tensor (trainer .current_epoch )
629
+ )
630
+ step = monitor_candidates .get ("step" )
631
+ monitor_candidates ["step" ] = step .int () if isinstance (step , torch .Tensor ) else torch .tensor (trainer .global_step )
627
632
return monitor_candidates
628
633
629
634
def _save_last_checkpoint (self , trainer : "pl.Trainer" , monitor_candidates : Dict [str , _METRIC ]) -> None :
630
635
if not self .save_last :
631
636
return
632
- self ._last_global_step_saved = monitor_candidates .get ("step" , trainer .global_step )
633
637
634
638
filepath = self .format_checkpoint_name (monitor_candidates , self .CHECKPOINT_NAME_LAST )
635
639
# set the last model path before saving because it will be part of the state.
636
640
previous , self .last_model_path = self .last_model_path , filepath
637
- trainer . save_checkpoint ( filepath , self . save_weights_only )
641
+ self . _save_checkpoint ( trainer , filepath )
638
642
if previous and previous != filepath :
639
643
trainer .strategy .remove_checkpoint (previous )
640
644
641
- def _save_top_k_checkpoint (self , trainer : "pl.Trainer" , monitor_candidates : Dict [str , _METRIC ]) -> None :
642
- if self .monitor is None or self .save_top_k == 0 :
643
- return
644
- self ._last_global_step_saved = monitor_candidates .get ("step" , trainer .global_step )
645
-
645
+ def _save_monitor_checkpoint (self , trainer : "pl.Trainer" , monitor_candidates : Dict [str , _METRIC ]) -> None :
646
646
current = monitor_candidates .get (self .monitor )
647
647
if self .check_monitor_top_k (trainer , current ):
648
648
self ._update_best_and_save (current , trainer , monitor_candidates )
649
649
elif self .verbose :
650
- epoch = monitor_candidates . get ( "epoch" )
651
- step = monitor_candidates . get ( "step" )
652
- rank_zero_info (f"Epoch { epoch :d} , global step { step :d} : { self .monitor } was not in top { self .save_top_k } " )
650
+ epoch = monitor_candidates [ "epoch" ]
651
+ step = monitor_candidates [ "step" ]
652
+ rank_zero_info (f"Epoch { epoch :d} , global step { step :d} : { self .monitor !r } was not in top { self .save_top_k } " )
653
653
654
654
def _save_none_monitor_checkpoint (self , trainer : "pl.Trainer" , monitor_candidates : Dict [str , _METRIC ]) -> None :
655
- if self .monitor is not None or self .save_top_k == 0 :
656
- return
657
- self ._last_global_step_saved = monitor_candidates .get ("step" , trainer .global_step )
658
-
659
655
filepath = self ._get_metric_interpolated_filepath_name (monitor_candidates , trainer )
660
656
# set the best model path before saving because it will be part of the state.
661
657
previous , self .best_model_path = self .best_model_path , filepath
662
- trainer . save_checkpoint ( filepath , self . save_weights_only )
658
+ self . _save_checkpoint ( trainer , filepath )
663
659
if self .save_top_k == 1 and previous and previous != filepath :
664
660
trainer .strategy .remove_checkpoint (previous )
665
661
666
- def _is_valid_monitor_key (self , metrics : Dict [str , _METRIC ]) -> bool :
667
- return self .monitor in metrics or len (metrics ) == 0
668
-
669
662
def _update_best_and_save (
670
663
self , current : torch .Tensor , trainer : "pl.Trainer" , monitor_candidates : Dict [str , _METRIC ]
671
664
) -> None :
@@ -697,13 +690,13 @@ def _update_best_and_save(
697
690
self .best_model_score = self .best_k_models [self .best_model_path ]
698
691
699
692
if self .verbose :
700
- epoch = monitor_candidates . get ( "epoch" )
701
- step = monitor_candidates . get ( "step" )
693
+ epoch = monitor_candidates [ "epoch" ]
694
+ step = monitor_candidates [ "step" ]
702
695
rank_zero_info (
703
- f"Epoch { epoch :d} , global step { step :d} : { self .monitor } reached { current :0.5f} "
704
- f' (best { self .best_model_score :0.5f} ), saving model to " { filepath } " as top { k } '
696
+ f"Epoch { epoch :d} , global step { step :d} : { self .monitor !r } reached { current :0.5f} "
697
+ f" (best { self .best_model_score :0.5f} ), saving model to { filepath !r } as top { k } "
705
698
)
706
- trainer . save_checkpoint ( filepath , self . save_weights_only )
699
+ self . _save_checkpoint ( trainer , filepath )
707
700
708
701
if del_filepath is not None and filepath != del_filepath :
709
702
trainer .strategy .remove_checkpoint (del_filepath )
0 commit comments