Skip to content

Commit 4decbc0

Browse files
authored
Deprecate dataloader_idx from on_train_batch_start/end (#9816)
* deprecate hooks * dep todo * explicit * Apply suggestions from code review * Apply suggestions from code review * code review * base
1 parent 0561fd6 commit 4decbc0

31 files changed

+150
-67
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,7 @@ def on_train_end(self) -> None:
487487
"""Called when train ends."""
488488
return self.training_type_plugin.on_train_end()
489489

490-
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
490+
# TODO: Update this in v1.7 (deprecation: #9816)
491+
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
491492
"""Called in the training loop before anything happens for that batch."""
492-
return self.training_type_plugin.on_train_batch_start(batch, batch_idx, dataloader_idx)
493+
return self.training_type_plugin.on_train_batch_start(batch, batch_idx)

pytorch_lightning/callbacks/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,12 @@ def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningMod
9797
pass
9898

9999
def on_train_batch_start(
100-
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
100+
self,
101+
trainer: "pl.Trainer",
102+
pl_module: "pl.LightningModule",
103+
batch: Any,
104+
batch_idx: int,
105+
unused: Optional[int] = 0,
101106
) -> None:
102107
"""Called when the train batch begins."""
103108
pass
@@ -109,7 +114,7 @@ def on_train_batch_end(
109114
outputs: STEP_OUTPUT,
110115
batch: Any,
111116
batch_idx: int,
112-
dataloader_idx: int,
117+
unused: Optional[int] = 0,
113118
) -> None:
114119
"""Called when the train batch ends."""
115120
pass

pytorch_lightning/callbacks/gpu_stats_monitor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
135135

136136
@rank_zero_only
137137
def on_train_batch_start(
138-
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
138+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
139139
) -> None:
140140
if self._log_stats.intra_step_time:
141141
self._snap_intra_step_time = time.time()
@@ -161,7 +161,6 @@ def on_train_batch_end(
161161
outputs: STEP_OUTPUT,
162162
batch: Any,
163163
batch_idx: int,
164-
dataloader_idx: int,
165164
) -> None:
166165
if self._log_stats.inter_step_time:
167166
self._snap_inter_step_time = time.time()

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,6 @@ def on_train_batch_end(
279279
outputs: STEP_OUTPUT,
280280
batch: Any,
281281
batch_idx: int,
282-
dataloader_idx: int,
283282
) -> None:
284283
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
285284
if self._should_skip_saving_checkpoint(trainer):
@@ -304,9 +303,7 @@ def on_train_batch_end(
304303

305304
self.save_checkpoint(trainer)
306305

307-
def on_train_epoch_end(
308-
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", unused: Optional = None
309-
) -> None:
306+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
310307
"""Save a checkpoint at the end of the training epoch."""
311308
# as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates
312309
trainer.fit_loop.global_step -= 1

pytorch_lightning/callbacks/progress/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def __init__(self):
3535
def disable(self):
3636
self.enable = False
3737
38-
def on_train_batch_end(self, trainer, pl_module, outputs):
39-
super().on_train_batch_end(trainer, pl_module, outputs) # don't forget this :)
38+
def on_train_batch_end(self, trainer, pl_module, outputs, batch_idx):
39+
super().on_train_batch_end(trainer, pl_module, outputs, batch_idx) # don't forget this :)
4040
percent = (self.train_batch_idx / self.total_train_batches) * 100
4141
sys.stdout.flush()
4242
sys.stdout.write(f'{percent:.01f} percent complete \r')
@@ -161,7 +161,7 @@ def on_train_start(self, trainer, pl_module):
161161
def on_train_epoch_start(self, trainer, pl_module):
162162
self._train_batch_idx = trainer.fit_loop.epoch_loop.batch_progress.current.completed
163163

164-
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
164+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
165165
self._train_batch_idx += 1
166166

167167
def on_validation_start(self, trainer, pl_module):

pytorch_lightning/callbacks/progress/rich_progress.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,8 @@ def on_predict_epoch_start(self, trainer, pl_module):
369369
super().on_predict_epoch_start(trainer, pl_module)
370370
self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description)
371371

372-
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
373-
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
372+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
373+
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
374374
self._update(self.main_progress_bar_id)
375375

376376
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):

pytorch_lightning/callbacks/progress/tqdm_progress.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ def on_train_epoch_start(self, trainer, pl_module):
231231
reset(self.main_progress_bar, total=total_batches, current=self.train_batch_idx)
232232
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
233233

234-
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
235-
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
234+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
235+
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
236236
total_batches = self.total_train_batches + self.total_val_batches
237237
total_batches = convert_inf(total_batches)
238238
if self._should_update(self.train_batch_idx, total_batches):

