Skip to content

[see #10061 instead] Unify checkpoint load paths #9693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated Accelerator collective API `barrier`, `broadcast`, and `all_gather`, call `TrainingTypePlugin` collective API directly ([#9677](https://github.com/PyTorchLightning/pytorch-lightning/pull/9677))


- Deprecated passing `resume_from_checkpoint` to the `Trainer` constructor in favor of `trainer.fit(ckpt_path=)` ([#9693](https://github.com/PyTorchLightning/pytorch-lightning/pull/9693))


### Removed

- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
difference = callback_states.keys() - current_callbacks_keys
if difference:
rank_zero_warn(
"Be aware that when using `resume_from_checkpoint`,"
"Be aware that when using `ckpt_path`,"
" callbacks used to create the checkpoint need to be provided."
f" Please add the following callbacks: {list(difference)}.",
UserWarning,
Expand Down
27 changes: 11 additions & 16 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@
class CheckpointConnector:
def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH] = None) -> None:
self.trainer = trainer
self.resume_checkpoint_path = resume_from_checkpoint
self.resume_checkpoint_path: Optional[_PATH] = resume_from_checkpoint
if resume_from_checkpoint is not None:
rank_zero_deprecation(
"Setting `Trainer(resume_from_checkpoint=)` is deprecated in v1.5 and"
" will be removed in v1.7. Please pass `Trainer.fit(ckpt_path=)` directly instead."
)
self._loaded_checkpoint: Dict[str, Any] = {}

@property
Expand All @@ -47,17 +52,17 @@ def hpc_resume_path(self) -> Optional[str]:
if max_version is not None:
return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt")

def resume_start(self) -> None:
def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
"""Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:

1. from HPC weights if found
2. from `resume_from_checkpoint` file if provided
2. from `checkpoint_path` file if provided
3. don't restore

Raises:
FileNotFoundError: If the path to the checkpoint file is provided but the file does not exist.
"""
self.resume_checkpoint_path = self.hpc_resume_path or self.resume_checkpoint_path
self.resume_checkpoint_path = self.hpc_resume_path or checkpoint_path
checkpoint_path = self.resume_checkpoint_path
if not checkpoint_path:
return
Expand Down Expand Up @@ -96,16 +101,15 @@ def restore(self, checkpoint_path: Optional[_PATH] = None) -> None:
state-restore, in this priority:

1. from HPC weights if found
2. from `resume_from_checkpoint` file if provided
2. from `checkpoint_path` file if provided
3. don't restore

All restored states are listed in return value description of `dump_checkpoint`.

Args:
checkpoint_path: Path to a PyTorch Lightning checkpoint file.
"""
self.resume_checkpoint_path = checkpoint_path
self.resume_start()
self.resume_start(checkpoint_path)

# restore module states
self.restore_datamodule()
Expand Down Expand Up @@ -154,15 +158,6 @@ def restore_model(self) -> None:
if isinstance(module, Metric):
module.reset()

def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None:
"""Restore only the model weights."""
checkpoint = self._loaded_checkpoint
if checkpoint_path is not None:
checkpoint = self._load_and_validate_checkpoint(checkpoint_path)

self.trainer.lightning_module.on_load_checkpoint(checkpoint)
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)

def restore_training_state(self) -> None:
"""Restore the trainer state from the pre-loaded checkpoint.

Expand Down
79 changes: 39 additions & 40 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ def __init__(
no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint,
training will start from the beginning of the next epoch.

.. deprecated:: v1.5
``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v1.7.
Please use ``Trainer.fit(ckpt_path)`` instead.

sync_batchnorm: Synchronize batch norm layers between process groups/whole world.

terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
Expand Down Expand Up @@ -574,6 +578,7 @@ def fit(
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
train_dataloader=None, # TODO: remove with 1.6
ckpt_path: Optional[str] = None,
) -> None:
r"""
Runs the full optimization routine.
Expand All @@ -587,6 +592,10 @@ def fit(

val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.

ckpt_path: Path/URL of the checkpoint from which training is resumed. If there is
no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint,
training will start from the beginning of the next epoch.

datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
"""
if train_dataloader is not None:
Expand All @@ -595,14 +604,17 @@ def fit(
" Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'"
)
train_dataloaders = train_dataloader
self._call_and_handle_interrupt(self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule)
self._call_and_handle_interrupt(
self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
)

def _fit_impl(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
ckpt_path: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this also be typed as _PATH ?

) -> None:
Trainer._log_api_event("fit")

Expand All @@ -625,7 +637,9 @@ def _fit_impl(
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
)

self._run(model)
# TODO: ckpt_path only in v1.7
ckpt_path = ckpt_path or self.resume_from_checkpoint
self._run(model, ckpt_path)

assert self.state.stopped
self.training = False
Expand Down Expand Up @@ -711,7 +725,7 @@ def _validate_impl(
)

# run validate
results = self._run(model)
results = self._run(model, self.validated_ckpt_path)

assert self.state.stopped
self.validating = False
Expand Down Expand Up @@ -800,7 +814,7 @@ def _test_impl(
)

# run test
results = self._run(model)
results = self._run(model, self.tested_ckpt_path)

assert self.state.stopped
self.testing = False
Expand Down Expand Up @@ -882,7 +896,7 @@ def _predict_impl(
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)

results = self._run(model)
results = self._run(model, self.predicted_ckpt_path)

assert self.state.stopped
self.predicting = False
Expand Down Expand Up @@ -951,24 +965,18 @@ def tune(

return result

def _restore_modules_and_callbacks(self) -> None:
def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:
# restore modules after setup
if self.state.fn == TrainerFn.FITTING:
self.checkpoint_connector.resume_start()
self.checkpoint_connector.restore_datamodule()
self.checkpoint_connector.resume_start(checkpoint_path)
self.checkpoint_connector.restore_model()
# restore callback states
self.checkpoint_connector.restore_callbacks()

def _load_checkpoint_weights(self):
# only one process running at this point for TPUs, as spawn isn't triggered yet
# todo: move this logic internally within the barrier.
if not self._device_type == DeviceType.TPU:
self.training_type_plugin.barrier()
rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}")
self.checkpoint_connector.restore_model_weights(self._ckpt_path)

def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
if self.state.fn == TrainerFn.FITTING:
self.checkpoint_connector.restore_datamodule()
# restore callback states
self.checkpoint_connector.restore_callbacks()

def _run(
self, model: "pl.LightningModule", ckpt_path: Optional[str] = None
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
# clean hparams
if hasattr(model, "hparams"):
parsing.clean_namespace(model.hparams)
Expand All @@ -985,9 +993,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
self.data_connector.prepare_data()
self.callback_connector._attach_model_callbacks()

if self._ckpt_path and not self.accelerator.restore_checkpoint_after_pre_dispatch:
self._load_checkpoint_weights()

# ----------------------------
# SET UP TRAINING
# ----------------------------
Expand All @@ -997,7 +1002,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,

# check if we should delay restoring checkpoint till later
if not self.accelerator.restore_checkpoint_after_pre_dispatch:
self._restore_modules_and_callbacks()
self._restore_modules_and_callbacks(ckpt_path)

self._call_configure_sharded_model() # allow user to setup in model sharded environment
self.accelerator.setup(self)
Expand Down Expand Up @@ -1046,12 +1051,13 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
self._pre_dispatch()

if self.accelerator.restore_checkpoint_after_pre_dispatch:
if self._ckpt_path:
self._load_checkpoint_weights()
self._restore_modules_and_callbacks()
self._restore_modules_and_callbacks(ckpt_path)

# restore optimizers, etc.
self.checkpoint_connector.restore_training_state()
if self.state.fn == TrainerFn.FITTING:
self.checkpoint_connector.restore_training_state()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

restore training state includes things which can be resumed even if not fitting, such as the loop state.

imo we shouldn't add the check for fitting here, but rather inside the select parts inside of restore_training_state

Copy link
Contributor Author

@jjenniferdai jjenniferdai Sep 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(above comment?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we do restore loops there and now some other attributes as well. I'd suggest waiting for this one to get merged: #9413


self.checkpoint_connector.resume_end()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: why is this bumped up to here vs in _pre_training_routine ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that this calls resume_start for all (not only fitting), similarly resume_end for all as well


# dispatch `start_training` or `start_evaluating` or `start_predicting`
self._dispatch()
Expand Down Expand Up @@ -1152,8 +1158,6 @@ def _pre_training_routine(self):
# register signals
self.signal_connector.register_signal_handlers()

self.checkpoint_connector.resume_end()

# --------------------------
# Pre-train
# --------------------------
Expand Down Expand Up @@ -1742,6 +1746,10 @@ def checkpoint_callbacks(self) -> List[ModelCheckpoint]:

@property
def resume_from_checkpoint(self) -> Optional[Union[str, Path]]:
rank_zero_deprecation(
"`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."
" Specify fit ckpt_path with `trainer.fit(ckpt_path=)` instead."
)
return self.checkpoint_connector.resume_checkpoint_path

def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
Expand Down Expand Up @@ -1974,15 +1982,6 @@ def train_loop(self) -> FitLoop:
)
return self.fit_loop

@property
def _ckpt_path(self) -> Optional[str]:
if self.state.fn == TrainerFn.VALIDATING:
return self.validated_ckpt_path
if self.state.fn == TrainerFn.TESTING:
return self.tested_ckpt_path
if self.state.fn == TrainerFn.PREDICTING:
return self.predicted_ckpt_path

"""
Logging properties
"""
Expand Down
6 changes: 2 additions & 4 deletions tests/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,8 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch
assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch

trainer = Trainer(
default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True, resume_from_checkpoint=checkpoint_path
)
trainer.fit(model)
trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True)
trainer.fit(model, ckpt_path=checkpoint_path)
for func in (trainer.test, trainer.validate, trainer.predict):
accelerator.training_type_plugin.predispatched_called = False
func(model, ckpt_path=checkpoint_path)
6 changes: 2 additions & 4 deletions tests/accelerators/test_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,8 @@ def test_resume_training_on_cpu(tmpdir):
assert weight_tensor.device == torch.device("cpu")

# Verify that training is resumed on CPU
trainer = Trainer(
resume_from_checkpoint=model_path, checkpoint_callback=True, max_epochs=1, default_root_dir=tmpdir
)
trainer.fit(model)
trainer = Trainer(checkpoint_callback=True, max_epochs=1, default_root_dir=tmpdir)
trainer.fit(model, ckpt_path=model_path)
assert trainer.state.finished, f"Training failed with {trainer.state}"


Expand Down
10 changes: 4 additions & 6 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def test_resume_callback_state_saved_by_type(tmpdir):
assert ckpt_path.exists()

callback = OldStatefulCallback(state=222)
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path)
trainer.fit(model)
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback])
trainer.fit(model, ckpt_path=ckpt_path)
assert callback.state == 111


