-
Notifications
You must be signed in to change notification settings - Fork 3.5k
[see #10061 instead] Unify checkpoint load paths #9693
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6ecf45b
5a0d60c
5b7df74
b5dee8e
2bb5bc5
48d200b
ef2d2cd
205a380
cd5b5c0
ea41b41
7d89b88
df7d4a9
1f350c6
4315cf5
d5069ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -342,6 +342,10 @@ def __init__( | |
no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint, | ||
training will start from the beginning of the next epoch. | ||
|
||
.. deprecated:: v1.5 | ||
``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v1.7. | ||
Please use ``Trainer.fit(ckpt_path)`` instead. | ||
|
||
sync_batchnorm: Synchronize batch norm layers between process groups/whole world. | ||
|
||
terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the | ||
|
@@ -574,6 +578,7 @@ def fit( | |
val_dataloaders: Optional[EVAL_DATALOADERS] = None, | ||
datamodule: Optional[LightningDataModule] = None, | ||
train_dataloader=None, # TODO: remove with 1.6 | ||
ckpt_path: Optional[str] = None, | ||
) -> None: | ||
r""" | ||
Runs the full optimization routine. | ||
|
@@ -587,6 +592,10 @@ def fit( | |
|
||
val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. | ||
|
||
ckpt_path: Path/URL of the checkpoint from which training is resumed. If there is | ||
no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint, | ||
training will start from the beginning of the next epoch. | ||
|
||
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. | ||
""" | ||
if train_dataloader is not None: | ||
|
@@ -595,14 +604,17 @@ def fit( | |
" Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'" | ||
) | ||
train_dataloaders = train_dataloader | ||
self._call_and_handle_interrupt(self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule) | ||
self._call_and_handle_interrupt( | ||
self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path | ||
) | ||
|
||
def _fit_impl( | ||
self, | ||
model: "pl.LightningModule", | ||
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, | ||
val_dataloaders: Optional[EVAL_DATALOADERS] = None, | ||
datamodule: Optional[LightningDataModule] = None, | ||
ckpt_path: Optional[str] = None, | ||
) -> None: | ||
Trainer._log_api_event("fit") | ||
|
||
|
@@ -625,7 +637,9 @@ def _fit_impl( | |
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule | ||
) | ||
|
||
self._run(model) | ||
# TODO: ckpt_path only in v1.7 | ||
ckpt_path = ckpt_path or self.resume_from_checkpoint | ||
self._run(model, ckpt_path) | ||
|
||
assert self.state.stopped | ||
self.training = False | ||
|
@@ -711,7 +725,7 @@ def _validate_impl( | |
) | ||
|
||
# run validate | ||
results = self._run(model) | ||
results = self._run(model, self.validated_ckpt_path) | ||
|
||
assert self.state.stopped | ||
self.validating = False | ||
|
@@ -800,7 +814,7 @@ def _test_impl( | |
) | ||
|
||
# run test | ||
results = self._run(model) | ||
results = self._run(model, self.tested_ckpt_path) | ||
|
||
assert self.state.stopped | ||
self.testing = False | ||
|
@@ -882,7 +896,7 @@ def _predict_impl( | |
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None | ||
) | ||
|
||
results = self._run(model) | ||
results = self._run(model, self.predicted_ckpt_path) | ||
|
||
assert self.state.stopped | ||
self.predicting = False | ||
|
@@ -951,24 +965,18 @@ def tune( | |
|
||
return result | ||
|
||
def _restore_modules_and_callbacks(self) -> None: | ||
def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None: | ||
# restore modules after setup | ||
if self.state.fn == TrainerFn.FITTING: | ||
self.checkpoint_connector.resume_start() | ||
self.checkpoint_connector.restore_datamodule() | ||
self.checkpoint_connector.resume_start(checkpoint_path) | ||
self.checkpoint_connector.restore_model() | ||
# restore callback states | ||
self.checkpoint_connector.restore_callbacks() | ||
|
||
def _load_checkpoint_weights(self): | ||
# only one process running at this point for TPUs, as spawn isn't triggered yet | ||
# todo: move this logic internally within the barrier. | ||
if not self._device_type == DeviceType.TPU: | ||
self.training_type_plugin.barrier() | ||
rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}") | ||
self.checkpoint_connector.restore_model_weights(self._ckpt_path) | ||
|
||
def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: | ||
if self.state.fn == TrainerFn.FITTING: | ||
self.checkpoint_connector.restore_datamodule() | ||
# restore callback states | ||
self.checkpoint_connector.restore_callbacks() | ||
ananthsub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _run( | ||
self, model: "pl.LightningModule", ckpt_path: Optional[str] = None | ||
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: | ||
# clean hparams | ||
if hasattr(model, "hparams"): | ||
parsing.clean_namespace(model.hparams) | ||
|
@@ -985,9 +993,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, | |
self.data_connector.prepare_data() | ||
self.callback_connector._attach_model_callbacks() | ||
|
||
if self._ckpt_path and not self.accelerator.restore_checkpoint_after_pre_dispatch: | ||
self._load_checkpoint_weights() | ||
|
||
# ---------------------------- | ||
# SET UP TRAINING | ||
# ---------------------------- | ||
|
@@ -997,7 +1002,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, | |
|
||
# check if we should delay restoring checkpoint till later | ||
if not self.accelerator.restore_checkpoint_after_pre_dispatch: | ||
self._restore_modules_and_callbacks() | ||
self._restore_modules_and_callbacks(ckpt_path) | ||
|
||
self._call_configure_sharded_model() # allow user to setup in model sharded environment | ||
self.accelerator.setup(self) | ||
|
@@ -1046,12 +1051,13 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, | |
self._pre_dispatch() | ||
|
||
if self.accelerator.restore_checkpoint_after_pre_dispatch: | ||
if self._ckpt_path: | ||
self._load_checkpoint_weights() | ||
self._restore_modules_and_callbacks() | ||
self._restore_modules_and_callbacks(ckpt_path) | ||
|
||
# restore optimizers, etc. | ||
self.checkpoint_connector.restore_training_state() | ||
if self.state.fn == TrainerFn.FITTING: | ||
self.checkpoint_connector.restore_training_state() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. restore training state includes things which can be resumed even if not fitting, such as the loop state. imo we shouldn't add the check for fitting here, but rather inside the select parts inside of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (above comment?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, we do restore loops there and now some other attributes as well. I'd suggest waiting for this one to get merged: #9413 |
||
|
||
self.checkpoint_connector.resume_end() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. n00b question: why is this bumped up to here vs in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. now that this calls |
||
|
||
# dispatch `start_training` or `start_evaluating` or `start_predicting` | ||
self._dispatch() | ||
|
@@ -1152,8 +1158,6 @@ def _pre_training_routine(self): | |
# register signals | ||
self.signal_connector.register_signal_handlers() | ||
|
||
self.checkpoint_connector.resume_end() | ||
|
||
# -------------------------- | ||
# Pre-train | ||
# -------------------------- | ||
|
@@ -1742,6 +1746,10 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]: | |
|
||
@property | ||
def resume_from_checkpoint(self) -> Optional[Union[str, Path]]: | ||
rank_zero_deprecation( | ||
"`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7." | ||
" Specify fit ckpt_path with `trainer.fit(ckpt_path=)` instead." | ||
) | ||
return self.checkpoint_connector.resume_checkpoint_path | ||
|
||
def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None: | ||
|
@@ -1974,15 +1982,6 @@ def train_loop(self) -> FitLoop: | |
) | ||
return self.fit_loop | ||
|
||
@property | ||
def _ckpt_path(self) -> Optional[str]: | ||
if self.state.fn == TrainerFn.VALIDATING: | ||
return self.validated_ckpt_path | ||
if self.state.fn == TrainerFn.TESTING: | ||
return self.tested_ckpt_path | ||
if self.state.fn == TrainerFn.PREDICTING: | ||
return self.predicted_ckpt_path | ||
|
||
""" | ||
Logging properties | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this also be typed as
_PATH
?