pytorch_lightning/core/hooks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,25 +79,25 @@ def on_pretrain_routine_end(self) -> None:
7979
- training_start
8080
"""
8181

82-
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
82+
def on_train_batch_start(self, batch: Any, batch_idx: int, unused: Optional[int] = 0) -> None:
8383
"""Called in the training loop before anything happens for that batch.
8484
8585
If you return -1 here, you will skip training for the rest of the current epoch.
8686
8787
Args:
8888
batch: The batched data as it is returned by the training DataLoader.
8989
batch_idx: the index of the batch
90-
dataloader_idx: the index of the dataloader
90+
unused: Deprecated argument. Will be removed in v1.7.
9191
"""
9292

93-
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
93+
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, unused: Optional[int] = 0) -> None:
9494
"""Called in the training loop after the batch.
9595
9696
Args:
9797
outputs: The outputs of training_step_end(training_step(x))
9898
batch: The batched data as it is returned by the training DataLoader.
9999
batch_idx: the index of the batch
100-
dataloader_idx: the index of the dataloader
100+
unused: Deprecated argument. Will be removed in v1.7.
101101
"""
102102

103103
def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:

pytorch_lightning/loops/batch/training_batch_loop.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from pytorch_lightning.loops.utilities import _get_active_optimizers
2525
from pytorch_lightning.trainer.supporters import TensorRunningAccum
2626
from pytorch_lightning.utilities import AttributeDict
27+
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
2728
from pytorch_lightning.utilities.warnings import WarningCache
2829

2930
_OUTPUTS_TYPE = List[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]]
@@ -76,7 +77,14 @@ def run(self, batch: Any, batch_idx: int) -> AttributeDict:
7677
return AttributeDict(signal=-1)
7778

7879
# hook
79-
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, 0)
80+
# TODO: Update this in v1.7 (deprecation: #9816)
81+
model_fx = self.trainer.lightning_module.on_train_batch_start
82+
extra_kwargs = (
83+
{"dataloader_idx": 0}
84+
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
85+
else {}
86+
)
87+
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
8088
if response == -1:
8189
return AttributeDict(signal=-1)
8290

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2828
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
2929
from pytorch_lightning.utilities.model_helpers import is_overridden
30+
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
3031

3132
_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
3233

@@ -170,7 +171,15 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
170171
automatic=self.trainer.lightning_module.trainer.lightning_module.automatic_optimization,
171172
num_optimizers=len(self.trainer.optimizers),
172173
)
173-
self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, self.batch_idx, 0)
174+
175+
# TODO: Update this in v1.7 (deprecation: #9816)
176+
model_fx = self.trainer.lightning_module.on_train_batch_end
177+
extra_kwargs = (
178+
{"dataloader_idx": 0}
179+
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
180+
else {}
181+
)
182+
self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs)
174183
self.trainer.call_hook("on_batch_end")
175184
self.trainer.logger_connector.on_batch_end()
176185

pytorch_lightning/plugins/training_type/ipu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def on_test_end(self):
285285
def on_predict_end(self):
286286
self._detach_models()
287287

288-
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
288+
def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
289289
# Updates optimizer stats if LR scheduler modified the optimizer state
290290
optimizer = self.lightning_module.trainer.optimizers[0]
291291
self.poptorch_models[RunningStage.TRAINING].setOptimizer(optimizer)

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def on_predict_end(self):
345345
"""Called when predict ends."""
346346
pass
347347

348-
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
348+
def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
349349
"""Called in the training loop before anything happens for that batch."""
350350
pass
351351

pytorch_lightning/trainer/callback_hook.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pytorch_lightning as pl
2222
from pytorch_lightning.callbacks import Callback
2323
from pytorch_lightning.utilities import rank_zero_warn
24+
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
2425
from pytorch_lightning.utilities.types import STEP_OUTPUT
2526

2627

@@ -161,15 +162,23 @@ def on_batch_end(self):
161162
for callback in self.callbacks:
162163
callback.on_batch_end(self, self.lightning_module)
163164

164-
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
165+
# TODO: Update this in v1.7 (deprecation: #9816)
166+
def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0):
165167
"""Called when the training batch begins."""
166168
for callback in self.callbacks:
167-
callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx)
169+
if is_param_in_hook_signature(callback.on_train_batch_start, "dataloader_idx", explicit=True):
170+
callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, 0)
171+
else:
172+
callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx)
168173

169-
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx):
174+
# TODO: Update this in v1.7 (deprecation: #9816)
175+
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx=0):
170176
"""Called when the training batch ends."""
171177
for callback in self.callbacks:
172-
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)
178+
if is_param_in_hook_signature(callback.on_train_batch_end, "dataloader_idx", explicit=True):
179+
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, 0)
180+
else:
181+
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx)
173182

174183
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
175184
"""Called when the validation batch begins."""

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None:
5050
self._check_on_post_move_to_device(model)
5151
# TODO: Delete _check_on_keyboard_interrupt in v1.7
5252
self._check_on_keyboard_interrupt()
53+
# TODO: Remove this in v1.7 (deprecation: #9816)
54+
self._check_dl_idx_in_on_train_batch_hooks(model)
5355

5456
def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None:
5557
# -----------------------------------
@@ -261,3 +263,18 @@ def _check_on_keyboard_interrupt(self) -> None:
261263
"The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7."
262264
" Please use the `on_exception` callback hook instead."
263265
)
266+
267+
def _check_dl_idx_in_on_train_batch_hooks(self, model: "pl.LightningModule") -> None:
268+
for hook in ("on_train_batch_start", "on_train_batch_end"):
269+
if is_param_in_hook_signature(getattr(model, hook), "dataloader_idx", explicit=True):
270+
rank_zero_deprecation(
271+
f"Base `LightningModule.{hook}` hook signature has changed in v1.5."
272+
" The `dataloader_idx` argument will be removed in v1.7."
273+
)
274+
275+
for cb in self.trainer.callbacks:
276+
if is_param_in_hook_signature(getattr(cb, hook), "dataloader_idx", explicit=True):
277+
rank_zero_deprecation(
278+
f"Base `Callback.{hook}` hook signature has changed in v1.5."
279+
" The `dataloader_idx` argument will be removed in v1.7."
280+
)

pytorch_lightning/tuner/lr_finder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def on_batch_start(self, trainer, pl_module):
344344

345345
self.lrs.append(trainer.lr_schedulers[0]["scheduler"].lr[0])
346346

347-
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
347+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
348348
"""Called when the training batch ends, logs the calculated loss."""
349349
if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
350350
return

tests/accelerators/test_tpu_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def __init__(self):
165165
def should_update(self):
166166
return self.count % 2 == 0
167167

168-
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
168+
def on_train_batch_start(self, batch, batch_idx):
169169
self.called["on_train_batch_start"] += 1
170170
self.weight_before = self.layer.weight.clone()
171171

@@ -181,7 +181,7 @@ def training_step(self, batch, batch_idx):
181181
opt.zero_grad()
182182
return loss
183183

184-
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
184+
def on_train_batch_end(self, outputs, batch, batch_idx):
185185
self.called["on_train_batch_end"] += 1
186186
after_before = self.layer.weight.clone()
187187
if self.should_update:

tests/callbacks/test_callback_hook_outputs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_train_step_no_return(tmpdir, single_cb: bool):
2222
"""Tests that only training_step can be used."""
2323

2424
class CB(Callback):
25-
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
25+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
2626
assert "loss" in outputs
2727

2828
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
@@ -32,7 +32,7 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal
3232
assert "x" in outputs
3333

3434
class TestModel(BoringModel):
35-
def on_train_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None:
35+
def on_train_batch_end(self, outputs, batch, batch_idx: int) -> None:
3636
assert "loss" in outputs
3737

3838
def on_validation_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None:

tests/callbacks/test_progress_bar.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,12 @@ class CurrentProgressBar(ProgressBar):
185185
val_batches_seen = 0
186186
test_batches_seen = 0
187187

188-
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
189-
super().on_train_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
188+
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
189+
super().on_train_batch_start(trainer, pl_module, batch, batch_idx)
190190
assert self.train_batch_idx == trainer.fit_loop.batch_idx
191191

192-
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
193-
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
192+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
193+
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
194194
assert self.train_batch_idx == trainer.fit_loop.batch_idx + 1
195195
if not self.is_disabled and self.train_batch_idx % self.refresh_rate == 0:
196196
assert self.main_progress_bar.n == self.train_batch_idx

tests/core/test_lightning_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,12 +331,12 @@ class TestModel(BoringModel):
331331
def configure_optimizers(self):
332332
return OptimizerWithHooks(self)
333333

334-
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
334+
def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
335335
self.count_on_train_batch_start += 1
336336
optimizer = self.optimizers(use_pl_optimizer=False)
337337
assert len(optimizer._fwd_handles) == 1
338338

339-
def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
339+
def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None:
340340
self.count_on_train_batch_end += 1
341341
del self.trainer._lightning_optimizers
342342
gc.collect() # not necessary, just in case

0 commit comments

Comments
 (0)