Skip to content

Commit a28b4cd

Browse files
carmoccaawaelchli
andauthored
Sort out the dataloader idx logic for evaluation (#10923)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 7792b77 commit a28b4cd

File tree

10 files changed

+63
-77
lines changed

10 files changed

+63
-77
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
109109
# optionally can be set by user
110110
self._example_input_array = None
111111
self._current_fx_name: Optional[str] = None
112-
self._current_dataloader_idx: Optional[int] = None
113112
self._automatic_optimization: bool = True
114113
self._truncated_bptt_steps: int = 0
115114
self._param_requires_grad_state = {}
@@ -419,7 +418,6 @@ def log(
419418
reduce_fx=reduce_fx,
420419
enable_graph=enable_graph,
421420
add_dataloader_idx=add_dataloader_idx,
422-
dataloader_idx=self._current_dataloader_idx,
423421
batch_size=batch_size,
424422
sync_dist=sync_dist and distributed_available(),
425423
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp,

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,16 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
100100
"""Performs evaluation on one single dataloader."""
101101
void(*args, **kwargs)
102102

103-
dataloader_idx: int = self.current_dataloader_idx
103+
dataloader_idx = self.current_dataloader_idx
104104
dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader)
105105
self.data_fetcher = dataloader = self.trainer._data_connector.get_profiled_dataloader(
106106
dataloader, dataloader_idx=dataloader_idx
107107
)
108108
dl_max_batches = self._max_batches[dataloader_idx]
109109

110-
dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
110+
dl_outputs = self.epoch_loop.run(
111+
dataloader, dataloader_idx if self.num_dataloaders > 1 else None, dl_max_batches
112+
)
111113

112114
# store batch level output per dataloader
113115
self._outputs.append(dl_outputs)
@@ -212,17 +214,13 @@ def _evaluation_epoch_end(self, outputs: List[EPOCH_OUTPUT]) -> None:
212214
# inform logger the batch loop has finished
213215
self.trainer.logger_connector.epoch_end_reached()
214216

215-
# call the model epoch end
216-
model = self.trainer.lightning_module
217-
218-
# unset dataloader_idx in model
219-
model._current_dataloader_idx = None
220-
221217
# with a single dataloader don't pass a 2D list
222218
output_or_outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] = (
223219
outputs[0] if len(outputs) > 0 and self.num_dataloaders == 1 else outputs
224220
)
225221

222+
# call the model epoch end
223+
model = self.trainer.lightning_module
226224
if self.trainer.testing:
227225
if is_overridden("test_epoch_end", model):
228226
model._current_fx_name = "test_epoch_end"

pytorch_lightning/loops/epoch/evaluation_epoch_loop.py

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def __init__(self) -> None:
4848

4949
self._outputs: EPOCH_OUTPUT = []
5050
self._dl_max_batches = 0
51-
self._num_dataloaders = 0
5251
self._dataloader_iter: Optional[Iterator] = None
5352
self._data_fetcher: Optional[AbstractDataFetcher] = None
5453
self._dataloader_state_dict: Dict[str, Any] = {}
@@ -61,7 +60,6 @@ def done(self) -> bool:
6160
def reset(self) -> None:
6261
"""Resets the loop's internal state."""
6362
self._dl_max_batches = 0
64-
self._num_dataloaders = 0
6563
self._data_fetcher = None
6664
self._outputs = []
6765

@@ -71,39 +69,36 @@ def reset(self) -> None:
7169
self.batch_progress.reset_on_restart()
7270

7371
def on_run_start( # type: ignore[override]
74-
self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
72+
self, data_fetcher: AbstractDataFetcher, dataloader_idx: Optional[int], dl_max_batches: int
7573
) -> None:
7674
"""Adds the passed arguments to the loop's state if necessary.
7775
7876
Args:
7977
data_fetcher: the current data_fetcher wrapping the dataloader
8078
dataloader_idx: index of the current dataloader
8179
dl_max_batches: maximum number of batches the dataloader can produce
82-
num_dataloaders: the total number of dataloaders
8380
"""
8481
void(dataloader_idx)
8582
self._dl_max_batches = dl_max_batches
86-
self._num_dataloaders = num_dataloaders
8783
self._data_fetcher = data_fetcher
8884