Expand All @@ -153,16 +153,14 @@ def test_resume_incomplete_callbacks_list_warning(tmpdir):
default_root_dir=tmpdir,
max_steps=1,
callbacks=[callback1], # one callback is missing!
resume_from_checkpoint=ckpt_path,
)
with pytest.warns(UserWarning, match=escape(f"Please add the following callbacks: [{repr(callback0.state_key)}]")):
trainer.fit(model)
trainer.fit(model, ckpt_path=ckpt_path)

trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
callbacks=[callback1, callback0], # all callbacks here, order switched
resume_from_checkpoint=ckpt_path,
)
with no_warning_call(UserWarning, match="Please add the following callbacks:"):
trainer.fit(model)
trainer.fit(model, ckpt_path=ckpt_path)
3 changes: 1 addition & 2 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,11 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
new_trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
resume_from_checkpoint=checkpoint_filepath,
callbacks=[early_stop_callback],
)

with pytest.raises(MisconfigurationException, match=r"You restored a checkpoint with current_epoch"):
new_trainer.fit(model)
new_trainer.fit(model, ckpt_path=checkpoint_filepath)


def test_early_stopping_no_extraneous_invocations(tmpdir):
Expand Down
14 changes: 6 additions & 8 deletions tests/callbacks/test_finetuning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def configure_optimizers(self):
trainer.fit(model)

