Skip to content

Re-design call_hook interface #10575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
Dec 4, 2021
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
dc8e838
first draft
daniellepintz Nov 16, 2021
3667a4d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 16, 2021
36078f2
doc fix
daniellepintz Nov 16, 2021
09949ee
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Nov 22, 2021
eabeba4
separate call_hooks
daniellepintz Nov 22, 2021
ec445a9
update call_hook refs
daniellepintz Nov 22, 2021
0c59dd8
fix more refs
daniellepintz Nov 22, 2021
6513caa
fix log
daniellepintz Nov 22, 2021
28701a0
cover edge case hooks
daniellepintz Nov 24, 2021
d3bdb46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2021
319ec5f
small fix
daniellepintz Nov 24, 2021
6ba14e4
Merge branch 'call_hook' of github.com:daniellepintz/pytorch-lightnin…
daniellepintz Nov 24, 2021
fda4d49
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Nov 24, 2021
088c441
only profile hook_name
daniellepintz Nov 24, 2021
800a11d
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Nov 24, 2021
460028a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2021
282a86b
flake8
daniellepintz Nov 24, 2021
14163ea
Merge branch 'call_hook' of github.com:daniellepintz/pytorch-lightnin…
daniellepintz Nov 24, 2021
fe2a76f
Fix failing tests. A ttp call was missed
carmocca Nov 25, 2021
d64d524
address comments
daniellepintz Nov 25, 2021
2617675
Merge branch 'call_hook' of github.com:daniellepintz/pytorch-lightnin…
daniellepintz Nov 25, 2021
a5fe96d
fix
daniellepintz Nov 25, 2021
63323bf
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Nov 26, 2021
a3d14ed
fix mypy
daniellepintz Nov 26, 2021
6eeec86
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Nov 26, 2021
53d6b53
remove setting of _current_fx_name
daniellepintz Nov 29, 2021
e5a3e3a
fix flake8
daniellepintz Nov 29, 2021
d4f66ae
fix
daniellepintz Nov 29, 2021
41dfc44
add asserts and optimizations
daniellepintz Dec 2, 2021
848a511
addr comments
daniellepintz Dec 2, 2021
d0b59c1
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Dec 3, 2021
ba89d1c
fix hook not callable
daniellepintz Dec 3, 2021
13e7b5d
fix ttp trainer ref
daniellepintz Dec 3, 2021
890e2e7
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Dec 3, 2021
9261c99
fix bad merge
daniellepintz Dec 3, 2021
3c5fe0c
fix on_train_batch_start
daniellepintz Dec 3, 2021
813c24a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 3, 2021
c48cdb2
fix
daniellepintz Dec 3, 2021
7b4f005
fix
daniellepintz Dec 3, 2021
7146f26
fix broken test
daniellepintz Dec 3, 2021
a163136
fix test
daniellepintz Dec 4, 2021
7e8ed03
fix _call_callback_hooks
daniellepintz Dec 4, 2021
492fc62
TypeError and other fix
daniellepintz Dec 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pl_examples/loop_examples/yielding_training_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def _training_step(self, generator):
training_step_output = next(generator)
self.trainer.training_type_plugin.post_training_step()

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
ttp_output = self.trainer._call_ttp_hook("training_step_end", training_step_output)
training_step_output = ttp_output if model_output is None else model_output