8985
self._reload_dataloader_state_dict(data_fetcher)
9086
self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_progress.current.ready)
9187

9288
def advance( # type: ignore[override]
93-
self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
89+
self, data_fetcher: AbstractDataFetcher, dataloader_idx: Optional[int], dl_max_batches: int
9490
) -> None:
9591
"""Calls the evaluation step with the corresponding hooks and updates the logger connector.
9692
9793
Args:
9894
data_fetcher: iterator over the dataloader
9995
dataloader_idx: index of the current dataloader
10096
dl_max_batches: maximum number of batches the dataloader can produce
101-
num_dataloaders: the total number of dataloaders
10297
10398
Raises:
10499
StopIteration: If the current batch is None
105100
"""
106-
void(dl_max_batches, num_dataloaders)
101+
void(dl_max_batches)
107102

108103
assert self._dataloader_iter is not None
109104
batch_idx, (batch, self.batch_progress.is_last_batch) = next(self._dataloader_iter)
@@ -113,24 +108,27 @@ def advance( # type: ignore[override]
113108

114109
if not data_fetcher.store_on_device:
115110
with self.trainer.profiler.profile("evaluation_batch_to_device"):
116-
batch = self.trainer.training_type_plugin.batch_to_device(batch, dataloader_idx=dataloader_idx)
111+
batch = self.trainer.training_type_plugin.batch_to_device(batch, dataloader_idx=(dataloader_idx or 0))
117112

118113
self.batch_progress.increment_ready()
119114

115+
# configure step_kwargs
116+
kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)
117+
120118
# hook
121-
self._on_evaluation_batch_start(batch, batch_idx, dataloader_idx)
119+
self._on_evaluation_batch_start(**kwargs)
122120

123121
self.batch_progress.increment_started()
124122

125123
# lightning module methods
126124
with self.trainer.profiler.profile("evaluation_step_and_end"):
127-
output = self._evaluation_step(batch, batch_idx, dataloader_idx)
125+
output = self._evaluation_step(**kwargs)
128126
output = self._evaluation_step_end(output)
129127

130128
self.batch_progress.increment_processed()
131129

132130
# track loss history
133-
self._on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx)
131+
self._on_evaluation_batch_end(output, **kwargs)
134132

135133
self.batch_progress.increment_completed()
136134

@@ -208,7 +206,7 @@ def _num_completed_batches_reached(self) -> bool:
208206
def _has_completed(self) -> bool:
209207
return self.batch_progress.current.ready == self.batch_progress.current.completed
210208

211-
def _evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]:
209+
def _evaluation_step(self, **kwargs: Any) -> Optional[STEP_OUTPUT]:
212210
"""The evaluation step (validation_step or test_step depending on the trainer's state).
213211
214212
Args:
@@ -219,17 +217,14 @@ def _evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> O
219217
Returns:
220218
the outputs of the step
221219
"""
222-
# configure step_kwargs
223-
step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)
224-
225220
if self.trainer.testing:
226221
self.trainer.lightning_module._current_fx_name = "test_step"
227222
with self.trainer.profiler.profile("test_step"):
228-
output = self.trainer.accelerator.test_step(*step_kwargs.values())
223+
output = self.trainer.accelerator.test_step(*kwargs.values())
229224
else:
230225
self.trainer.lightning_module._current_fx_name = "validation_step"
231226
with self.trainer.profiler.profile("validation_step"):
232-
output = self.trainer.accelerator.validation_step(*step_kwargs.values())
227+
output = self.trainer.accelerator.validation_step(*kwargs.values())
233228

234229
return output
235230

@@ -239,7 +234,7 @@ def _evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPU
239234
output = self.trainer.call_hook(hook_name, *args, **kwargs)
240235
return output
241236

242-
def _on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
237+
def _on_evaluation_batch_start(self, **kwargs: Any) -> None:
243238
"""Calls the ``on_{validation/test}_batch_start`` hook.
244239
245240
Args:
@@ -250,19 +245,15 @@ def _on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx:
250245
Raises:
251246
AssertionError: If the number of dataloaders is None (has not yet been set).
252247
"""
253-
self.trainer.logger_connector.on_batch_start(batch_idx, batch)
254-
255-
assert self._num_dataloaders is not None
256-
self.trainer.logger_connector.on_evaluation_batch_start(dataloader_idx, self._num_dataloaders)
248+
self.trainer.logger_connector.on_batch_start(**kwargs)
257249

