@@ -308,3 +308,68 @@ def test_multi_node_weighted_large_sample_size_with_prefetcher(self, midpoint, s
308
308
stop_criteria ,
309
309
)
310
310
run_test_save_load_state (self , node , midpoint )
311
+
312
+ def test_multi_node_weighted_sampler_tag_output_dict_items (self ) -> None :
313
+ """Test MultiNodeWeightedSampler with tag_output=True for dictionary items"""
314
+ node = MultiNodeWeightedSampler (
315
+ self .datasets ,
316
+ self .weights ,
317
+ tag_output = True ,
318
+ )
319
+
320
+ results = list (node )
321
+
322
+ # Verify that each result has a 'dataset_key' key with the correct dataset name
323
+ for result in results :
324
+ self .assertIn ("dataset_key" , result )
325
+
326
+ dataset_name = result ["dataset_key" ]
327
+ self .assertIn (dataset_name , [f"ds{ i } " for i in range (self ._num_datasets )])
328
+
329
+ self .assertIn ("name" , result )
330
+ self .assertIn ("test_tensor" , result )
331
+
332
+ self .assertEqual (dataset_name , result ["name" ])
333
+
334
+ def test_multi_node_weighted_sampler_tag_output_non_dict_items (self ) -> None :
335
+ """Test MultiNodeWeightedSampler with tag_output=True for non-dictionary items"""
336
+ non_dict_datasets = {
337
+ f"ds{ i } " : IterableWrapper (range (i * 10 , (i + 1 ) * 10 ))
338
+ for i in range (self ._num_datasets )
339
+ }
340
+
341
+ node = MultiNodeWeightedSampler (
342
+ non_dict_datasets ,
343
+ self .weights ,
344
+ tag_output = True ,
345
+ )
346
+
347
+ results = list (node )
348
+
349
+ # Verify that each result is now a dictionary with 'data' and 'dataset_key' keys
350
+ for result in results :
351
+ self .assertIsInstance (result , dict )
352
+
353
+ self .assertIn ("data" , result )
354
+ self .assertIn ("dataset_key" , result )
355
+
356
+ dataset_name = result ["dataset_key" ]
357
+ self .assertIn (dataset_name , [f"ds{ i } " for i in range (self ._num_datasets )])
358
+
359
+ def test_multi_node_weighted_sampler_tag_output_false (self ) -> None :
360
+ """Test MultiNodeWeightedSampler with tag_output=False (default behavior)"""
361
+ node = MultiNodeWeightedSampler (
362
+ self .datasets ,
363
+ self .weights ,
364
+ tag_output = False ,
365
+ )
366
+
367
+ results = list (node )
368
+
369
+ # Verify that none of the results have a 'dataset' key
370
+ for result in results :
371
+ self .assertNotIn ("dataset" , result )
372
+
373
+ # Check that the original data is preserved
374
+ self .assertIn ("name" , result )
375
+ self .assertIn ("test_tensor" , result )
0 commit comments