# The closure result takes care of properly detaching the loss for logging and peforms
# some additional checks that the output format is correct.
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def log(
value = apply_to_collection(value, numbers.Number, self.__to_tensor)

if self.trainer.logger_connector.should_reset_tensors(self._current_fx_name):
# if we started a new epoch (running it's first batch) the hook name has changed
# if we started a new epoch (running its first batch) the hook name has changed
# reset any tensors for the new hook name
results.reset(metrics=False, fx=self._current_fx_name)

Expand Down
46 changes: 27 additions & 19 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT


Expand Down Expand Up @@ -170,16 +169,20 @@ def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
self._results.to(device=self.trainer.lightning_module.device)

if self.trainer.testing:
self.trainer.call_hook("on_test_start", *args, **kwargs)
self.trainer._call_callback_hooks("on_test_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_test_start", *args, **kwargs)
self.trainer._call_ttp_hook("on_test_start", *args, **kwargs)
else:
self.trainer.call_hook("on_validation_start", *args, **kwargs)
self.trainer._call_callback_hooks("on_validation_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_validation_start", *args, **kwargs)
self.trainer._call_ttp_hook("on_validation_start", *args, **kwargs)

def _on_evaluation_model_eval(self) -> None:
"""Sets model to eval mode."""
if self.trainer.testing:
self.trainer.call_hook("on_test_model_eval")
self.trainer._call_lightning_module_hook("on_test_model_eval")
else:
self.trainer.call_hook("on_validation_model_eval")
self.trainer._call_lightning_module_hook("on_validation_model_eval")

def _on_evaluation_model_train(self) -> None:
"""Sets model to train mode."""
Expand All @@ -192,22 +195,29 @@ def _on_evaluation_model_train(self) -> None:
def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_end`` hook."""
if self.trainer.testing:
self.trainer.call_hook("on_test_end", *args, **kwargs)
self.trainer._call_callback_hooks("on_test_end", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_test_end", *args, **kwargs)
self.trainer._call_ttp_hook("on_test_end", *args, **kwargs)
else:
self.trainer.call_hook("on_validation_end", *args, **kwargs)
self.trainer._call_callback_hooks("on_validation_end", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_validation_end", *args, **kwargs)
self.trainer._call_ttp_hook("on_validation_end", *args, **kwargs)

# reset the logger connector state
self.trainer.logger_connector.reset_results()

def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks."""
self.trainer.logger_connector.on_epoch_start()
self.trainer.call_hook("on_epoch_start", *args, **kwargs)
self.trainer._call_callback_hooks("on_epoch_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_epoch_start", *args, **kwargs)

if self.trainer.testing:
self.trainer.call_hook("on_test_epoch_start", *args, **kwargs)
self.trainer._call_callback_hooks("on_test_epoch_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_test_epoch_start", *args, **kwargs)
else:
self.trainer.call_hook("on_validation_epoch_start", *args, **kwargs)
self.trainer._call_callback_hooks("on_validation_epoch_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_validation_epoch_start", *args, **kwargs)

def _evaluation_epoch_end(self, outputs: List[EPOCH_OUTPUT]) -> None:
"""Runs ``{validation/test}_epoch_end``"""
Expand All @@ -222,18 +232,16 @@ def _evaluation_epoch_end(self, outputs: List[EPOCH_OUTPUT]) -> None:
# call the model epoch end
model = self.trainer.lightning_module
if self.trainer.testing:
if is_overridden("test_epoch_end", model):
model._current_fx_name = "test_epoch_end"
model.test_epoch_end(output_or_outputs)

self.trainer._call_lightning_module_hook("test_epoch_end", output_or_outputs)
else:
if is_overridden("validation_epoch_end", model):
model._current_fx_name = "validation_epoch_end"
model.validation_epoch_end(output_or_outputs)
self.trainer._call_lightning_module_hook("validation_epoch_end", output_or_outputs)

def _on_evaluation_epoch_end(self) -> None:
"""Runs ``on_{validation/test}_epoch_end`` hook."""
hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end"
self.trainer.call_hook(hook_name)
self.trainer.call_hook("on_epoch_end")
self.trainer._call_callback_hooks(hook_name)
self.trainer._call_lightning_module_hook(hook_name)

self.trainer._call_callback_hooks("on_epoch_end")
self.trainer._call_lightning_module_hook("on_epoch_end")
self.trainer.logger_connector.on_epoch_end()
15 changes: 11 additions & 4 deletions pytorch_lightning/loops/dataloader/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,12 @@ def _on_predict_start(self) -> None:
self.trainer.lightning_module.zero_grad()

# hook
self.trainer.call_hook("on_predict_start")
self.trainer.call_hook("on_predict_epoch_start")
self.trainer._call_callback_hooks("on_predict_start")
self.trainer._call_lightning_module_hook("on_predict_start")
self.trainer._call_ttp_hook("on_predict_start")

self.trainer._call_callback_hooks("on_predict_epoch_start")
self.trainer._call_lightning_module_hook("on_predict_epoch_start")

def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
"""Calls ``on_predict_epoch_end`` hook.
Expand All @@ -118,7 +122,8 @@ def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
"""
results = self.predictions

self.trainer.call_hook("on_predict_epoch_end", results)
self.trainer._call_callback_hooks("on_predict_epoch_end", results)
self.trainer._call_lightning_module_hook("on_predict_epoch_end", results)

if self.return_predictions:
return results[0] if self.num_dataloaders == 1 else results
Expand All @@ -130,7 +135,9 @@ def _on_predict_end(self) -> None:
self.epoch_batch_indices = []

# hook
self.trainer.call_hook("on_predict_end")
self.trainer._call_callback_hooks("on_predict_end")
self.trainer._call_lightning_module_hook("on_predict_end")
self.trainer._call_ttp_hook("on_predict_end")

def _on_predict_model_eval(self) -> None:
"""Calls ``on_predict_model_eval`` hook."""
Expand Down
19 changes: 11 additions & 8 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,19 +219,19 @@ def _evaluation_step(self, **kwargs: Any) -> Optional[STEP_OUTPUT]:
"""
if self.trainer.testing:
self.trainer.lightning_module._current_fx_name = "test_step"
with self.trainer.profiler.profile("test_step"):
output = self.trainer.accelerator.test_step(*kwargs.values())
output = self.trainer._call_accelerator_hook("test_step", *kwargs.values())
else:
self.trainer.lightning_module._current_fx_name = "validation_step"
with self.trainer.profiler.profile("validation_step"):
output = self.trainer.accelerator.validation_step(*kwargs.values())
output = self.trainer._call_accelerator_hook("validation_step", *kwargs.values())

return output

def _evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
"""Calls the `{validation/test}_step_end` hook."""
hook_name = "test_step_end" if self.trainer.testing else "validation_step_end"
output = self.trainer.call_hook(hook_name, *args, **kwargs)
model_output = self.trainer._call_lightning_module_hook(hook_name, *args, **kwargs)
ttp_output = self.trainer._call_ttp_hook(hook_name, *args, **kwargs)
output = ttp_output if model_output is None else model_output
return output

def _on_evaluation_batch_start(self, **kwargs: Any) -> None:
Expand All @@ -249,9 +249,11 @@ def _on_evaluation_batch_start(self, **kwargs: Any) -> None:

kwargs.setdefault("dataloader_idx", 0) # TODO: the argument should be keyword for these
if self.trainer.testing:
self.trainer.call_hook("on_test_batch_start", *kwargs.values())
self.trainer._call_callback_hooks("on_test_batch_start", *kwargs.values())
self.trainer._call_lightning_module_hook("on_test_batch_start", *kwargs.values())
else:
self.trainer.call_hook("on_validation_batch_start", *kwargs.values())
self.trainer._call_callback_hooks("on_validation_batch_start", *kwargs.values())
self.trainer._call_lightning_module_hook("on_validation_batch_start", *kwargs.values())

def _on_evaluation_batch_end(self, output: Optional[STEP_OUTPUT], **kwargs: Any) -> None:
"""The ``on_{validation/test}_batch_end`` hook.
Expand All @@ -264,7 +266,8 @@ def _on_evaluation_batch_end(self, output: Optional[STEP_OUTPUT], **kwargs: Any)
"""
kwargs.setdefault("dataloader_idx", 0) # TODO: the argument should be keyword for these
hook_name = "on_test_batch_end" if self.trainer.testing else "on_validation_batch_end"
self.trainer.call_hook(hook_name, output, *kwargs.values())
self.trainer._call_callback_hooks(hook_name, output, *kwargs.values())
self.trainer._call_lightning_module_hook(hook_name, output, *kwargs.values())

self.trainer.logger_connector.on_batch_end()

Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,21 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None
# extract batch_indices and store them
self.current_batch_indices = self._seen_batch_indices[batch_idx] if self._seen_batch_indices else []

model_ref = self.trainer.lightning_module

self.trainer.call_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx)
self.trainer._call_callback_hooks("on_predict_batch_start", batch, batch_idx, dataloader_idx)
self.trainer._call_lightning_module_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx)

self.batch_progress.increment_started()

model_ref._current_fx_name = "predict_step"
predictions = self.trainer.accelerator.predict_step(*step_kwargs.values())
self.trainer.lightning_module._current_fx_name = "predict_step"
predictions = self.trainer._call_accelerator_hook("predict_step", *step_kwargs.values())

self.batch_progress.increment_processed()

if predictions is None:
self._warning_cache.warn("predict returned None if it was on purpose, ignore this warning...")

self.trainer.call_hook("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx)
self.trainer._call_callback_hooks("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx)
self.trainer._call_lightning_module_hook("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx)

self.batch_progress.increment_completed()

Expand Down
33 changes: 23 additions & 10 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,11 @@ def reset(self) -> None:
def on_run_start(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override]
# hook
self.trainer.logger_connector.on_epoch_start()
self.trainer.call_hook("on_epoch_start")
self.trainer.call_hook("on_train_epoch_start")
self.trainer._call_callback_hooks("on_epoch_start")
self.trainer._call_lightning_module_hook("on_epoch_start")

self.trainer._call_callback_hooks("on_train_epoch_start")
self.trainer._call_lightning_module_hook("on_train_epoch_start")
self.trainer.fit_loop.epoch_progress.increment_started()

self._reload_dataloader_state_dict(data_fetcher)
Expand Down Expand Up @@ -165,7 +168,7 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
batch_output = []
else:
# hook
self.trainer.call_hook("on_batch_start")
self.trainer._call_callback_hooks("on_batch_start")

# TODO: Update this in v1.7 (deprecation: #9816)
model_fx = self.trainer.lightning_module.on_train_batch_start
Expand All @@ -176,7 +179,12 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
)

# hook
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
self.trainer._call_callback_hooks("on_train_batch_start", batch, batch_idx, **extra_kwargs)
model_response = self.trainer._call_lightning_module_hook(
"on_train_batch_start", batch, batch_idx, **extra_kwargs
)
ttp_response = self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
response = ttp_response if model_response is None else model_response
if response == -1:
self.batch_progress.increment_processed()
raise StopIteration
Expand Down Expand Up @@ -207,8 +215,11 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
else {}
)
self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs)
self.trainer.call_hook("on_batch_end")
self.trainer._call_callback_hooks("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs)
self.trainer._call_lightning_module_hook(
"on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs
)
self.trainer._call_callback_hooks("on_batch_end")
self.trainer.logger_connector.on_batch_end()

self.batch_progress.increment_completed()
Expand Down Expand Up @@ -276,8 +287,7 @@ def on_run_end(self) -> None:
)
# run lightning module hook training_epoch_end
# refresh the result for custom logging at the epoch level
model._current_fx_name = "training_epoch_end"
epoch_end_outputs = model.training_epoch_end(epoch_end_outputs)
epoch_end_outputs = self.trainer._call_lightning_module_hook("training_epoch_end", epoch_end_outputs)
if epoch_end_outputs is not None:
raise MisconfigurationException(
"`training_epoch_end` expects a return of None. "
Expand All @@ -289,8 +299,11 @@ def on_run_end(self) -> None:
self.trainer.fit_loop.epoch_progress.increment_processed()

# call train epoch end hooks
self.trainer.call_hook("on_train_epoch_end")
self.trainer.call_hook("on_epoch_end")
self.trainer._call_callback_hooks("on_train_epoch_end")
self.trainer._call_lightning_module_hook("on_train_epoch_end")

self.trainer._call_callback_hooks("on_epoch_end")
self.trainer._call_lightning_module_hook("on_epoch_end")
self.trainer.logger_connector.on_epoch_end()

if self._num_ready_batches_reached():
Expand Down
8 changes: 6 additions & 2 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ def on_run_start(self) -> None: # type: ignore[override]
self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)
self._is_fresh_start_epoch = True
self._results.to(device=self.trainer.lightning_module.device)
self.trainer.call_hook("on_train_start")
self.trainer._call_callback_hooks("on_train_start")
self.trainer._call_lightning_module_hook("on_train_start")
self.trainer._call_accelerator_hook("on_train_start")

def on_advance_start(self) -> None: # type: ignore[override]
"""Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and
Expand Down Expand Up @@ -248,7 +250,9 @@ def on_run_end(self) -> None:
self.current_epoch = max(self.current_epoch - 1, 0)

# hook
self.trainer.call_hook("on_train_end")
self.trainer._call_callback_hooks("on_train_end")
self.trainer._call_lightning_module_hook("on_train_end")
self.trainer._call_ttp_hook("on_train_end")

# give accelerators a chance to finish
self.trainer.training_type_plugin.on_train_end()
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/loops/optimization/manual_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]

# manually capture logged metrics
lightning_module._current_fx_name = "training_step"
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator.training_step(*step_kwargs.values())
self.trainer.training_type_plugin.post_training_step()
training_step_output = self.trainer._call_accelerator_hook("training_step", *step_kwargs.values())
self.trainer.training_type_plugin.post_training_step()

del step_kwargs

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)

model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
ttp_output = self.trainer._call_ttp_hook("training_step_end", training_step_output)
training_step_output = ttp_output if model_output is None else model_output
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)

result = self.output_result_cls.from_training_step_output(training_step_output)
Expand Down
Loading