Skip to content

Commit 6d79184

Browse files
authored
Unify checkpoint load paths [redo #9693] (#10061)
1 parent 76081fb commit 6d79184

32 files changed

+209
-180
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
226226
- Added `XLACheckpointIO` plugin ([#9972](https://github.com/PyTorchLightning/pytorch-lightning/pull/9972))
227227

228228

229+
- Added `ckpt_path` argument for `trainer.fit()` ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061))
230+
229231

230232
### Changed
231233

@@ -418,6 +420,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
418420
- Deprecated access to the `AcceleratorConnector.configure_slurm_ddp` method and marked it as protected ([#10101](https://github.com/PyTorchLightning/pytorch-lightning/pull/10101))
419421

420422

423+
- Deprecated passing `resume_from_checkpoint` to the `Trainer` constructor in favor of `trainer.fit(ckpt_path=)` ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061))
424+
425+
421426
### Removed
422427

423428
- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))

docs/source/advanced/advanced_gpu.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ After training using ZeRO Stage 3, you'll notice that your checkpoints are a dir
622622
623623
.. warning::
624624

625-
This single file checkpoint does not include the optimizer/lr-scheduler states. This means we cannot restore training via the `resume_from_checkpoint` Trainer argument. Ensure to keep the sharded checkpoint directory if this is required.
625+
This single file checkpoint does not include the optimizer/lr-scheduler states. This means we cannot restore training via the ``trainer.fit(ckpt_path=)`` call. Ensure to keep the sharded checkpoint directory if this is required.
626626

