Skip to content

Commit e5dfdf3

Browse files
authored
Avoid deprecation warning after #9901 (#9951)
1 parent 1f09cf2 commit e5dfdf3

File tree

5 files changed

+20
-4
lines changed

5 files changed

+20
-4
lines changed

pytorch_lightning/accelerators/gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
4646
return super().setup(trainer)
4747

4848
def on_train_start(self) -> None:
49+
super().on_train_start()
4950
# clear cache before training
5051
torch.cuda.empty_cache()
5152

pytorch_lightning/plugins/training_type/ipu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def on_test_end(self):
285285
def on_predict_end(self):
286286
self._detach_models()
287287

288-
def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
288+
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
289289
# Updates optimizer stats if LR scheduler modified the optimizer state
290290
optimizer = self.lightning_module.trainer.optimizers[0]
291291
self.poptorch_models[RunningStage.TRAINING].setOptimizer(optimizer)

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def on_predict_end(self):
345345
"""Called when predict ends."""
346346
pass
347347

348-
def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
348+
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
349349
"""Called in the training loop before anything happens for that batch."""
350350
pass
351351

pytorch_lightning/trainer/trainer.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1401,15 +1401,30 @@ def call_hook(
14011401
if callable(model_fx):
14021402
output = model_fx(*args, **kwargs)
14031403

1404+
# *Bad code alert*
1405+
# The `Accelerator` mostly calls the `TrainingTypePlugin` but some of those calls are deprecated.
1406+
# The following logic selectively chooses which hooks are called on each object.
1407+
# In the case of `setup` and `teardown`, the hooks on the `LightningModule` should not call the hooks of the
1408+
# same name in these objects as they are meant to be managed outside of the `LightningModule` lifecycle.
1409+
# All of this should be fixed by #8506
1410+
14041411
# call the accelerator hook
1405-
if hook_name not in ("setup", "teardown") and hasattr(self.accelerator, hook_name):
1412+
if hook_name in ("on_train_start",) and hasattr(self.accelerator, hook_name):
14061413
accelerator_hook = getattr(self.accelerator, hook_name)
14071414
accelerator_output = accelerator_hook(*args, **kwargs)
14081415
# Rely on the accelerator output if lightningModule hook returns nothing
14091416
# Required for cases such as DataParallel where we reduce the output for the user
14101417
# todo: move this data parallel logic into the data parallel plugin
14111418
output = accelerator_output if output is None else output
14121419

1420+
# call the ttp hook
1421+
if hook_name not in ("setup", "teardown", "on_train_start") and hasattr(
1422+
self.training_type_plugin, hook_name
1423+
):
1424+
ttp_hook = getattr(self.training_type_plugin, hook_name)
1425+
ttp_output = ttp_hook(*args, **kwargs)
1426+
output = ttp_output if output is None else output
1427+
14131428
if pl_module:
14141429
# restore current_fx when nested context
14151430
pl_module._current_fx_name = prev_fx_name

tests/loops/test_training_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def run_training(**trainer_kwargs):
8686
@pytest.mark.parametrize(["max_epochs", "batch_idx_"], [(2, 5), (3, 8), (4, 12)])
8787
def test_on_train_batch_start_return_minus_one(max_epochs, batch_idx_, tmpdir):
8888
class CurrentModel(BoringModel):
89-
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
89+
def on_train_batch_start(self, batch, batch_idx):
9090
if batch_idx == batch_idx_:
9191
return -1
9292

0 commit comments

Comments
 (0)