@@ -30,13 +30,17 @@ def setUp(self) -> None:
30
30
self ._num_samples = 1
31
31
self ._num_datasets = 3
32
32
33
- def get_equal_dataset (self , num_samples , num_datasets , as_list = False ) -> Union [List [BaseNode ], Dict [str , BaseNode ]]:
33
+ def get_equal_dataset (
34
+ self , num_samples , num_datasets , as_list = False
35
+ ) -> Union [List [BaseNode [T ]], Dict [str , BaseNode [T ]]]:
34
36
"""Returns a dictionary of datasets with the same number of samples"""
35
37
if as_list :
36
38
return [IterableWrapper (DummyIterableDataset (num_samples , f"ds{ i } " )) for i in range (num_datasets )]
37
39
return {f"ds{ i } " : IterableWrapper (DummyIterableDataset (num_samples , f"ds{ i } " )) for i in range (num_datasets )}
38
40
39
- def get_unequal_dataset (self , num_samples , num_datasets , as_list = False ):
41
+ def get_unequal_dataset (
42
+ self , num_samples , num_datasets , as_list = False
43
+ ) -> Union [List [BaseNode [T ]], Dict [str , BaseNode [T ]]]:
40
44
"""Returns a dictionary of datasets with the different number of samples.
41
45
For example if num_samples = 1 and num_datasets = 3, the datasets will have 1, 2, 3 samples, respectively.
42
46
datasets = {"ds0":[0], "ds1":[0, 1], "ds2":[0, 1, 2]}
@@ -85,6 +89,37 @@ def test_single_dataset(self, num_samples: int) -> None:
85
89
pass
86
90
self .assertEqual (num_sample + 1 , num_samples )
87
91
92
+ @parameterized .expand ([4 , 8 , 16 ])
93
+ def test_single_dataset_tagged (self , num_samples : int ) -> None :
94
+ datasets = self .get_equal_dataset (num_samples , 1 )
95
+ sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .FIRST_DATASET_EXHAUSTED , tag_output = True )
96
+ for num_sample , _ in enumerate (sampler ):
97
+ pass
98
+ self .assertEqual (num_sample + 1 , num_samples )
99
+
100
+ datasets = self .get_equal_dataset (num_samples , 1 , as_list = True )
101
+ sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .FIRST_DATASET_EXHAUSTED , tag_output = True )
102
+ for num_sample , _ in enumerate (sampler ):
103
+ pass
104
+ self .assertEqual (num_sample + 1 , num_samples )
105
+
106
+ @parameterized .expand ([4 , 8 , 16 ])
107
+ def test_single_dataset_tags (self , num_samples : int ) -> None :
108
+ datasets = self .get_equal_dataset (num_samples , 1 )
109
+ sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .FIRST_DATASET_EXHAUSTED , tag_output = True )
110
+ for num_sample , element in enumerate (sampler ):
111
+ self .assertIsInstance (element , dict )
112
+ self .assertEqual (set (element .keys ()), {"dataset_key" , "data" })
113
+ self .assertEqual (element ["dataset_key" ], "ds0" )
114
+
115
+ self .assertEqual (num_sample + 1 , num_samples )
116
+
117
+ datasets = self .get_equal_dataset (num_samples , 1 , as_list = True )
118
+ sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .FIRST_DATASET_EXHAUSTED , tag_output = True )
119
+ for num_sample , _ in enumerate (sampler ):
120
+ pass
121
+ self .assertEqual (num_sample + 1 , num_samples )
122
+
88
123
@parameterized .expand ([4 , 8 , 16 ])
89
124
def test_single_dataset_batched (self , num_samples : int ) -> None :
90
125
datasets = self .get_equal_dataset (num_samples , 1 )
@@ -117,7 +152,7 @@ def test_single_dataset_drop_last_batched(self, num_samples: int, drop_last: boo
117
152
batch_size = 5
118
153
batcher = Batcher (sampler , batch_size = batch_size , drop_last = drop_last )
119
154
num_batches = 0
120
- for batch_number , batch in enumerate (batcher ):
155
+ for _ , batch in enumerate (batcher ):
121
156
num_batches += 1
122
157
self .assertGreater (len (batch ), 0 )
123
158
if drop_last :
@@ -152,14 +187,14 @@ def test_stop_criteria_all_datasets_exhausted(self, num_samples, num_datasets) -
152
187
datasets = self .get_unequal_dataset (num_samples , num_datasets )
153
188
total_items = sum (range (num_samples , num_samples + num_datasets ))
154
189
sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .ALL_DATASETS_EXHAUSTED )
155
- for num_sample , item in enumerate (sampler ):
190
+ for num_sample , _ in enumerate (sampler ):
156
191
pass
157
192
self .assertEqual (num_sample + 1 , total_items )
158
193
159
194
datasets = self .get_unequal_dataset (num_samples , num_datasets , as_list = True )
160
195
total_items = sum (range (num_samples , num_samples + num_datasets ))
161
196
sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .ALL_DATASETS_EXHAUSTED )
162
- for num_sample , item in enumerate (sampler ):
197
+ for num_sample , _ in enumerate (sampler ):
163
198
pass
164
199
self .assertEqual (num_sample + 1 , total_items )
165
200
@@ -176,7 +211,7 @@ def test_stop_criteria_all_datasets_exhausted_batched(self, num_samples, num_dat
176
211
batch_size = 3
177
212
batcher = Batcher (sampler , batch_size = batch_size , drop_last = True )
178
213
num_batches = 0
179
- for batch in batcher :
214
+ for _ in batcher :
180
215
num_batches += 1
181
216
self .assertEqual (num_batches , total_items // batch_size )
182
217
@@ -186,7 +221,7 @@ def test_stop_criteria_all_datasets_exhausted_batched(self, num_samples, num_dat
186
221
batch_size = 3
187
222
batcher = Batcher (sampler , batch_size = batch_size , drop_last = True )
188
223
num_batches = 0
189
- for batch in batcher :
224
+ for _ in batcher :
190
225
num_batches += 1
191
226
self .assertEqual (num_batches , total_items // batch_size )
192
227
@@ -199,13 +234,13 @@ def test_stop_criteria_all_datasets_exhausted_batched(self, num_samples, num_dat
199
234
def test_stop_criteria_first_dataset_exhausted (self , num_samples , num_datasets ) -> None :
200
235
datasets = self .get_unequal_dataset (num_samples , num_datasets )
201
236
sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .FIRST_DATASET_EXHAUSTED )
202
- for num_sample , item in enumerate (sampler ):
237
+ for num_sample , _ in enumerate (sampler ):
203
238
pass
204
239
self .assertEqual (num_sample + 1 , num_datasets * num_samples )
205
240
206
241
datasets = self .get_unequal_dataset (num_samples , num_datasets , as_list = True )
207
242
sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .FIRST_DATASET_EXHAUSTED )
208
- for num_sample , item in enumerate (sampler ):
243
+ for num_sample , _ in enumerate (sampler ):
209
244
pass
210
245
self .assertEqual (num_sample + 1 , num_datasets * num_samples )
211
246
@@ -222,7 +257,7 @@ def test_stop_criteria_first_dataset_exhausted_batched(self, num_samples, num_da
222
257
batch_size = 2
223
258
batcher = Batcher (sampler , batch_size = batch_size )
224
259
num_batches = 0
225
- for batch in batcher :
260
+ for _ in batcher :
226
261
num_batches += 1
227
262
self .assertEqual (num_batches , num_samples * num_datasets // batch_size )
228
263
@@ -231,7 +266,7 @@ def test_stop_criteria_first_dataset_exhausted_batched(self, num_samples, num_da
231
266
batch_size = 2
232
267
batcher = Batcher (sampler , batch_size = batch_size )
233
268
num_batches = 0
234
- for batch in batcher :
269
+ for _ in batcher :
235
270
num_batches += 1
236
271
self .assertEqual (num_batches , num_samples * num_datasets // batch_size )
237
272
@@ -245,7 +280,7 @@ def test_stop_criteria_cycle_until_all_datasets_exhausted(self, num_samples, num
245
280
num_samples = 4
246
281
datasets = self .get_unequal_dataset (num_samples , num_datasets )
247
282
sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED )
248
- for num_sample , item in enumerate (sampler ):
283
+ for num_sample , _ in enumerate (sampler ):
249
284
pass
250
285
self .assertEqual (
251
286
num_sample + 1 ,
@@ -254,7 +289,7 @@ def test_stop_criteria_cycle_until_all_datasets_exhausted(self, num_samples, num
254
289
255
290
datasets = self .get_unequal_dataset (num_samples , num_datasets , as_list = True )
256
291
sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED )
257
- for num_sample , item in enumerate (sampler ):
292
+ for num_sample , _ in enumerate (sampler ):
258
293
pass
259
294
self .assertEqual (
260
295
num_sample + 1 ,
@@ -275,7 +310,7 @@ def test_stop_criteria_cycle_until_all_datasets_exhausted_batched(self, num_samp
275
310
batcher = Batcher (sampler , batch_size = batch_size , drop_last = True )
276
311
277
312
num_batches = 0
278
- for batch in batcher :
313
+ for _ in batcher :
279
314
num_batches += 1
280
315
281
316
self .assertEqual (num_batches , 3 )
@@ -373,6 +408,121 @@ def test_multi_node_round_robin_sampler_unequal_dataset_batched(self) -> None:
373
408
374
409
self .assertEqual (batch_number , 2 )
375
410
411
+ def test_tag_default_default_keys (self ) -> None :
412
+ """Test that custom keys for tag_output work correctly."""
413
+ num_samples = 5
414
+ num_datasets = 3
415
+
416
+ # Test with dictionary input
417
+ datasets = self .get_equal_dataset (num_samples , num_datasets )
418
+ default_keys = ("dataset_key" , "data" )
419
+ sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .FIRST_DATASET_EXHAUSTED , tag_output = True )
420
+
421
+ for i , element in enumerate (sampler ):
422
+ self .assertIsInstance (element , dict )
423
+ self .assertEqual (set (element .keys ()), set (default_keys ))
424
+ self .assertTrue (element ["dataset_key" ].startswith ("ds" ))
425
+ # Since we're using round-robin, datasets should appear in order
426
+ expected_ds = f"ds{ i % num_datasets } "
427
+ self .assertEqual (element ["dataset_key" ], expected_ds )
428
+
429
+ # Also test with list input
430
+ datasets = self .get_equal_dataset (num_samples , num_datasets , as_list = True )
431
+ sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .FIRST_DATASET_EXHAUSTED , tag_output = True )
432
+
433
+ for i , element in enumerate (sampler ):
434
+ self .assertIsInstance (element , dict )
435
+ self .assertEqual (set (element .keys ()), set (default_keys ))
436
+ # Since we're using round-robin, datasets should appear in order
437
+ expected_ds = f"ds_{ i % num_datasets } " # Note list sources use ds_0 format
438
+ self .assertEqual (element ["dataset_key" ], expected_ds )
439
+
440
+ def test_tag_output_with_different_stop_criteria (self ) -> None :
441
+ """Test tagging with different stop criteria."""
442
+ # Use unequal datasets to better test the behavior
443
+ datasets = self .get_unequal_dataset (1 , 3 )
444
+
445
+ # ALL_DATASETS_EXHAUSTED
446
+ sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .ALL_DATASETS_EXHAUSTED , tag_output = True )
447
+
448
+ results = list (sampler )
449
+ # We should have 6 items total (1 + 2 + 3)
450
+ self .assertEqual (len (results ), 6 )
451
+ for item in results :
452
+ self .assertIsInstance (item , dict )
453
+ self .assertEqual (set (item .keys ()), {"dataset_key" , "data" })
454
+
455
+ # Count occurrences of each dataset
456
+ dataset_counts = {}
457
+ for item in results :
458
+ ds = item ["dataset_key" ]
459
+ dataset_counts [ds ] = dataset_counts .get (ds , 0 ) + 1
460
+
461
+ self .assertEqual (dataset_counts ["ds0" ], 1 )
462
+ self .assertEqual (dataset_counts ["ds1" ], 2 )
463
+ self .assertEqual (dataset_counts ["ds2" ], 3 )
464
+
465
+ # FIRST_DATASET_EXHAUSTED
466
+ sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .FIRST_DATASET_EXHAUSTED , tag_output = True )
467
+
468
+ results = list (sampler )
469
+ # We should have 3 items total (one from each dataset)
470
+ self .assertEqual (len (results ), 3 )
471
+
472
+ # Count occurrences of each dataset
473
+ dataset_counts = {}
474
+ for item in results :
475
+ ds = item ["dataset_key" ]
476
+ dataset_counts [ds ] = dataset_counts .get (ds , 0 ) + 1
477
+
478
+ self .assertEqual (dataset_counts ["ds0" ], 1 )
479
+ self .assertEqual (dataset_counts ["ds1" ], 1 )
480
+ self .assertEqual (dataset_counts ["ds2" ], 1 )
481
+
482
+ # CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED
483
+ sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED , tag_output = True )
484
+
485
+ results = list (sampler )
486
+ # We should have items from all datasets, with the smallest dataset repeated until all are exhausted
487
+ self .assertEqual (len (results ), 11 )
488
+
489
+ dataset_counts = {}
490
+ for item in results :
491
+ ds = item ["dataset_key" ]
492
+ dataset_counts [ds ] = dataset_counts .get (ds , 0 ) + 1
493
+
494
+ # ds0 should appear the most since it's recycled
495
+ self .assertEqual (dataset_counts ["ds0" ], 4 )
496
+ self .assertEqual (dataset_counts ["ds1" ], 4 )
497
+ self .assertEqual (dataset_counts ["ds2" ], 3 )
498
+
499
+ def test_tag_output_with_batching (self ) -> None :
500
+ """Test that tagging works correctly with batching."""
501
+ num_samples = 6
502
+ num_datasets = 3
503
+ batch_size = 3
504
+
505
+ datasets = self .get_equal_dataset (num_samples , num_datasets )
506
+ sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .FIRST_DATASET_EXHAUSTED , tag_output = True )
507
+
508
+ batcher = Batcher (sampler , batch_size = batch_size )
509
+
510
+ for batch in batcher :
511
+ self .assertEqual (len (batch ), batch_size )
512
+ for item in batch :
513
+ self .assertIsInstance (item , dict )
514
+ self .assertEqual (set (item .keys ()), {"dataset_key" , "data" })
515
+ self .assertTrue (item ["dataset_key" ].startswith ("ds" ))
516
+
517
+ def test_tag_output_invalid_inputs (self ) -> None :
518
+ """Test validation of invalid tag_output inputs."""
519
+ datasets = self .get_equal_dataset (3 , 3 )
520
+
521
+ # Test with invalid type
522
+ with self .assertRaises (TypeError ) as cm :
523
+ MultiNodeRoundRobinSampler (datasets , StopCriteria .FIRST_DATASET_EXHAUSTED , tag_output = 123 )
524
+ self .assertIn ("tag_output must be a boolean (True/False), got" , str (cm .exception ))
525
+
376
526
def test_unequal_batch_size (self ) -> None :
377
527
datasets = self .get_unequal_dataset (self ._num_samples , self ._num_datasets )
378
528
@@ -409,28 +559,29 @@ def test_get_state(self) -> None:
409
559
StopCriteria .FIRST_DATASET_EXHAUSTED ,
410
560
StopCriteria .CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED ,
411
561
],
562
+ [True , False ],
412
563
)
413
564
)
414
- def test_save_load_state (self , midpoint : int , stop_criteria : str ) -> None :
565
+ def test_save_load_state (self , midpoint : int , stop_criteria : str , tag_output ) -> None :
415
566
num_samples = 1500
416
567
num_datasets = 3
417
568
datasets = self .get_equal_dataset (num_samples , num_datasets )
418
- sampler = MultiNodeRoundRobinSampler (datasets , stop_criteria )
569
+ sampler = MultiNodeRoundRobinSampler (datasets , stop_criteria , tag_output )
419
570
prefetcher = Prefetcher (sampler , 3 )
420
571
run_test_save_load_state (self , prefetcher , midpoint )
421
572
422
573
datasets = self .get_unequal_dataset (num_samples , num_datasets )
423
- sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .ALL_DATASETS_EXHAUSTED )
574
+ sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .ALL_DATASETS_EXHAUSTED , tag_output )
424
575
prefetcher = Prefetcher (sampler , 3 )
425
576
run_test_save_load_state (self , prefetcher , 400 )
426
577
427
578
datasets = self .get_equal_dataset (num_samples , num_datasets , as_list = True )
428
- sampler = MultiNodeRoundRobinSampler (datasets , stop_criteria )
579
+ sampler = MultiNodeRoundRobinSampler (datasets , stop_criteria , tag_output )
429
580
prefetcher = Prefetcher (sampler , 3 )
430
581
run_test_save_load_state (self , prefetcher , midpoint )
431
582
432
583
datasets = self .get_unequal_dataset (num_samples , num_datasets )
433
- sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .ALL_DATASETS_EXHAUSTED )
584
+ sampler = MultiNodeRoundRobinSampler (datasets , StopCriteria .ALL_DATASETS_EXHAUSTED , tag_output )
434
585
prefetcher = Prefetcher (sampler , 3 )
435
586
run_test_save_load_state (self , prefetcher , 400 )
436
587
0 commit comments