assert model.backbone.has_been_used
trainer = Trainer(max_epochs=3, resume_from_checkpoint=chk.last_model_path)
trainer.fit(model)
trainer = Trainer(max_epochs=3)
trainer.fit(model, ckpt_path=chk.last_model_path)


def test_freeze_unfreeze_function(tmpdir):
Expand Down Expand Up @@ -258,9 +258,9 @@ def configure_optimizers(self):

model = FreezeModel()
cb = OnEpochLayerFinetuning()
trainer = Trainer(max_epochs=10, resume_from_checkpoint=chk.last_model_path, callbacks=[cb])
trainer = Trainer(max_epochs=10, callbacks=[cb])
with pytest.raises(IndexError, match="index 6 is out of range"):
trainer.fit(model)
trainer.fit(model, ckpt_path=chk.last_model_path)


def test_on_before_accelerator_backend_setup(tmpdir):
Expand Down Expand Up @@ -400,10 +400,9 @@ def test_callbacks_restore(tmpdir):
}

trainer_kwargs["max_epochs"] = 3
trainer_kwargs["resume_from_checkpoint"] = chk.last_model_path

trainer = Trainer(**trainer_kwargs)
trainer.fit(model)
trainer.fit(model, ckpt_path=chk.last_model_path)


def test_callbacks_restore_backbone(tmpdir):
Expand Down Expand Up @@ -438,6 +437,5 @@ def forward(self, x):
max_epochs=3,
enable_progress_bar=False,
callbacks=BackboneFinetuning(unfreeze_backbone_at_epoch=1),
resume_from_checkpoint=ckpt.last_model_path,
)
trainer.fit(BackboneBoringModel())
trainer.fit(BackboneBoringModel(), ckpt_path=ckpt.last_model_path)
Loading