Skip to content

Commit 99adc45

Browse files
carmoccaawaelchli
andauthored
Follow-up changes to #10575 (#10957)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 42b5417 commit 99adc45

File tree

11 files changed

+65
-46
lines changed

11 files changed

+65
-46
lines changed

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,10 @@ def _on_evaluation_model_eval(self) -> None:
186186

187187
def _on_evaluation_model_train(self) -> None:
188188
"""Sets model to train mode."""
189-
model_ref = self.trainer.lightning_module
190189
if self.trainer.testing:
191-
model_ref.on_test_model_train()
190+
self.trainer._call_lightning_module_hook("on_test_model_train")
192191
else:
193-
model_ref.on_validation_model_train()
192+
self.trainer._call_lightning_module_hook("on_validation_model_train")
194193

195194
def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
196195
"""Runs ``on_{validation/test}_end`` hook."""

pytorch_lightning/loops/epoch/evaluation_epoch_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,9 @@ def _evaluation_step(self, **kwargs: Any) -> Optional[STEP_OUTPUT]:
218218
the outputs of the step
219219
"""
220220
if self.trainer.testing:
221-
output = self.trainer._call_accelerator_hook("test_step", *kwargs.values())
221+
output = self.trainer._call_ttp_hook("test_step", *kwargs.values())
222222
else:
223-
output = self.trainer._call_accelerator_hook("validation_step", *kwargs.values())
223+
output = self.trainer._call_ttp_hook("validation_step", *kwargs.values())
224224

225225
return output
226226

pytorch_lightning/loops/epoch/prediction_epoch_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None
130130

131131
self.batch_progress.increment_started()
132132

133-
predictions = self.trainer._call_accelerator_hook("predict_step", *step_kwargs.values())
133+
predictions = self.trainer._call_ttp_hook("predict_step", *step_kwargs.values())
134134

135135
self.batch_progress.increment_processed()
136136

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,10 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
180180

181181
# hook
182182
self.trainer._call_callback_hooks("on_train_batch_start", batch, batch_idx, **extra_kwargs)
183-
model_response = self.trainer._call_lightning_module_hook(
183+
response = self.trainer._call_lightning_module_hook(
184184
"on_train_batch_start", batch, batch_idx, **extra_kwargs
185185
)
186-
ttp_response = self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
187-
response = ttp_response if model_response is None else model_response
186+
self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
188187
if response == -1:
189188
self.batch_progress.increment_processed()
190189
raise StopIteration

pytorch_lightning/loops/optimization/manual_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
102102
)
103103

104104
# manually capture logged metrics
105-
training_step_output = self.trainer._call_accelerator_hook("training_step", *step_kwargs.values())
105+
training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values())
106106
self.trainer.training_type_plugin.post_training_step()
107107

108108
del step_kwargs

pytorch_lightning/loops/optimization/optimizer_loop.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,10 @@ def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
144144
)
145145

146146
if self._zero_grad_fn is not None:
147-
with self._profiler.profile("zero_grad"):
148-
self._zero_grad_fn()
147+
self._zero_grad_fn()
149148

150149
if self._backward_fn is not None and step_output.closure_loss is not None:
151-
with self._profiler.profile("backward"):
152-
self._backward_fn(step_output.closure_loss)
150+
self._backward_fn(step_output.closure_loss)
153151

154152
return step_output
155153

@@ -320,7 +318,7 @@ def _make_backward_fn(self, optimizer: Optimizer, opt_idx: int) -> Optional[Call
320318
return None
321319

322320
def backward_fn(loss: Tensor) -> None:
323-
self.trainer.training_type_plugin.backward(loss, optimizer, opt_idx)
321+
self.trainer._call_ttp_hook("backward", loss, optimizer, opt_idx)
324322

325323
# check if model weights are nan
326324
if self.trainer._terminate_on_nan:
@@ -362,16 +360,15 @@ def _optimizer_step(
362360
train_step_and_backward_closure: the closure function performing the train step and computing the
363361
gradients. By default called by the optimizer (if possible)
364362
"""
365-
lightning_module = self.trainer.lightning_module
366-
367363
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
368364
# wraps into LightningOptimizer only for running step
369365
optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx)
370366

371367
self.optim_progress.optimizer.step.increment_ready()
372368