627627
Custom DeepSpeed Config
628628
"""""""""""""""""""""""

docs/source/common/trainer.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,6 +1349,10 @@ By setting to False, you have to add your own distributed sampler:
13491349
resume_from_checkpoint
13501350
^^^^^^^^^^^^^^^^^^^^^^
13511351

1352+
.. warning:: ``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v1.7.
1353+
Please pass ``trainer.fit(ckpt_path="some/path/to/my_checkpoint.ckpt")`` instead.
1354+
1355+
13521356
.. raw:: html
13531357

13541358
<video width="50%" max-width="400px" controls

docs/source/common/weights_loading.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ do the following:
212212
.. code-block:: python
213213
214214
model = LitModel()
215-
trainer = Trainer(resume_from_checkpoint="some/path/to/my_checkpoint.ckpt")
215+
trainer = Trainer()
216216
217217
# automatically restores model, epoch, step, LR schedulers, apex, etc...
218-
trainer.fit(model)
218+
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")

docs/source/extensions/loops_advanced.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ The two hooks :class:`~pytorch_lightning.loops.base.Loop.on_save_checkpoint` and
3030
def on_load_checkpoint(self, state_dict):
3131
self.iteration = state_dict["iteration"]
3232
33-
When the Trainer is restarting from a checkpoint (e.g., through :code:`Trainer(resume_from_checkpoint=...)`), the loop exposes a boolean attribute :attr:`~pytorch_lightning.loops.base.Loop.restarting`.
33+
When the Trainer is restarting from a checkpoint (e.g., through :code:`trainer.fit(ckpt_path=...)`), the loop exposes a boolean attribute :attr:`~pytorch_lightning.loops.base.Loop.restarting`.
3434
Based around the value of this variable, the user can write the loop in such a way that it can restart from an arbitrary point given the state loaded from the checkpoint.
3535
For example, the implementation of the :meth:`~pytorch_lightning.loops.base.Loop.reset` method could look like this given our previous example:
3636

pytorch_lightning/trainer/callback_hook.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,8 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
279279
difference = callback_states.keys() - current_callbacks_keys
280280
if difference:
281281
rank_zero_warn(
282-
"Be aware that when using `resume_from_checkpoint`,"
283-
" callbacks used to create the checkpoint need to be provided."
282+
"Be aware that when using `ckpt_path`,"
283+
" callbacks used to create the checkpoint need to be provided during `Trainer` instantiation."
284284
f" Please add the following callbacks: {list(difference)}.",
285285
UserWarning,
286286
)

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,14 @@
3838
class CheckpointConnector:
3939
def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH] = None) -> None:
4040
self.trainer = trainer
41-
self.resume_checkpoint_path = resume_from_checkpoint
41+
self.resume_checkpoint_path: Optional[_PATH] = None
42+
# TODO: remove resume_from_checkpoint_fit_path in v1.7
43+
self.resume_from_checkpoint_fit_path: Optional[_PATH] = resume_from_checkpoint
44+
if resume_from_checkpoint is not None:
45+
rank_zero_deprecation(
46+
"Setting `Trainer(resume_from_checkpoint=)` is deprecated in v1.5 and"
47+
" will be removed in v1.7. Please pass `Trainer.fit(ckpt_path=)` directly instead."
48+
)
4249
self._loaded_checkpoint: Dict[str, Any] = {}
4350

4451
@property
@@ -53,14 +60,14 @@ def hpc_resume_path(self) -> Optional[str]:
5360
if os.path.exists(auto_save_checkpoint):
5461
return auto_save_checkpoint
5562

56-
def resume_start(self) -> None:
63+
def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
5764
"""Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:
5865
5966
1. from HPC weights if found
60-
2. from `resume_from_checkpoint` file if provided
67+
2. from `checkpoint_path` file if provided
6168
3. don't restore
6269
"""
63-
self.resume_checkpoint_path = self.hpc_resume_path or self.resume_checkpoint_path
70+
self.resume_checkpoint_path = self.hpc_resume_path or checkpoint_path
6471
checkpoint_path = self.resume_checkpoint_path
6572
if not checkpoint_path:
6673
return
@@ -83,8 +90,18 @@ def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any
8390
def resume_end(self) -> None:
8491
"""Signal the connector that all states have resumed and memory for the checkpoint object can be
8592
released."""
93+
assert self.trainer.state.fn is not None
8694
if self.resume_checkpoint_path:
87-
rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}")
95+
if self.trainer.state.fn == TrainerFn.FITTING:
96+
rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}")
97+
elif self.trainer.state.fn in (TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING):
98+
rank_zero_info(f"Loaded model weights from checkpoint at {self.resume_checkpoint_path}")
99+
# TODO: remove resume_from_checkpoint_fit_path in v1.7
100+
if (
101+
self.trainer.state.fn == TrainerFn.FITTING
102+
and self.resume_checkpoint_path == self.resume_from_checkpoint_fit_path
103+
):
104+
self.resume_from_checkpoint_fit_path = None
88105
self.resume_checkpoint_path = None
89106
self._loaded_checkpoint = {}
90107

@@ -99,16 +116,15 @@ def restore(self, checkpoint_path: Optional[_PATH] = None) -> None:
99116
state-restore, in this priority:
100117
101118
1. from HPC weights if found
102-
2. from `resume_from_checkpoint` file if provided
119+
2. from `checkpoint_path` file if provided
103120
3. don't restore
104121
105122
All restored states are listed in return value description of `dump_checkpoint`.
106123
107124
Args:
108125
checkpoint_path: Path to a PyTorch Lightning checkpoint file.
109126
"""
110-
self.resume_checkpoint_path = checkpoint_path
111-
self.resume_start()
127+
self.resume_start(checkpoint_path)
112128

113129
# restore module states
114130
self.restore_datamodule()
@@ -157,15 +173,6 @@ def restore_model(self) -> None:
157173
if isinstance(module, Metric):
158174
module.reset()
159175

