Skip to content

Commit 4fc8275

Browse files
authored
Remove the deprecated trainer.call_hook (#14869)
1 parent 35419b5 commit 4fc8275

File tree

4 files changed

+6
-57
lines changed

4 files changed

+6
-57
lines changed

docs/source-pytorch/extensions/loops.rst

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,7 @@ Here is a simple example how to add a new hook:
150150
151151
class CustomFitLoop(FitLoop):
152152
def advance(self):
153-
# ... whatever code before
154-
155-
# pass anything you want to the hook
156-
self.trainer.call_hook("my_new_hook", *args, **kwargs)
157-
158-
# ... whatever code after
153+
"""Put your custom logic here."""
159154
160155
Now simply attach the correct loop in the trainer directly:
161156

src/pytorch_lightning/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
221221

222222
- Removed the deprecated `Trainer.use_amp` and `LightningModule.use_amp` attributes ([#14832](https://github.com/Lightning-AI/lightning/pull/14832))
223223

224+
224225
- Removed the deprecated `Trainer.run_stage` in favor of `Trainer.{fit,validate,test,predict}`
225226

226227

@@ -232,9 +233,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
232233

233234
- Remove the deprecated `Trainer.should_rank_save_checkpoint` ([#14885](https://github.com/Lightning-AI/lightning/pull/14885))
234235

236+
235237
- Removed the deprecated `TrainerDataLoadingMixin` ([#14888](https://github.com/Lightning-AI/lightning/pull/14888))
236238

237239

240+
- Removed the deprecated `Trainer.call_hook` in favor of `Trainer._call_callback_hooks`, `Trainer._call_lightning_module_hook`, `Trainer._call_ttp_hook`, and `Trainer._call_accelerator_hook` ([#14869](https://github.com/Lightning-AI/lightning/pull/14869))
241+
242+
238243
### Fixed
239244

240245
- Fixed an issue with `LightningLite.setup()` not setting the `.device` attribute correctly on the returned wrapper ([#14822](https://github.com/Lightning-AI/lightning/pull/14822))

src/pytorch_lightning/trainer/trainer.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,45 +1289,6 @@ def _call_teardown_hook(self) -> None:
12891289
# summarize profile results
12901290
self.profiler.describe()
12911291

1292-
def call_hook(
1293-
self, hook_name: str, *args: Any, pl_module: Optional["pl.LightningModule"] = None, **kwargs: Any
1294-
) -> Any:
1295-
r"""
1296-
.. deprecated:: v1.6
1297-
The Trainer's `call_hook` method was deprecated in v1.6 and will be removed in v1.8.
1298-
"""
1299-
rank_zero_deprecation("The Trainer's `call_hook` method was deprecated in v1.6 and will be removed in v1.8.")
1300-
pl_module = self.lightning_module or pl_module
1301-
if pl_module:
1302-
prev_fx_name = pl_module._current_fx_name
1303-
pl_module._current_fx_name = hook_name
1304-
1305-
# always profile hooks
1306-
with self.profiler.profile(hook_name):
1307-
1308-
# first call trainer hook
1309-
callback_fx = getattr(self, hook_name, None)
1310-
if callable(callback_fx):
1311-
callback_fx(*args, **kwargs)
1312-
1313-
# next call hook in lightningModule
1314-
output = None
1315-
model_fx = getattr(pl_module, hook_name, None)
1316-
if callable(model_fx):
1317-
output = model_fx(*args, **kwargs)
1318-
1319-
# call the strategy hook
1320-
if hook_name not in ("setup", "teardown", "on_train_start") and hasattr(self.strategy, hook_name):
1321-
strategy_hook = getattr(self.strategy, hook_name)
1322-
strategy_output = strategy_hook(*args, **kwargs)
1323-
output = strategy_output if output is None else output
1324-
1325-
if pl_module:
1326-
# restore current_fx when nested context
1327-
pl_module._current_fx_name = prev_fx_name
1328-
1329-
return output
1330-
13311292
def _call_lightning_module_hook(
13321293
self,
13331294
hook_name: str,

tests/tests_pytorch/deprecated_api/test_remove_1-8.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,18 +52,6 @@ def on_init_end(self, trainer):
5252
trainer.validate(model)
5353

5454

55-
def test_v1_8_0_deprecated_call_hook():
56-
trainer = Trainer(
57-
max_epochs=1,
58-
limit_val_batches=0.1,
59-
limit_train_batches=0.2,
60-
enable_progress_bar=False,
61-
logger=False,
62-
)
63-
with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8."):
64-
trainer.call_hook("test_hook")
65-
66-
6755
@pytest.mark.parametrize("fn_prefix", ["validated", "tested", "predicted"])
6856
def test_v1_8_0_trainer_ckpt_path_attributes(fn_prefix: str):
6957
test_attr = f"{fn_prefix}_ckpt_path"

0 commit comments

Comments
 (0)