373369
# model hook
374-
lightning_module.optimizer_step(
370+
self.trainer._call_lightning_module_hook(
371+
"optimizer_step",
375372
self.trainer.current_epoch,
376373
batch_idx,
377374
optimizer,
@@ -403,7 +400,7 @@ def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer,
403400
optimizer: the current optimizer
404401
opt_idx: the index of the current optimizer
405402
"""
406-
self.trainer.training_type_plugin.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
403+
self.trainer._call_ttp_hook("optimizer_zero_grad", self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
407404
self.optim_progress.optimizer.zero_grad.increment_completed()
408405

409406
def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult:
@@ -427,7 +424,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos
427424
)
428425

429426
# manually capture logged metrics
430-
training_step_output = self.trainer._call_accelerator_hook("training_step", *step_kwargs.values())
427+
training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values())
431428
self.trainer.training_type_plugin.post_training_step()
432429

433430
del step_kwargs

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,7 @@ def optimizer_step(
164164

165165
def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
166166
"""Zeros all model parameter's gradients."""
167-
model_ref = self.lightning_module
168-
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)
167+
self.lightning_module.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)
169168

170169
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
171170
"""Setup a model and multiple optimizers together.

pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,24 @@ class _LogOptions(TypedDict):
3131
"on_before_backward": _LogOptions(
3232
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
3333
),
34+
"backward": _LogOptions(
35+
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
36+
),
3437
"on_after_backward": _LogOptions(
3538
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
3639
),
3740
"on_before_optimizer_step": _LogOptions(
3841
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
3942
),
43+
"optimizer_step": _LogOptions(
44+
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
45+
),
4046
"on_before_zero_grad": _LogOptions(
4147
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
4248
),
49+
"optimizer_zero_grad": _LogOptions(
50+
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
51+
),
4352
"on_init_start": None,
4453
"on_init_end": None,
4554
"on_fit_start": None,
@@ -160,6 +169,8 @@ class _LogOptions(TypedDict):
160169
"configure_callbacks": None,
161170
"on_validation_model_eval": None,
162171
"on_test_model_eval": None,
172+
"on_validation_model_train": None,
173+
"on_test_model_train": None,
163174
}
164175

165176
@classmethod

pytorch_lightning/trainer/trainer.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,15 +1452,15 @@ def _call_lightning_module_hook(
14521452
*args: Any,
14531453
pl_module: Optional["pl.LightningModule"] = None,
14541454
**kwargs: Any,
1455-
):
1455+
) -> Any:
14561456
pl_module = pl_module or self.lightning_module
14571457

14581458
if pl_module is None:
14591459
raise TypeError("No Lightning Module is available to call hooks on")
14601460

14611461
fn = getattr(pl_module, hook_name)
14621462
if not callable(fn):
1463-
return None
1463+
return
14641464

14651465
prev_fx_name = pl_module._current_fx_name
14661466
pl_module._current_fx_name = hook_name
@@ -1479,16 +1479,15 @@ def _call_callback_hooks(
14791479
hook_name: str,
14801480
*args: Any,
14811481
**kwargs: Any,
1482-
) -> Optional[Any]:
1483-
output = None
1482+
) -> None:
14841483
if hook_name in ("on_init_start", "on_init_end"):
14851484
# these `Callback` hooks are the only ones that do not take a lightning module.
14861485
# we also don't profile bc profiler hasn't been set yet
14871486
for callback in self.callbacks:
14881487
fn = getattr(callback, hook_name)
14891488
if callable(fn):
1490-
output = fn(self, *args, **kwargs)
1491-
return output
1489+
fn(self, *args, **kwargs)
1490+
return
14921491

14931492
pl_module = self.lightning_module
14941493
if pl_module:
@@ -1500,34 +1499,39 @@ def _call_callback_hooks(
15001499
fn = getattr(self, hook_name)
15011500
if callable(fn):
15021501
with self.profiler.profile(hook_name):
1503-
output = fn(*args, **kwargs)
1502+
fn(*args, **kwargs)
15041503
else:
15051504
for callback in self.callbacks:
15061505
fn = getattr(callback, hook_name)
15071506
if callable(fn):
15081507
with self.profiler.profile(hook_name):
1509-
output = fn(self, self.lightning_module, *args, **kwargs)
1508+
fn(self, self.lightning_module, *args, **kwargs)
15101509

15111510
if pl_module:
15121511
# restore current_fx when nested context
15131512
pl_module._current_fx_name = prev_fx_name
15141513

1515-
return output
1516-
15171514
# TODO: rename to _call_strategy_hook and eventually no longer need this
15181515
def _call_ttp_hook(
15191516
self,
15201517
hook_name: str,
15211518
*args: Any,
15221519
**kwargs: Any,
1523-
):
1520+
) -> Any:
1521+
pl_module = self.lightning_module
1522+
prev_fx_name = pl_module._current_fx_name
1523+
pl_module._current_fx_name = hook_name
1524+
15241525
fn = getattr(self.training_type_plugin, hook_name)
15251526
if not callable(fn):
1526-
return None
1527+
return
15271528

15281529
with self.profiler.profile(hook_name):
15291530
output = fn(*args, **kwargs)
15301531

1532+
# restore current_fx when nested context
1533+
pl_module._current_fx_name = prev_fx_name
1534+
15311535
return output
15321536

15331537
# TODO: eventually no longer need this
@@ -1536,15 +1540,21 @@ def _call_accelerator_hook(
15361540
hook_name: str,
15371541
*args: Any,
15381542
**kwargs: Any,
1539-
) -> Optional[Any]:
1540-
self.lightning_module._current_fx_name = hook_name
1541-
fn = getattr(self.training_type_plugin, hook_name)
1543+
) -> Any:
1544+
pl_module = self.lightning_module
1545+
prev_fx_name = pl_module._current_fx_name
1546+
pl_module._current_fx_name = hook_name
1547+
1548+
fn = getattr(self.accelerator, hook_name)
15421549
if not callable(fn):
1543-
return None
1550+
return
15441551

15451552
with self.profiler.profile(hook_name):
15461553
output = fn(*args, **kwargs)
15471554

1555+
# restore current_fx when nested context
1556+
pl_module._current_fx_name = prev_fx_name
1557+
15481558
return output
15491559

15501560
@staticmethod

tests/trainer/logging_/test_logger_connector.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131

3232

3333
def test_fx_validator(tmpdir):
34-
funcs_name = sorted(get_members(Callback))
34+
funcs_name = get_members(Callback)
3535

36-
callbacks_func = [
36+
callbacks_func = {
3737
"on_before_backward",
3838
"on_after_backward",
3939
"on_before_optimizer_step",
@@ -82,9 +82,9 @@ def test_fx_validator(tmpdir):
8282
"on_predict_start",
8383
"setup",
8484
"teardown",
85-
]
85+
}
8686

87-
not_supported = [
87+
not_supported = {
8888
"on_before_accelerator_backend_setup",
8989
"on_fit_end",
9090
"on_fit_start",
@@ -110,11 +110,10 @@ def test_fx_validator(tmpdir):
110110
"on_validation_end",
111111
"setup",
112112
"teardown",
113-
]
113+
}
114114

115-
assert funcs_name == sorted(
116-
callbacks_func
117-
), "Detected new callback function. Need to add its logging permission to FxValidator and update this test"
115+
# Detected new callback function. Need to add its logging permission to FxValidator and update this test
116+
assert funcs_name == callbacks_func
118117

119118
validator = _FxValidator()
120119

@@ -233,6 +232,7 @@ def test_fx_validator_integration(tmpdir):
233232
"prepare_data": "You can't",
234233
"configure_callbacks": "You can't",
235234
"on_validation_model_eval": "You can't",
235+
"on_validation_model_train": "You can't",
236236
"summarize": "not managed by the `Trainer",
237237
}
238238
model = HookedModel(not_supported)
@@ -260,6 +260,7 @@ def test_fx_validator_integration(tmpdir):
260260
"on_test_dataloader": "You can't",
261261
"test_dataloader": "You can't",
262262
"on_test_model_eval": "You can't",
263+
"on_test_model_train": "You can't",
263264
"on_test_end": "You can't",
264265
}
265266
)

tests/trainer/logging_/test_loop_logging.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,12 @@ def _make_assertion(model, hooks, result_mock, on_step, on_epoch, extra_kwargs):
5050
trainer.state.stage = RunningStage.TRAINING
5151
hooks = [
5252
"on_before_backward",
53+
"backward",
5354
"on_after_backward",
5455
"on_before_optimizer_step",
56+
"optimizer_step",
5557
"on_before_zero_grad",
58+
"optimizer_zero_grad",
5659
"training_step",
5760
"training_step_end",
5861
"on_batch_start",

0 commit comments

Comments
 (0)