Skip to content

Add profiling for on_load_checkpoint/on_save_checkpoint callback and LM hooks #12149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 22, 2022
Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for pluggable Accelerators ([#12030](https://github.com/PyTorchLightning/pytorch-lightning/pull/12030))


- Added profiling for `on_load_checkpoint`/`on_save_checkpoint` callback and LightningModule hooks ([#12149](https://github.com/PyTorchLightning/pytorch-lightning/pull/12149))


- Added `LayerSync` and `NativeSyncBatchNorm` plugins ([#11754](https://github.com/PyTorchLightning/pytorch-lightning/pull/11754))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def restore_model(self) -> None:
model = self.trainer.lightning_module

# hook: give user access to checkpoint if needed.
model.on_load_checkpoint(self._loaded_checkpoint)
self.trainer._call_lightning_module_hook("on_load_checkpoint", self._loaded_checkpoint)

# TODO: remove this in v1.8.
# call hpc specific hook
Expand Down Expand Up @@ -393,7 +393,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
# support for returning state in on_save_checkpoint
# will be removed in v1.8
self.trainer._call_callbacks_on_save_checkpoint(checkpoint)
model.on_save_checkpoint(checkpoint)
self.trainer._call_lightning_module_hook("on_save_checkpoint", checkpoint)
if datamodule is not None:
datamodule.on_save_checkpoint(checkpoint)

Expand Down
26 changes: 22 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1677,19 +1677,33 @@ def _call_callbacks_on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None
Will be removed in v1.8: If state is returned, we insert the callback state into
``checkpoint["callbacks"][Callback.state_key]``. It overrides ``state_dict`` if already present.
"""
pl_module = self.lightning_module
if pl_module:
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = "on_save_checkpoint"

for callback in self.callbacks:
# TODO: Add profiling for on_save_checkpoint hook
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
with self.profiler.profile(f"[Callback]{callback.state_key}.on_save_checkpoint"):
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
if state:
# TODO: Add deprecation warning if state is returned (see reference PR #11887)
checkpoint["callbacks"][callback.state_key] = state

if pl_module:
# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name

def _call_callbacks_on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""Called when loading a model checkpoint.

Calls every callback's `on_load_checkpoint` hook. We have a dedicated function for this rather than using
`_call_callback_hooks` because we have special logic for getting callback_states.
"""
pl_module = self.lightning_module
if pl_module:
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = "on_load_checkpoint"

callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks")

if callback_states is None:
Expand All @@ -1709,8 +1723,12 @@ def _call_callbacks_on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None
state = callback_states.get(callback.state_key, callback_states.get(callback._legacy_state_key))
if state:
state = deepcopy(state)
# TODO: Add profiling for on_load_checkpoint hook
callback.on_load_checkpoint(self, self.lightning_module, state)
with self.profiler.profile(f"[Callback]{callback.state_key}.on_load_checkpoint"):
callback.on_load_checkpoint(self, self.lightning_module, state)

if pl_module:
# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name

def _call_callbacks_load_state_dict(self, checkpoint: Dict[str, Any]) -> None:
"""Called when loading a model checkpoint, calls every callback's `load_state_dict`."""
Expand Down
3 changes: 3 additions & 0 deletions tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ def test_fx_validator_integration(tmpdir):
"on_validation_model_eval": "You can't",
"on_validation_model_train": "You can't",
"lr_scheduler_step": "You can't",
"on_save_checkpoint": "You can't",
"on_load_checkpoint": "You can't",
"on_exception": "You can't",
"summarize": "not managed by the `Trainer",
}
model = HookedModel(not_supported)
Expand Down