Skip to content

Commit aeb0b55

Browse files
Deprecate call_hook (#10979)
1 parent 6369e3b commit aeb0b55

File tree

3 files changed

+74
-1
lines changed

3 files changed

+74
-1
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6363
- Changes in `LightningCLI` required for the new major release of jsonargparse v4.0.0 ([#10426](https://github.com/PyTorchLightning/pytorch-lightning/pull/10426))
6464

6565

66-
- Renamed `refresh_rate_per_second` parameter to `referesh_rate` for `RichProgressBar` signature ([#10497](https://github.com/PyTorchLightning/pytorch-lightning/pull/10497))
66+
- Renamed `refresh_rate_per_second` parameter to `refresh_rate` for `RichProgressBar` signature ([#10497](https://github.com/PyTorchLightning/pytorch-lightning/pull/10497))
6767

6868

6969
- Moved ownership of the `PrecisionPlugin` into `TrainingTypePlugin` and updated all references ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570))
@@ -116,6 +116,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
116116
- Deprecated the access to the attribute `IndexBatchSamplerWrapper.batch_indices` in favor of `IndexBatchSamplerWrapper.seen_batch_indices` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870))
117117

118118

119+
- Deprecated `Trainer.call_hook` in favor of `Trainer._call_callback_hooks`, `Trainer._call_lightning_module_hook`, `Trainer._call_ttp_hook`, and `Trainer._call_accelerator_hook` ([#10979](https://github.com/PyTorchLightning/pytorch-lightning/pull/10979))
120+
121+
119122
- Deprecated `TrainingTypePlugin.post_dispatch` in favor of `TrainingTypePlugin.teardown` ([#10939](https://github.com/PyTorchLightning/pytorch-lightning/pull/10939))
120123

121124

pytorch_lightning/trainer/trainer.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,6 +1446,63 @@ def _call_teardown_hook(self) -> None:
14461446
# summarize profile results
14471447
self.profiler.describe()
14481448

1449+
def call_hook(
1450+
self, hook_name: str, *args: Any, pl_module: Optional["pl.LightningModule"] = None, **kwargs: Any
1451+
) -> Any:
1452+
r"""
1453+
.. deprecated:: v1.6
1454+
The Trainer's `call_hook` method was deprecated in v1.6 and will be removed in v1.8.
1455+
"""
1456+
rank_zero_deprecation("The Trainer's `call_hook` method was deprecated in v1.6 and will be removed in v1.8.")
1457+
pl_module = self.lightning_module or pl_module
1458+
if pl_module:
1459+
prev_fx_name = pl_module._current_fx_name
1460+
pl_module._current_fx_name = hook_name
1461+
1462+
# always profile hooks
1463+
with self.profiler.profile(hook_name):
1464+
1465+
# first call trainer hook
1466+
callback_fx = getattr(self, hook_name, None)
1467+
if callable(callback_fx):
1468+
callback_fx(*args, **kwargs)
1469+
1470+
# next call hook in lightningModule
1471+
output = None
1472+
model_fx = getattr(pl_module, hook_name, None)
1473+
if callable(model_fx):
1474+
output = model_fx(*args, **kwargs)
1475+
1476+
# *Bad code alert*
1477+
# The `Accelerator` mostly calls the `TrainingTypePlugin` but some of those calls are deprecated.
1478+
# The following logic selectively chooses which hooks are called on each object.
1479+
# In the case of `setup` and `teardown`, the hooks on the `LightningModule` should not call the hooks of the
1480+
# same name in these objects as they are meant to be managed outside of the `LightningModule` lifecycle.
1481+
# All of this should be fixed by #8506
1482+
1483+
# call the accelerator hook
1484+
if hook_name in ("on_train_start",) and hasattr(self.accelerator, hook_name):
1485+
accelerator_hook = getattr(self.accelerator, hook_name)
1486+
accelerator_output = accelerator_hook(*args, **kwargs)
1487+
# Rely on the accelerator output if lightningModule hook returns nothing
1488+
# Required for cases such as DataParallel where we reduce the output for the user
1489+
# todo: move this data parallel logic into the data parallel plugin
1490+
output = accelerator_output if output is None else output
1491+
1492+
# call the ttp hook
1493+
if hook_name not in ("setup", "teardown", "on_train_start") and hasattr(
1494+
self.training_type_plugin, hook_name
1495+
):
1496+
ttp_hook = getattr(self.training_type_plugin, hook_name)
1497+
ttp_output = ttp_hook(*args, **kwargs)
1498+
output = ttp_output if output is None else output
1499+
1500+
if pl_module:
1501+
# restore current_fx when nested context
1502+
pl_module._current_fx_name = prev_fx_name
1503+
1504+
return output
1505+
14491506
def _call_lightning_module_hook(
14501507
self,
14511508
hook_name: str,

tests/deprecated_api/test_remove_1-8.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import pytest
1717
import torch
1818

19+
from pytorch_lightning import Trainer
1920
from pytorch_lightning.utilities.apply_func import move_data_to_device
2021
from pytorch_lightning.utilities.enums import DeviceType, DistributedType
2122
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
@@ -41,3 +42,15 @@ def test_v1_8_0_deprecated_torchtext_batch():
4142
data_iterator, _ = get_dummy_torchtext_data_iterator(num_samples=3, batch_size=3)
4243
batch = next(iter(data_iterator))
4344
_ = move_data_to_device(batch=batch, device=torch.device("cpu"))
45+
46+
47+
def test_v1_8_0_deprecated_call_hook():
48+
trainer = Trainer(
49+
max_epochs=1,
50+
limit_val_batches=0.1,
51+
limit_train_batches=0.2,
52+
enable_progress_bar=False,
53+
logger=False,
54+
)
55+
with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8."):
56+
trainer.call_hook("test_hook")

0 commit comments

Comments
 (0)