250+
kwargs.setdefault("dataloader_idx", 0) # TODO: the argument should be keyword for these
258251
if self.trainer.testing:
259-
self.trainer.call_hook("on_test_batch_start", batch, batch_idx, dataloader_idx)
252+
self.trainer.call_hook("on_test_batch_start", *kwargs.values())
260253
else:
261-
self.trainer.call_hook("on_validation_batch_start", batch, batch_idx, dataloader_idx)
254+
self.trainer.call_hook("on_validation_batch_start", *kwargs.values())
262255

263-
def _on_evaluation_batch_end(
264-
self, output: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
265-
) -> None:
256+
def _on_evaluation_batch_end(self, output: Optional[STEP_OUTPUT], **kwargs: Any) -> None:
266257
"""The ``on_{validation/test}_batch_end`` hook.
267258
268259
Args:
@@ -271,12 +262,13 @@ def _on_evaluation_batch_end(
271262
batch_idx: The index of the current batch
272263
dataloader_idx: Index of the dataloader producing the current batch
273264
"""
265+
kwargs.setdefault("dataloader_idx", 0) # TODO: the argument should be keyword for these
274266
hook_name = "on_test_batch_end" if self.trainer.testing else "on_validation_batch_end"
275-
self.trainer.call_hook(hook_name, output, batch, batch_idx, dataloader_idx)
267+
self.trainer.call_hook(hook_name, output, *kwargs.values())
276268

277269
self.trainer.logger_connector.on_batch_end()
278270

279-
def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Union[Any, int]]:
271+
def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]) -> Dict[str, Union[Any, int]]:
280272
"""Helper function to build the arguments for the current step.
281273
282274
Args:
@@ -289,13 +281,8 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict
289281
"""
290282
# make dataloader_idx arg in validation_step optional
291283
step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])
292-
293-
multiple_val_loaders = not self.trainer.testing and self._num_dataloaders > 1
294-
multiple_test_loaders = self.trainer.testing and self._num_dataloaders > 1
295-
296-
if multiple_test_loaders or multiple_val_loaders:
284+
if dataloader_idx is not None:
297285
step_kwargs["dataloader_idx"] = dataloader_idx
298-
299286
return step_kwargs
300287

301288
@lru_cache(1)

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
158158

159159
self.batch_progress.increment_ready()
160160

161-
self.trainer.logger_connector.on_batch_start(batch_idx, batch)
161+
self.trainer.logger_connector.on_batch_start(batch, batch_idx)
162162

163163
if batch is None:
164164
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -139,17 +139,12 @@ def _increment_eval_log_step(self) -> None:
139139
elif self.trainer.state.stage is RunningStage.TESTING:
140140
self._test_log_step += 1
141141

142-
def on_evaluation_batch_start(self, dataloader_idx: int, num_dataloaders: int) -> None:
143-
model = self.trainer.lightning_module
144-
# set dataloader_idx only if multiple ones
145-
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None
146-
147142
def update_eval_step_metrics(self) -> None:
143+
assert not self._epoch_end_reached
148144
if self.trainer.sanity_checking:
149145
return
150146

151147
# logs user requested information to logger
152-
assert not self._epoch_end_reached
153148
self.log_metrics(self.metrics["log"], step=self._eval_log_step)
154149

155150
# increment the step even if nothing was logged
@@ -259,23 +254,29 @@ def _log_gpus_metrics(self) -> None:
259254
def on_epoch_start(self) -> None:
260255
self._epoch_end_reached = False
261256

262-
def on_batch_start(self, batch_idx: int, batch: Any) -> None:
257+
def on_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> None:
263258
self._batch_idx = batch_idx
264259
self._epoch_end_reached = False
265260

266-
assert self.trainer._results is not None
261+
results = self.trainer._results
262+
assert results is not None
267263
# attach reference to the new batch and remove the cached batch_size
268-
self.trainer._results.batch = batch
269-
self.trainer._results.batch_size = None
264+
results.batch = batch
265+
results.batch_size = None
266+
results.dataloader_idx = dataloader_idx
270267

