@@ -48,7 +48,6 @@ def __init__(self) -> None:
48
48
49
49
self ._outputs : EPOCH_OUTPUT = []
50
50
self ._dl_max_batches = 0
51
- self ._num_dataloaders = 0
52
51
self ._dataloader_iter : Optional [Iterator ] = None
53
52
self ._data_fetcher : Optional [AbstractDataFetcher ] = None
54
53
self ._dataloader_state_dict : Dict [str , Any ] = {}
@@ -61,7 +60,6 @@ def done(self) -> bool:
61
60
def reset (self ) -> None :
62
61
"""Resets the loop's internal state."""
63
62
self ._dl_max_batches = 0
64
- self ._num_dataloaders = 0
65
63
self ._data_fetcher = None
66
64
self ._outputs = []
67
65
@@ -71,39 +69,36 @@ def reset(self) -> None:
71
69
self .batch_progress .reset_on_restart ()
72
70
73
71
def on_run_start ( # type: ignore[override]
74
- self , data_fetcher : AbstractDataFetcher , dataloader_idx : int , dl_max_batches : int , num_dataloaders : int
72
+ self , data_fetcher : AbstractDataFetcher , dataloader_idx : Optional [ int ] , dl_max_batches : int
75
73
) -> None :
76
74
"""Adds the passed arguments to the loop's state if necessary.
77
75
78
76
Args:
79
77
data_fetcher: the current data_fetcher wrapping the dataloader
80
78
dataloader_idx: index of the current dataloader
81
79
dl_max_batches: maximum number of batches the dataloader can produce
82
- num_dataloaders: the total number of dataloaders
83
80
"""
84
81
void (dataloader_idx )
85
82
self ._dl_max_batches = dl_max_batches
86
- self ._num_dataloaders = num_dataloaders
87
83
self ._data_fetcher = data_fetcher
88
84
89
85
self ._reload_dataloader_state_dict (data_fetcher )
90
86
self ._dataloader_iter = _update_dataloader_iter (data_fetcher , self .batch_progress .current .ready )
91
87
92
88
def advance ( # type: ignore[override]
93
- self , data_fetcher : AbstractDataFetcher , dataloader_idx : int , dl_max_batches : int , num_dataloaders : int
89
+ self , data_fetcher : AbstractDataFetcher , dataloader_idx : Optional [ int ] , dl_max_batches : int
94
90
) -> None :
95
91
"""Calls the evaluation step with the corresponding hooks and updates the logger connector.
96
92
97
93
Args:
98
94
data_fetcher: iterator over the dataloader
99
95
dataloader_idx: index of the current dataloader
100
96
dl_max_batches: maximum number of batches the dataloader can produce
101
- num_dataloaders: the total number of dataloaders
102
97
103
98
Raises:
104
99
StopIteration: If the current batch is None
105
100
"""
106
- void (dl_max_batches , num_dataloaders )
101
+ void (dl_max_batches )
107
102
108
103
assert self ._dataloader_iter is not None
109
104
batch_idx , (batch , self .batch_progress .is_last_batch ) = next (self ._dataloader_iter )
@@ -113,24 +108,27 @@ def advance( # type: ignore[override]
113
108
114
109
if not data_fetcher .store_on_device :
115
110
with self .trainer .profiler .profile ("evaluation_batch_to_device" ):
116
- batch = self .trainer .training_type_plugin .batch_to_device (batch , dataloader_idx = dataloader_idx )
111
+ batch = self .trainer .training_type_plugin .batch_to_device (batch , dataloader_idx = ( dataloader_idx or 0 ) )
117
112
118
113
self .batch_progress .increment_ready ()
119
114
115
+ # configure step_kwargs
116
+ kwargs = self ._build_kwargs (batch , batch_idx , dataloader_idx )
117
+
120
118
# hook
121
- self ._on_evaluation_batch_start (batch , batch_idx , dataloader_idx )
119
+ self ._on_evaluation_batch_start (** kwargs )
122
120
123
121
self .batch_progress .increment_started ()
124
122
125
123
# lightning module methods
126
124
with self .trainer .profiler .profile ("evaluation_step_and_end" ):
127
- output = self ._evaluation_step (batch , batch_idx , dataloader_idx )
125
+ output = self ._evaluation_step (** kwargs )
128
126
output = self ._evaluation_step_end (output )
129
127
130
128
self .batch_progress .increment_processed ()
131
129
132
130
# track loss history
133
- self ._on_evaluation_batch_end (output , batch , batch_idx , dataloader_idx )
131
+ self ._on_evaluation_batch_end (output , ** kwargs )
134
132
135
133
self .batch_progress .increment_completed ()
136
134
@@ -208,7 +206,7 @@ def _num_completed_batches_reached(self) -> bool:
208
206
def _has_completed (self ) -> bool :
209
207
return self .batch_progress .current .ready == self .batch_progress .current .completed
210
208
211
- def _evaluation_step (self , batch : Any , batch_idx : int , dataloader_idx : int ) -> Optional [STEP_OUTPUT ]:
209
+ def _evaluation_step (self , ** kwargs : Any ) -> Optional [STEP_OUTPUT ]:
212
210
"""The evaluation step (validation_step or test_step depending on the trainer's state).
213
211
214
212
Args:
@@ -219,17 +217,14 @@ def _evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> O
219
217
Returns:
220
218
the outputs of the step
221
219
"""
222
- # configure step_kwargs
223
- step_kwargs = self ._build_kwargs (batch , batch_idx , dataloader_idx )
224
-
225
220
if self .trainer .testing :
226
221
self .trainer .lightning_module ._current_fx_name = "test_step"
227
222
with self .trainer .profiler .profile ("test_step" ):
228
- output = self .trainer .accelerator .test_step (* step_kwargs .values ())
223
+ output = self .trainer .accelerator .test_step (* kwargs .values ())
229
224
else :
230
225
self .trainer .lightning_module ._current_fx_name = "validation_step"
231
226
with self .trainer .profiler .profile ("validation_step" ):
232
- output = self .trainer .accelerator .validation_step (* step_kwargs .values ())
227
+ output = self .trainer .accelerator .validation_step (* kwargs .values ())
233
228
234
229
return output
235
230
@@ -239,7 +234,7 @@ def _evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPU
239
234
output = self .trainer .call_hook (hook_name , * args , ** kwargs )
240
235
return output
241
236
242
- def _on_evaluation_batch_start (self , batch : Any , batch_idx : int , dataloader_idx : int ) -> None :
237
+ def _on_evaluation_batch_start (self , ** kwargs : Any ) -> None :
243
238
"""Calls the ``on_{validation/test}_batch_start`` hook.
244
239
245
240
Args:
@@ -250,19 +245,15 @@ def _on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx:
250
245
Raises:
251
246
AssertionError: If the number of dataloaders is None (has not yet been set).
252
247
"""
253
- self .trainer .logger_connector .on_batch_start (batch_idx , batch )
254
-
255
- assert self ._num_dataloaders is not None
256
- self .trainer .logger_connector .on_evaluation_batch_start (dataloader_idx , self ._num_dataloaders )
248
+ self .trainer .logger_connector .on_batch_start (** kwargs )
257
249
250
+ kwargs .setdefault ("dataloader_idx" , 0 ) # TODO: the argument should be keyword for these
258
251
if self .trainer .testing :
259
- self .trainer .call_hook ("on_test_batch_start" , batch , batch_idx , dataloader_idx )
252
+ self .trainer .call_hook ("on_test_batch_start" , * kwargs . values () )
260
253
else :
261
- self .trainer .call_hook ("on_validation_batch_start" , batch , batch_idx , dataloader_idx )
254
+ self .trainer .call_hook ("on_validation_batch_start" , * kwargs . values () )
262
255
263
- def _on_evaluation_batch_end (
264
- self , output : Optional [STEP_OUTPUT ], batch : Any , batch_idx : int , dataloader_idx : int
265
- ) -> None :
256
+ def _on_evaluation_batch_end (self , output : Optional [STEP_OUTPUT ], ** kwargs : Any ) -> None :
266
257
"""The ``on_{validation/test}_batch_end`` hook.
267
258
268
259
Args:
@@ -271,12 +262,13 @@ def _on_evaluation_batch_end(
271
262
batch_idx: The index of the current batch
272
263
dataloader_idx: Index of the dataloader producing the current batch
273
264
"""
265
+ kwargs .setdefault ("dataloader_idx" , 0 ) # TODO: the argument should be keyword for these
274
266
hook_name = "on_test_batch_end" if self .trainer .testing else "on_validation_batch_end"
275
- self .trainer .call_hook (hook_name , output , batch , batch_idx , dataloader_idx )
267
+ self .trainer .call_hook (hook_name , output , * kwargs . values () )
276
268
277
269
self .trainer .logger_connector .on_batch_end ()
278
270
279
- def _build_kwargs (self , batch : Any , batch_idx : int , dataloader_idx : int ) -> Dict [str , Union [Any , int ]]:
271
+ def _build_kwargs (self , batch : Any , batch_idx : int , dataloader_idx : Optional [ int ] ) -> Dict [str , Union [Any , int ]]:
280
272
"""Helper function to build the arguments for the current step.
281
273
282
274
Args:
@@ -289,13 +281,8 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict
289
281
"""
290
282
# make dataloader_idx arg in validation_step optional
291
283
step_kwargs = OrderedDict ([("batch" , batch ), ("batch_idx" , batch_idx )])
292
-
293
- multiple_val_loaders = not self .trainer .testing and self ._num_dataloaders > 1
294
- multiple_test_loaders = self .trainer .testing and self ._num_dataloaders > 1
295
-
296
- if multiple_test_loaders or multiple_val_loaders :
284
+ if dataloader_idx is not None :
297
285
step_kwargs ["dataloader_idx" ] = dataloader_idx
298
-
299
286
return step_kwargs
300
287
301
288
@lru_cache (1 )
0 commit comments