160-
def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None:
161-
"""Restore only the model weights."""
162-
checkpoint = self._loaded_checkpoint
163-
if checkpoint_path is not None:
164-
checkpoint = self._load_and_validate_checkpoint(checkpoint_path)
165-
166-
self.trainer.lightning_module.on_load_checkpoint(checkpoint)
167-
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)
168-
169176
def restore_training_state(self) -> None:
170177
"""Restore the trainer state from the pre-loaded checkpoint.
171178

pytorch_lightning/trainer/trainer.py

Lines changed: 41 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,10 @@ def __init__(
362362
no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint,
363363
training will start from the beginning of the next epoch.
364364
365+
.. deprecated:: v1.5
366+
``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v1.7.
367+
Please pass the path to ``Trainer.fit(..., ckpt_path=...)`` instead.
368+
365369
strategy: Supports different training strategies with aliases
366370
as well custom training type plugins.
367371
@@ -617,6 +621,7 @@ def fit(
617621
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
618622
datamodule: Optional[LightningDataModule] = None,
619623
train_dataloader=None, # TODO: remove with 1.6
624+
ckpt_path: Optional[str] = None,
620625
) -> None:
621626
r"""
622627
Runs the full optimization routine.
@@ -630,6 +635,10 @@ def fit(
630635
631636
val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
632637
638+
ckpt_path: Path/URL of the checkpoint from which training is resumed. If there is
639+
no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint,
640+
training will start from the beginning of the next epoch.
641+
633642
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
634643
"""
635644
if train_dataloader is not None:
@@ -638,14 +647,17 @@ def fit(
638647
" Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'"
639648
)
640649
train_dataloaders = train_dataloader
641-
self._call_and_handle_interrupt(self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule)
650+
self._call_and_handle_interrupt(
651+
self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
652+
)
642653

643654
def _fit_impl(
644655
self,
645656
model: "pl.LightningModule",
646657
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
647658
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
648659
datamodule: Optional[LightningDataModule] = None,
660+
ckpt_path: Optional[str] = None,
649661
) -> None:
650662
Trainer._log_api_event("fit")
651663

@@ -668,7 +680,9 @@ def _fit_impl(
668680
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
669681
)
670682

671-
self._run(model)
683+
# TODO: ckpt_path only in v1.7
684+
ckpt_path = ckpt_path or self.resume_from_checkpoint
685+
self._run(model, ckpt_path=ckpt_path)
672686

673687
assert self.state.stopped
674688
self.training = False
@@ -755,7 +769,7 @@ def _validate_impl(
755769
)
756770

757771
# run validate
758-
results = self._run(model)
772+
results = self._run(model, ckpt_path=self.validated_ckpt_path)
759773

760774
assert self.state.stopped
761775
self.validating = False
@@ -845,7 +859,7 @@ def _test_impl(
845859
)
846860

847861
# run test
848-
results = self._run(model)
862+
results = self._run(model, ckpt_path=self.tested_ckpt_path)
849863

850864
assert self.state.stopped
851865
self.testing = False
@@ -928,7 +942,7 @@ def _predict_impl(
928942
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
929943
)
930944

931-
results = self._run(model)
945+
results = self._run(model, ckpt_path=self.predicted_ckpt_path)
932946

933947
assert self.state.stopped
934948
self.predicting = False
@@ -997,24 +1011,18 @@ def tune(
9971011

9981012
return result
9991013

1000-
def _restore_modules_and_callbacks(self) -> None:
1001-
if self.state.fn != TrainerFn.FITTING:
1002-
return
1003-
1004-
self.checkpoint_connector.restore_datamodule()
1014+
def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:
1015+
# restore modules after setup
1016+
self.checkpoint_connector.resume_start(checkpoint_path)
10051017
self.checkpoint_connector.restore_model()
1006-
# restore callback states
1007-
self.checkpoint_connector.restore_callbacks()
1008-
1009-
def _load_checkpoint_weights(self):
1010-
# only one process running at this point for TPUs, as spawn isn't triggered yet
1011-
# todo: move this logic internally within the barrier.
1012-
if not self._device_type == DeviceType.TPU:
1013-
self.training_type_plugin.barrier()
1014-
rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}")
1015-
self.checkpoint_connector.restore_model_weights(self._ckpt_path)
1016-
1017-
def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
1018+
if self.state.fn == TrainerFn.FITTING:
1019+
self.checkpoint_connector.restore_datamodule()
1020+
# restore callback states
1021+
self.checkpoint_connector.restore_callbacks()
1022+
1023+
def _run(
1024+
self, model: "pl.LightningModule", ckpt_path: Optional[str] = None
1025+
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
10181026
# clean hparams
10191027
if hasattr(model, "hparams"):
10201028
parsing.clean_namespace(model.hparams)
@@ -1031,9 +1039,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
10311039
self._data_connector.prepare_data()
10321040
self.callback_connector._attach_model_callbacks()
10331041

1034-
if self._ckpt_path and not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
1035-
self._load_checkpoint_weights()
1036-
10371042
# ----------------------------
10381043
# SET UP TRAINING
10391044
# ----------------------------
@@ -1042,9 +1047,8 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
10421047
self._call_setup_hook() # allow user to setup lightning_module in accelerator environment
10431048

10441049
# check if we should delay restoring checkpoint till later
1045-
if not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
1046-
self.checkpoint_connector.resume_start()
1047-
self._restore_modules_and_callbacks()
1050+
if not self.accelerator.restore_checkpoint_after_pre_dispatch:
1051+
self._restore_modules_and_callbacks(ckpt_path)
10481052

10491053
self._call_configure_sharded_model() # allow user to setup in model sharded environment
10501054
self.accelerator.setup(self)
@@ -1092,16 +1096,14 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
10921096
# plugin will setup fitting (e.g. ddp will launch child processes)
10931097
self._pre_dispatch()
10941098

1095-
if self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
1096-
if self._ckpt_path:
1097-
self._load_checkpoint_weights()
1098-
1099-
self.checkpoint_connector.resume_start()
1100-
self._restore_modules_and_callbacks()
1099+
if self.accelerator.restore_checkpoint_after_pre_dispatch:
1100+
self._restore_modules_and_callbacks(ckpt_path)
11011101

11021102
# restore optimizers, etc.
11031103
self.checkpoint_connector.restore_training_state()
11041104

1105+
self.checkpoint_connector.resume_end()
1106+
11051107
# dispatch `start_training` or `start_evaluating` or `start_predicting`
11061108
self._dispatch()
11071109

@@ -1201,9 +1203,6 @@ def _pre_training_routine(self):
12011203
# register signals
12021204
self.signal_connector.register_signal_handlers()
12031205

1204-
if self.state.fn != TrainerFn.TUNING:
1205-
self.checkpoint_connector.resume_end()
1206-
12071206
# --------------------------
12081207
# Pre-train
12091208
# --------------------------
@@ -1804,7 +1803,11 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]:
18041803

18051804
@property
18061805
def resume_from_checkpoint(self) -> Optional[Union[str, Path]]:
1807-
return self.checkpoint_connector.resume_checkpoint_path
1806+
rank_zero_deprecation(
1807+
"`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."
1808+
" Specify the fit checkpoint path with `trainer.fit(ckpt_path=)` instead."
1809+
)
1810+
return self.checkpoint_connector.resume_from_checkpoint_fit_path
18081811

18091812
def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
18101813
self.checkpoint_connector.save_checkpoint(filepath, weights_only)
@@ -2029,15 +2032,6 @@ def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop
20292032
if self.predicting:
20302033
return self.predict_loop
20312034

2032-
@property
2033-
def _ckpt_path(self) -> Optional[str]:
2034-
if self.state.fn == TrainerFn.VALIDATING:
2035-
return self.validated_ckpt_path
2036-
if self.state.fn == TrainerFn.TESTING:
2037-
return self.tested_ckpt_path
2038-
if self.state.fn == TrainerFn.PREDICTING:
2039-
return self.predicted_ckpt_path
2040-
20412035
"""
20422036
Logging properties
20432037
"""

tests/accelerators/test_cpu.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,8 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
8080
assert accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch
8181
assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch
8282

83-
trainer = Trainer(
84-
default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True, resume_from_checkpoint=checkpoint_path
85-
)
86-
trainer.fit(model)
83+
trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True)
84+
trainer.fit(model, ckpt_path=checkpoint_path)
8785
for func in (trainer.test, trainer.validate, trainer.predict):
8886
accelerator.training_type_plugin.predispatched_called = False
8987
func(model, ckpt_path=checkpoint_path)

tests/accelerators/test_tpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def test_resume_training_on_cpu(tmpdir):
6262
assert weight_tensor.device == torch.device("cpu")
6363

6464
# Verify that training is resumed on CPU
65-
trainer = Trainer(resume_from_checkpoint=model_path, max_epochs=1, default_root_dir=tmpdir)
66-
trainer.fit(model)
65+
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
66+
trainer.fit(model, ckpt_path=model_path)
6767
assert trainer.state.finished, f"Training failed with {trainer.state}"
6868

6969

0 commit comments

Comments
 (0)