Skip to content

Deprecate call_hook #10979

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changes in `LightningCLI` required for the new major release of jsonargparse v4.0.0 ([#10426](https://github.com/PyTorchLightning/pytorch-lightning/pull/10426))


- Renamed `refresh_rate_per_second` parameter to `referesh_rate` for `RichProgressBar` signature ([#10497](https://github.com/PyTorchLightning/pytorch-lightning/pull/10497))
- Renamed `refresh_rate_per_second` parameter to `refresh_rate` for `RichProgressBar` signature ([#10497](https://github.com/PyTorchLightning/pytorch-lightning/pull/10497))


- Moved ownership of the `PrecisionPlugin` into `TrainingTypePlugin` and updated all references ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570))
Expand Down Expand Up @@ -116,6 +116,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated the access to the attribute `IndexBatchSamplerWrapper.batch_indices` in favor of `IndexBatchSamplerWrapper.seen_batch_indices` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870))


- Deprecated `Trainer.call_hook` in favor of `Trainer._call_callback_hooks`, `Trainer._call_lightning_module_hook`, `Trainer._call_ttp_hook`, and `Trainer._call_accelerator_hook` ([#10979](https://github.com/PyTorchLightning/pytorch-lightning/pull/10979))


- Deprecated `TrainingTypePlugin.post_dispatch` in favor of `TrainingTypePlugin.teardown` ([#10939](https://github.com/PyTorchLightning/pytorch-lightning/pull/10939))


Expand Down
57 changes: 57 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,6 +1446,63 @@ def _call_teardown_hook(self) -> None:
# summarize profile results
self.profiler.describe()

def call_hook(
self, hook_name: str, *args: Any, pl_module: Optional["pl.LightningModule"] = None, **kwargs: Any
) -> Any:
r"""
.. deprecated:: v1.6
The Trainer's `call_hook` method was deprecated in v1.6 and will be removed in v1.8.
"""
rank_zero_deprecation("The Trainer's `call_hook` method was deprecated in v1.6 and will be removed in v1.8.")
pl_module = self.lightning_module or pl_module
if pl_module:
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name

# always profile hooks
with self.profiler.profile(hook_name):

# first call trainer hook
callback_fx = getattr(self, hook_name, None)
if callable(callback_fx):
callback_fx(*args, **kwargs)

# next call hook in lightningModule
output = None
model_fx = getattr(pl_module, hook_name, None)
if callable(model_fx):
output = model_fx(*args, **kwargs)

# *Bad code alert*
# The `Accelerator` mostly calls the `TrainingTypePlugin` but some of those calls are deprecated.
# The following logic selectively chooses which hooks are called on each object.
# In the case of `setup` and `teardown`, the hooks on the `LightningModule` should not call the hooks of the
# same name in these objects as they are meant to be managed outside of the `LightningModule` lifecycle.
# All of this should be fixed by #8506

# call the accelerator hook
if hook_name in ("on_train_start",) and hasattr(self.accelerator, hook_name):
accelerator_hook = getattr(self.accelerator, hook_name)
accelerator_output = accelerator_hook(*args, **kwargs)
# Rely on the accelerator output if lightningModule hook returns nothing
# Required for cases such as DataParallel where we reduce the output for the user
# todo: move this data parallel logic into the data parallel plugin
output = accelerator_output if output is None else output

# call the ttp hook
if hook_name not in ("setup", "teardown", "on_train_start") and hasattr(
self.training_type_plugin, hook_name
):
ttp_hook = getattr(self.training_type_plugin, hook_name)
ttp_output = ttp_hook(*args, **kwargs)
output = ttp_output if output is None else output

if pl_module:
# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name

return output

def _call_lightning_module_hook(
self,
hook_name: str,
Expand Down
13 changes: 13 additions & 0 deletions tests/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.enums import DeviceType, DistributedType
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
Expand All @@ -41,3 +42,15 @@ def test_v1_8_0_deprecated_torchtext_batch():
data_iterator, _ = get_dummy_torchtext_data_iterator(num_samples=3, batch_size=3)
batch = next(iter(data_iterator))
_ = move_data_to_device(batch=batch, device=torch.device("cpu"))


def test_v1_8_0_deprecated_call_hook():
trainer = Trainer(
max_epochs=1,
limit_val_batches=0.1,
limit_train_batches=0.2,
enable_progress_bar=False,
logger=False,
)
with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8."):
trainer.call_hook("test_hook")