Skip to content

Commit d37e8da

Browse files
committed
Follow-up changes to #10575
1 parent 5f8504e commit d37e8da

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,10 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
180180

181181
# hook
182182
self.trainer._call_callback_hooks("on_train_batch_start", batch, batch_idx, **extra_kwargs)
183-
model_response = self.trainer._call_lightning_module_hook(
183+
response = self.trainer._call_lightning_module_hook(
184184
"on_train_batch_start", batch, batch_idx, **extra_kwargs
185185
)
186-
ttp_response = self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
187-
response = ttp_response if model_response is None else model_response
186+
self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
188187
if response == -1:
189188
self.batch_progress.increment_processed()
190189
raise StopIteration

pytorch_lightning/trainer/trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1537,14 +1537,20 @@ def _call_accelerator_hook(
15371537
*args: Any,
15381538
**kwargs: Any,
15391539
) -> Optional[Any]:
1540-
self.lightning_module._current_fx_name = hook_name
1540+
pl_module = self.lightning_module
1541+
prev_fx_name = pl_module._current_fx_name
1542+
pl_module._current_fx_name = hook_name
1543+
15411544
fn = getattr(self.accelerator, hook_name)
15421545
if not callable(fn):
15431546
return None
15441547

15451548
with self.profiler.profile(hook_name):
15461549
output = fn(*args, **kwargs)
15471550

1551+
# restore current_fx when nested context
1552+
pl_module._current_fx_name = prev_fx_name
1553+
15481554
return output
15491555

15501556
@staticmethod

0 commit comments

Comments
 (0)