271268
def epoch_end_reached(self) -> None:
272269
self._epoch_end_reached = True
273270
self._batch_idx = None
274271
self._split_idx = None
275-
assert self.trainer._results is not None
276272

277273
def on_epoch_end(self) -> None:
278274
assert self._epoch_end_reached
275+
results = self.trainer._results
276+
assert results is not None
277+
# we need to reset this index before the `self.metrics` call below
278+
results.dataloader_idx = None
279+
279280
metrics = self.metrics
280281
self._progress_bar_metrics.update(metrics["pbar"])
281282
self._callback_metrics.update(metrics["callback"])
@@ -308,8 +309,9 @@ def reset_metrics(self) -> None:
308309
self._callback_metrics = {}
309310

310311
def reset_results(self) -> None:
311-
if self.trainer._results is not None:
312-
self.trainer._results.reset()
312+
results = self.trainer._results
313+
if results is not None:
314+
results.reset()
313315

314316
self._batch_idx = None
315317
self._split_idx = None

pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def __init__(self, training: bool, device: Optional[Union[str, torch.device]] =
393393
self.device: Optional[Union[str, torch.device]] = device
394394
self.batch: Optional[Any] = None
395395
self.batch_size: Optional[int] = None
396+
self.dataloader_idx: Optional[int] = None
396397

397398
@property
398399
def result_metrics(self) -> List[ResultMetric]:
@@ -436,7 +437,6 @@ def log(
436437
sync_dist_fn: Callable = _Sync.no_op,
437438
sync_dist_group: Optional[Any] = None,
438439
add_dataloader_idx: bool = True,
439-
dataloader_idx: Optional[int] = None,
440440
batch_size: Optional[int] = None,
441441
metric_attribute: Optional[str] = None,
442442
rank_zero_only: bool = False,
@@ -453,9 +453,9 @@ def log(
453453
# storage key
454454
key = f"{fx}.{name}"
455455
# add dataloader_suffix to both key and fx
456-
if add_dataloader_idx and dataloader_idx is not None:
457-
key += f".{dataloader_idx}"
458-
fx += f".{dataloader_idx}"
456+
if add_dataloader_idx and self.dataloader_idx is not None:
457+
key += f".{self.dataloader_idx}"
458+
fx += f".{self.dataloader_idx}"
459459

460460
meta = _Metadata(
461461
fx=fx,
@@ -467,7 +467,7 @@ def log(
467467
reduce_fx=reduce_fx,
468468
enable_graph=enable_graph,
469469
add_dataloader_idx=add_dataloader_idx,
470-
dataloader_idx=dataloader_idx,
470+
dataloader_idx=self.dataloader_idx,
471471
metric_attribute=metric_attribute,
472472
)
473473
meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, _group=sync_dist_group, rank_zero_only=rank_zero_only)

pytorch_lightning/trainer/trainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1430,7 +1430,6 @@ def _call_teardown_hook(self) -> None:
14301430
self.call_hook("teardown", stage=fn)
14311431

14321432
self.lightning_module._current_fx_name = None
1433-
self.lightning_module._current_dataloader_idx = None
14341433
# these could have become stale if metrics are defined in `setup`
14351434
self.lightning_module._metric_attributes = None
14361435

tests/loops/test_loop_state_dict.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,15 @@ def test_loops_state_dict_structure():
8181
"epoch_loop.val_loop._results": {
8282
"batch": None,
8383
"batch_size": None,
84+
"dataloader_idx": None,
8485
"training": False,
8586
"device": None,
8687
"items": {},
8788
},
8889
"epoch_loop._results": {
8990
"batch": None,
9091
"batch_size": None,
92+
"dataloader_idx": None,
9193
"training": True,
9294
"device": None,
9395
"items": {},
@@ -109,6 +111,7 @@ def test_loops_state_dict_structure():
109111
"_results": {
110112
"batch": None,
111113
"batch_size": None,
114+
"dataloader_idx": None,
112115
"training": False,
113116
"device": None,
114117
"items": {},
@@ -126,6 +129,7 @@ def test_loops_state_dict_structure():
126129
"_results": {
127130
"batch": None,
128131
"batch_size": None,
132+
"dataloader_idx": None,
129133
"training": False,
130134
"device": None,
131135
"items": {},

0 commit comments

Comments
 (0)