Skip to content

Commit 905a4d8

Browse files
Add profiling for on_load_checkpoint/on_save_checkpoint callback and LM hooks (#12149)
1 parent 5d156f4 commit 905a4d8

File tree

4 files changed

+30
-6
lines changed

4 files changed

+30
-6
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
143143
- Added support for pluggable Accelerators ([#12030](https://github.com/PyTorchLightning/pytorch-lightning/pull/12030))
144144

145145

146+
- Added profiling for `on_load_checkpoint`/`on_save_checkpoint` callback and LightningModule hooks ([#12149](https://github.com/PyTorchLightning/pytorch-lightning/pull/12149))
147+
148+
146149
- Added `LayerSync` and `NativeSyncBatchNorm` plugins ([#11754](https://github.com/PyTorchLightning/pytorch-lightning/pull/11754))
147150

148151

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def restore_model(self) -> None:
167167
model = self.trainer.lightning_module
168168

169169
# hook: give user access to checkpoint if needed.
170-
model.on_load_checkpoint(self._loaded_checkpoint)
170+
self.trainer._call_lightning_module_hook("on_load_checkpoint", self._loaded_checkpoint)
171171

172172
# TODO: remove this in v1.8.
173173
# call hpc specific hook
@@ -393,7 +393,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
393393
# support for returning state in on_save_checkpoint
394394
# will be removed in v1.8
395395
self.trainer._call_callbacks_on_save_checkpoint(checkpoint)
396-
model.on_save_checkpoint(checkpoint)
396+
self.trainer._call_lightning_module_hook("on_save_checkpoint", checkpoint)
397397
if datamodule is not None:
398398
datamodule.on_save_checkpoint(checkpoint)
399399

pytorch_lightning/trainer/trainer.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,19 +1678,33 @@ def _call_callbacks_on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None
16781678
Will be removed in v1.8: If state is returned, we insert the callback state into
16791679
``checkpoint["callbacks"][Callback.state_key]``. It overrides ``state_dict`` if already present.
16801680
"""
1681+
pl_module = self.lightning_module
1682+
if pl_module:
1683+
prev_fx_name = pl_module._current_fx_name
1684+
pl_module._current_fx_name = "on_save_checkpoint"
1685+
16811686
for callback in self.callbacks:
1682-
# TODO: Add profiling for on_save_checkpoint hook
1683-
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
1687+
with self.profiler.profile(f"[Callback]{callback.state_key}.on_save_checkpoint"):
1688+
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
16841689
if state:
16851690
# TODO: Add deprecation warning if state is returned (see reference PR #11887)
16861691
checkpoint["callbacks"][callback.state_key] = state
16871692

1693+
if pl_module:
1694+
# restore current_fx when nested context
1695+
pl_module._current_fx_name = prev_fx_name
1696+
16881697
def _call_callbacks_on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
16891698
"""Called when loading a model checkpoint.
16901699
16911700
Calls every callback's `on_load_checkpoint` hook. We have a dedicated function for this rather than using
16921701
`_call_callback_hooks` because we have special logic for getting callback_states.
16931702
"""
1703+
pl_module = self.lightning_module
1704+
if pl_module:
1705+
prev_fx_name = pl_module._current_fx_name
1706+
pl_module._current_fx_name = "on_load_checkpoint"
1707+
16941708
callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks")
16951709

16961710
if callback_states is None:
@@ -1710,8 +1724,12 @@ def _call_callbacks_on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None
17101724
state = callback_states.get(callback.state_key, callback_states.get(callback._legacy_state_key))
17111725
if state:
17121726
state = deepcopy(state)
1713-
# TODO: Add profiling for on_load_checkpoint hook
1714-
callback.on_load_checkpoint(self, self.lightning_module, state)
1727+
with self.profiler.profile(f"[Callback]{callback.state_key}.on_load_checkpoint"):
1728+
callback.on_load_checkpoint(self, self.lightning_module, state)
1729+
1730+
if pl_module:
1731+
# restore current_fx when nested context
1732+
pl_module._current_fx_name = prev_fx_name
17151733

17161734
def _call_callbacks_load_state_dict(self, checkpoint: Dict[str, Any]) -> None:
17171735
"""Called when loading a model checkpoint, calls every callback's `load_state_dict`."""

tests/trainer/logging_/test_logger_connector.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,9 @@ def test_fx_validator_integration(tmpdir):
243243
"on_validation_model_eval": "You can't",
244244
"on_validation_model_train": "You can't",
245245
"lr_scheduler_step": "You can't",
246+
"on_save_checkpoint": "You can't",
247+
"on_load_checkpoint": "You can't",
248+
"on_exception": "You can't",
246249
"summarize": "not managed by the `Trainer",
247250
}
248251
model = HookedModel(not_supported)

0 commit comments

Comments
 (0)