diff --git a/CHANGELOG.md b/CHANGELOG.md index fa9bd9c0ce71b..ed23d361164d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,7 +63,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changes in `LightningCLI` required for the new major release of jsonargparse v4.0.0 ([#10426](https://github.com/PyTorchLightning/pytorch-lightning/pull/10426)) -- Renamed `refresh_rate_per_second` parameter to `referesh_rate` for `RichProgressBar` signature ([#10497](https://github.com/PyTorchLightning/pytorch-lightning/pull/10497)) +- Renamed `refresh_rate_per_second` parameter to `refresh_rate` for `RichProgressBar` signature ([#10497](https://github.com/PyTorchLightning/pytorch-lightning/pull/10497)) - Moved ownership of the `PrecisionPlugin` into `TrainingTypePlugin` and updated all references ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) @@ -116,6 +116,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the access to the attribute `IndexBatchSamplerWrapper.batch_indices` in favor of `IndexBatchSamplerWrapper.seen_batch_indices` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870)) +- Deprecated `Trainer.call_hook` in favor of `Trainer._call_callback_hooks`, `Trainer._call_lightning_module_hook`, `Trainer._call_ttp_hook`, and `Trainer._call_accelerator_hook` ([#10979](https://github.com/PyTorchLightning/pytorch-lightning/pull/10979)) + + - Deprecated `TrainingTypePlugin.post_dispatch` in favor of `TrainingTypePlugin.teardown` ([#10939](https://github.com/PyTorchLightning/pytorch-lightning/pull/10939)) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bb9398570bc08..7687b5ecf148c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1446,6 +1446,63 @@ def _call_teardown_hook(self) -> None: # summarize profile results self.profiler.describe() + def call_hook( + self, hook_name: str, *args: Any, pl_module: Optional["pl.LightningModule"] = None, **kwargs: Any + ) -> Any: + r""" + .. deprecated:: v1.6 + The Trainer's `call_hook` method was deprecated in v1.6 and will be removed in v1.8. + """ + rank_zero_deprecation("The Trainer's `call_hook` method was deprecated in v1.6 and will be removed in v1.8.") + pl_module = self.lightning_module or pl_module + if pl_module: + prev_fx_name = pl_module._current_fx_name + pl_module._current_fx_name = hook_name + + # always profile hooks + with self.profiler.profile(hook_name): + + # first call trainer hook + callback_fx = getattr(self, hook_name, None) + if callable(callback_fx): + callback_fx(*args, **kwargs) + + # next call hook in lightningModule + output = None + model_fx = getattr(pl_module, hook_name, None) + if callable(model_fx): + output = model_fx(*args, **kwargs) + + # *Bad code alert* + # The `Accelerator` mostly calls the `TrainingTypePlugin` but some of those calls are deprecated. + # The following logic selectively chooses which hooks are called on each object. + # In the case of `setup` and `teardown`, the hooks on the `LightningModule` should not call the hooks of the + # same name in these objects as they are meant to be managed outside of the `LightningModule` lifecycle. + # All of this should be fixed by #8506 + + # call the accelerator hook + if hook_name in ("on_train_start",) and hasattr(self.accelerator, hook_name): + accelerator_hook = getattr(self.accelerator, hook_name) + accelerator_output = accelerator_hook(*args, **kwargs) + # Rely on the accelerator output if lightningModule hook returns nothing + # Required for cases such as DataParallel where we reduce the output for the user + # todo: move this data parallel logic into the data parallel plugin + output = accelerator_output if output is None else output + + # call the ttp hook + if hook_name not in ("setup", "teardown", "on_train_start") and hasattr( + self.training_type_plugin, hook_name + ): + ttp_hook = getattr(self.training_type_plugin, hook_name) + ttp_output = ttp_hook(*args, **kwargs) + output = ttp_output if output is None else output + + if pl_module: + # restore current_fx when nested context + pl_module._current_fx_name = prev_fx_name + + return output + def _call_lightning_module_hook( self, hook_name: str, diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 7ef0fe2a15e4f..e58e7927641c3 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -16,6 +16,7 @@ import pytest import torch +from pytorch_lightning import Trainer from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.enums import DeviceType, DistributedType from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY @@ -41,3 +42,15 @@ def test_v1_8_0_deprecated_torchtext_batch(): data_iterator, _ = get_dummy_torchtext_data_iterator(num_samples=3, batch_size=3) batch = next(iter(data_iterator)) _ = move_data_to_device(batch=batch, device=torch.device("cpu")) + + +def test_v1_8_0_deprecated_call_hook(): + trainer = Trainer( + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + enable_progress_bar=False, + logger=False, + ) + with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8."): + trainer.call_hook("test_hook")