Skip to content

Commit ac259e1

Browse files
Tag outgoing batch in round robin sampler (#1478)
* tag outgoing batch rr * add tag_output tests * add tests * add examples * remove tuple tag_output
1 parent e58314d commit ac259e1

File tree

2 files changed

+190
-25
lines changed

2 files changed

+190
-25
lines changed

test/nodes/test_multi_node_round_robin_sampler.py

Lines changed: 170 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,17 @@ def setUp(self) -> None:
3030
self._num_samples = 1
3131
self._num_datasets = 3
3232

33-
def get_equal_dataset(self, num_samples, num_datasets, as_list=False) -> Union[List[BaseNode], Dict[str, BaseNode]]:
33+
def get_equal_dataset(
34+
self, num_samples, num_datasets, as_list=False
35+
) -> Union[List[BaseNode[T]], Dict[str, BaseNode[T]]]:
3436
"""Returns a dictionary of datasets with the same number of samples"""
3537
if as_list:
3638
return [IterableWrapper(DummyIterableDataset(num_samples, f"ds{i}")) for i in range(num_datasets)]
3739
return {f"ds{i}": IterableWrapper(DummyIterableDataset(num_samples, f"ds{i}")) for i in range(num_datasets)}
3840

39-
def get_unequal_dataset(self, num_samples, num_datasets, as_list=False):
41+
def get_unequal_dataset(
42+
self, num_samples, num_datasets, as_list=False
43+
) -> Union[List[BaseNode[T]], Dict[str, BaseNode[T]]]:
4044
"""Returns a dictionary of datasets with the different number of samples.
4145
For example if num_samples = 1 and num_datasets = 3, the datasets will have 1, 2, 3 samples, respectively.
4246
datasets = {"ds0":[0], "ds1":[0, 1], "ds2":[0, 1, 2]}
@@ -85,6 +89,37 @@ def test_single_dataset(self, num_samples: int) -> None:
8589
pass
8690
self.assertEqual(num_sample + 1, num_samples)
8791

92+
@parameterized.expand([4, 8, 16])
93+
def test_single_dataset_tagged(self, num_samples: int) -> None:
94+
datasets = self.get_equal_dataset(num_samples, 1)
95+
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED, tag_output=True)
96+
for num_sample, _ in enumerate(sampler):
97+
pass
98+
self.assertEqual(num_sample + 1, num_samples)
99+
100+
datasets = self.get_equal_dataset(num_samples, 1, as_list=True)
101+
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED, tag_output=True)
102+
for num_sample, _ in enumerate(sampler):
103+
pass
104+
self.assertEqual(num_sample + 1, num_samples)
105+
106+
@parameterized.expand([4, 8, 16])
107+
def test_single_dataset_tags(self, num_samples: int) -> None:
108+
datasets = self.get_equal_dataset(num_samples, 1)
109+
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED, tag_output=True)
110+
for num_sample, element in enumerate(sampler):
111+
self.assertIsInstance(element, dict)
112+
self.assertEqual(set(element.keys()), {"dataset_key", "data"})
113+
self.assertEqual(element["dataset_key"], "ds0")
114+
115+
self.assertEqual(num_sample + 1, num_samples)
116+
117+
datasets = self.get_equal_dataset(num_samples, 1, as_list=True)
118+
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED, tag_output=True)
119+
for num_sample, _ in enumerate(sampler):
120+
pass
121+
self.assertEqual(num_sample + 1, num_samples)
122+
88123
@parameterized.expand([4, 8, 16])
89124
def test_single_dataset_batched(self, num_samples: int) -> None:
90125
datasets = self.get_equal_dataset(num_samples, 1)
@@ -117,7 +152,7 @@ def test_single_dataset_drop_last_batched(self, num_samples: int, drop_last: boo
117152
batch_size = 5
118153
batcher = Batcher(sampler, batch_size=batch_size, drop_last=drop_last)
119154
num_batches = 0
120-
for batch_number, batch in enumerate(batcher):
155+
for _, batch in enumerate(batcher):
121156
num_batches += 1
122157
self.assertGreater(len(batch), 0)
123158
if drop_last:
@@ -152,14 +187,14 @@ def test_stop_criteria_all_datasets_exhausted(self, num_samples, num_datasets) -
152187
datasets = self.get_unequal_dataset(num_samples, num_datasets)
153188
total_items = sum(range(num_samples, num_samples + num_datasets))
154189
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.ALL_DATASETS_EXHAUSTED)
155-
for num_sample, item in enumerate(sampler):
190+
for num_sample, _ in enumerate(sampler):
156191
pass
157192
self.assertEqual(num_sample + 1, total_items)
158193

159194
datasets = self.get_unequal_dataset(num_samples, num_datasets, as_list=True)
160195
total_items = sum(range(num_samples, num_samples + num_datasets))
161196
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.ALL_DATASETS_EXHAUSTED)
162-
for num_sample, item in enumerate(sampler):
197+
for num_sample, _ in enumerate(sampler):
163198
pass
164199
self.assertEqual(num_sample + 1, total_items)
165200

@@ -176,7 +211,7 @@ def test_stop_criteria_all_datasets_exhausted_batched(self, num_samples, num_dat
176211
batch_size = 3
177212
batcher = Batcher(sampler, batch_size=batch_size, drop_last=True)
178213
num_batches = 0
179-
for batch in batcher:
214+
for _ in batcher:
180215
num_batches += 1
181216
self.assertEqual(num_batches, total_items // batch_size)
182217

@@ -186,7 +221,7 @@ def test_stop_criteria_all_datasets_exhausted_batched(self, num_samples, num_dat
186221
batch_size = 3
187222
batcher = Batcher(sampler, batch_size=batch_size, drop_last=True)
188223
num_batches = 0
189-
for batch in batcher:
224+
for _ in batcher:
190225
num_batches += 1
191226
self.assertEqual(num_batches, total_items // batch_size)
192227

@@ -199,13 +234,13 @@ def test_stop_criteria_all_datasets_exhausted_batched(self, num_samples, num_dat
199234
def test_stop_criteria_first_dataset_exhausted(self, num_samples, num_datasets) -> None:
200235
datasets = self.get_unequal_dataset(num_samples, num_datasets)
201236
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED)
202-
for num_sample, item in enumerate(sampler):
237+
for num_sample, _ in enumerate(sampler):
203238
pass
204239
self.assertEqual(num_sample + 1, num_datasets * num_samples)
205240

206241
datasets = self.get_unequal_dataset(num_samples, num_datasets, as_list=True)
207242
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED)
208-
for num_sample, item in enumerate(sampler):
243+
for num_sample, _ in enumerate(sampler):
209244
pass
210245
self.assertEqual(num_sample + 1, num_datasets * num_samples)
211246

@@ -222,7 +257,7 @@ def test_stop_criteria_first_dataset_exhausted_batched(self, num_samples, num_da
222257
batch_size = 2
223258
batcher = Batcher(sampler, batch_size=batch_size)
224259
num_batches = 0
225-
for batch in batcher:
260+
for _ in batcher:
226261
num_batches += 1
227262
self.assertEqual(num_batches, num_samples * num_datasets // batch_size)
228263

@@ -231,7 +266,7 @@ def test_stop_criteria_first_dataset_exhausted_batched(self, num_samples, num_da
231266
batch_size = 2
232267
batcher = Batcher(sampler, batch_size=batch_size)
233268
num_batches = 0
234-
for batch in batcher:
269+
for _ in batcher:
235270
num_batches += 1
236271
self.assertEqual(num_batches, num_samples * num_datasets // batch_size)
237272

@@ -245,7 +280,7 @@ def test_stop_criteria_cycle_until_all_datasets_exhausted(self, num_samples, num
245280
num_samples = 4
246281
datasets = self.get_unequal_dataset(num_samples, num_datasets)
247282
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED)
248-
for num_sample, item in enumerate(sampler):
283+
for num_sample, _ in enumerate(sampler):
249284
pass
250285
self.assertEqual(
251286
num_sample + 1,
@@ -254,7 +289,7 @@ def test_stop_criteria_cycle_until_all_datasets_exhausted(self, num_samples, num
254289

255290
datasets = self.get_unequal_dataset(num_samples, num_datasets, as_list=True)
256291
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED)
257-
for num_sample, item in enumerate(sampler):
292+
for num_sample, _ in enumerate(sampler):
258293
pass
259294
self.assertEqual(
260295
num_sample + 1,
@@ -275,7 +310,7 @@ def test_stop_criteria_cycle_until_all_datasets_exhausted_batched(self, num_samp
275310
batcher = Batcher(sampler, batch_size=batch_size, drop_last=True)
276311

277312
num_batches = 0
278-
for batch in batcher:
313+
for _ in batcher:
279314
num_batches += 1
280315

281316
self.assertEqual(num_batches, 3)
@@ -373,6 +408,121 @@ def test_multi_node_round_robin_sampler_unequal_dataset_batched(self) -> None:
373408

374409
self.assertEqual(batch_number, 2)
375410

411+
def test_tag_default_default_keys(self) -> None:
412+
"""Test that custom keys for tag_output work correctly."""
413+
num_samples = 5
414+
num_datasets = 3
415+
416+
# Test with dictionary input
417+
datasets = self.get_equal_dataset(num_samples, num_datasets)
418+
default_keys = ("dataset_key", "data")
419+
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED, tag_output=True)
420+
421+
for i, element in enumerate(sampler):
422+
self.assertIsInstance(element, dict)
423+
self.assertEqual(set(element.keys()), set(default_keys))
424+
self.assertTrue(element["dataset_key"].startswith("ds"))
425+
# Since we're using round-robin, datasets should appear in order
426+
expected_ds = f"ds{i % num_datasets}"
427+
self.assertEqual(element["dataset_key"], expected_ds)
428+
429+
# Also test with list input
430+
datasets = self.get_equal_dataset(num_samples, num_datasets, as_list=True)
431+
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED, tag_output=True)
432+
433+
for i, element in enumerate(sampler):
434+
self.assertIsInstance(element, dict)
435+
self.assertEqual(set(element.keys()), set(default_keys))
436+
# Since we're using round-robin, datasets should appear in order
437+
expected_ds = f"ds_{i % num_datasets}" # Note list sources use ds_0 format
438+
self.assertEqual(element["dataset_key"], expected_ds)
439+
440+
def test_tag_output_with_different_stop_criteria(self) -> None:
441+
"""Test tagging with different stop criteria."""
442+
# Use unequal datasets to better test the behavior
443+
datasets = self.get_unequal_dataset(1, 3)
444+
445+
# ALL_DATASETS_EXHAUSTED
446+
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.ALL_DATASETS_EXHAUSTED, tag_output=True)
447+
448+
results = list(sampler)
449+
# We should have 6 items total (1 + 2 + 3)
450+
self.assertEqual(len(results), 6)
451+
for item in results:
452+
self.assertIsInstance(item, dict)
453+
self.assertEqual(set(item.keys()), {"dataset_key", "data"})
454+
455+
# Count occurrences of each dataset
456+
dataset_counts = {}
457+
for item in results:
458+
ds = item["dataset_key"]
459+
dataset_counts[ds] = dataset_counts.get(ds, 0) + 1
460+
461+
self.assertEqual(dataset_counts["ds0"], 1)
462+
self.assertEqual(dataset_counts["ds1"], 2)
463+
self.assertEqual(dataset_counts["ds2"], 3)
464+
465+
# FIRST_DATASET_EXHAUSTED
466+
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED, tag_output=True)
467+
468+
results = list(sampler)
469+
# We should have 3 items total (one from each dataset)
470+
self.assertEqual(len(results), 3)
471+
472+
# Count occurrences of each dataset
473+
dataset_counts = {}
474+
for item in results:
475+
ds = item["dataset_key"]
476+
dataset_counts[ds] = dataset_counts.get(ds, 0) + 1
477+
478+
self.assertEqual(dataset_counts["ds0"], 1)
479+
self.assertEqual(dataset_counts["ds1"], 1)
480+
self.assertEqual(dataset_counts["ds2"], 1)
481+
482+
# CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED
483+
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED, tag_output=True)
484+
485+
results = list(sampler)
486+
# We should have items from all datasets, with the smallest dataset repeated until all are exhausted
487+
self.assertEqual(len(results), 11)
488+
489+
dataset_counts = {}
490+
for item in results:
491+
ds = item["dataset_key"]
492+
dataset_counts[ds] = dataset_counts.get(ds, 0) + 1
493+
494+
# ds0 should appear the most since it's recycled
495+
self.assertEqual(dataset_counts["ds0"], 4)
496+
self.assertEqual(dataset_counts["ds1"], 4)
497+
self.assertEqual(dataset_counts["ds2"], 3)
498+
499+
def test_tag_output_with_batching(self) -> None:
500+
"""Test that tagging works correctly with batching."""
501+
num_samples = 6
502+
num_datasets = 3
503+
batch_size = 3
504+
505+
datasets = self.get_equal_dataset(num_samples, num_datasets)
506+
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED, tag_output=True)
507+
508+
batcher = Batcher(sampler, batch_size=batch_size)
509+
510+
for batch in batcher:
511+
self.assertEqual(len(batch), batch_size)
512+
for item in batch:
513+
self.assertIsInstance(item, dict)
514+
self.assertEqual(set(item.keys()), {"dataset_key", "data"})
515+
self.assertTrue(item["dataset_key"].startswith("ds"))
516+
517+
def test_tag_output_invalid_inputs(self) -> None:
518+
"""Test validation of invalid tag_output inputs."""
519+
datasets = self.get_equal_dataset(3, 3)
520+
521+
# Test with invalid type
522+
with self.assertRaises(TypeError) as cm:
523+
MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED, tag_output=123)
524+
self.assertIn("tag_output must be a boolean (True/False), got", str(cm.exception))
525+
376526
def test_unequal_batch_size(self) -> None:
377527
datasets = self.get_unequal_dataset(self._num_samples, self._num_datasets)
378528

@@ -409,28 +559,29 @@ def test_get_state(self) -> None:
409559
StopCriteria.FIRST_DATASET_EXHAUSTED,
410560
StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED,
411561
],
562+
[True, False],
412563
)
413564
)
414-
def test_save_load_state(self, midpoint: int, stop_criteria: str) -> None:
565+
def test_save_load_state(self, midpoint: int, stop_criteria: str, tag_output) -> None:
415566
num_samples = 1500
416567
num_datasets = 3
417568
datasets = self.get_equal_dataset(num_samples, num_datasets)
418-
sampler = MultiNodeRoundRobinSampler(datasets, stop_criteria)
569+
sampler = MultiNodeRoundRobinSampler(datasets, stop_criteria, tag_output)
419570
prefetcher = Prefetcher(sampler, 3)
420571
run_test_save_load_state(self, prefetcher, midpoint)
421572

422573
datasets = self.get_unequal_dataset(num_samples, num_datasets)
423-
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.ALL_DATASETS_EXHAUSTED)
574+
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.ALL_DATASETS_EXHAUSTED, tag_output)
424575
prefetcher = Prefetcher(sampler, 3)
425576
run_test_save_load_state(self, prefetcher, 400)
426577

427578
datasets = self.get_equal_dataset(num_samples, num_datasets, as_list=True)
428-
sampler = MultiNodeRoundRobinSampler(datasets, stop_criteria)
579+
sampler = MultiNodeRoundRobinSampler(datasets, stop_criteria, tag_output)
429580
prefetcher = Prefetcher(sampler, 3)
430581
run_test_save_load_state(self, prefetcher, midpoint)
431582

432583
datasets = self.get_unequal_dataset(num_samples, num_datasets)
433-
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.ALL_DATASETS_EXHAUSTED)
584+
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.ALL_DATASETS_EXHAUSTED, tag_output)
434585
prefetcher = Prefetcher(sampler, 3)
435586
run_test_save_load_state(self, prefetcher, 400)
436587

torchdata/nodes/samplers/multi_node_round_robin_sampler.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
logger = logging.getLogger(__name__)
1414

1515

16-
class MultiNodeRoundRobinSampler(BaseNode[T]):
16+
class MultiNodeRoundRobinSampler(BaseNode[Union[T, Dict[str, Any]]]):
1717
"""A node that samples from multiple datasets in a round robin fashion.
1818
This node expects to take in a list or dictionary of source nodes. If a list is provided, it assumed that the order of the source nodes will be the same when the sampler is reset.
1919
The node implements the state using the following keys:
@@ -37,16 +37,18 @@ class MultiNodeRoundRobinSampler(BaseNode[T]):
3737
Args:
3838
source_nodes (Mapping[str, BaseNode[T]]): A dictionary of source nodes.
3939
stop_criteria (str): The stopping criteria. Default is CYCLE_UNTIL_ALL_DATASETS_EXHAUST.
40+
tag_output (bool): Whether to tag the output with the dataset name. Default is False.
4041
4142
Example:
4243
>>> # Dataset A: 1 element, Dataset B: 2 elements
4344
>>> sampler = MultiNodeRoundRobinSampler(
4445
... source_nodes={"A": A_node, "B": B_node},
4546
... stop_criteria=StopCriteria.FIRST_DATASET_EXHAUSTED
47+
... tag_output=True
4648
... )
4749
>>> list(sampler) # Yields: A, B, then A is exhausted
48-
[A_item, B_item1]
49-
If using StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED:
50+
['dataset_key': 'ds0', 'data': 'A_item'}, {'dataset_key': 'ds1', 'data': 'B_item1'}]
51+
If using StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED and tag_output=False:
5052
>>> list(sampler) # Yields: A, B, A (exhausted), B , A, then B is exhausted
5153
[A_item, B_item1, A_item, B_item2, A_item ]
5254
"""
@@ -60,6 +62,7 @@ def __init__(
6062
self,
6163
source_nodes: Union[Mapping[str, BaseNode[T]], List[BaseNode[T]]],
6264
stop_criteria: str = StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED,
65+
tag_output: bool = False,
6366
) -> None:
6467
super().__init__()
6568
if isinstance(source_nodes, list):
@@ -75,6 +78,9 @@ def __init__(
7578
self._current_dataset_index = 0
7679
self._validate_stop_criteria()
7780
self._datasets_exhausted = [False for _ in range(self.num_datasets)]
81+
if not isinstance(tag_output, bool):
82+
raise TypeError(f"tag_output must be a boolean (True/False), got {type(tag_output)}")
83+
self.output_keys = ("dataset_key", "data") if tag_output else None
7884

7985
def _validate_stop_criteria(self) -> None:
8086
if self.stop_criteria not in [
@@ -119,7 +125,7 @@ def _check_for_stop_iteration(self) -> None:
119125
raise StopIteration()
120126
return
121127

122-
def next(self) -> T:
128+
def next(self) -> Union[T, Dict[str, Any]]:
123129
while True:
124130
self._check_for_stop_iteration()
125131
current_iterator = self.source_nodes[self._current_dataset_index]
@@ -146,8 +152,16 @@ def next(self) -> T:
146152
self.source_nodes[self._current_dataset_index].reset()
147153
item = next(self.source_nodes[self._current_dataset_index])
148154
break
149-
# If we did't throw StopIteration, increment the number of items yielded and return the item
150-
self._current_dataset_index = (self._current_dataset_index + 1) % self.num_datasets
155+
# Capture dataset information before incrementing index
156+
dataset_idx = self._current_dataset_index
157+
dataset_name = self.dataset_keys[dataset_idx]
158+
self._current_dataset_index = (dataset_idx + 1) % self.num_datasets
159+
# Wrap item in dictionary if tagging is enabled
160+
if self.output_keys is not None:
161+
return {
162+
self.output_keys[0]: dataset_name,
163+
self.output_keys[1]: item,
164+
} # Type: ignore[return-value]
151165
return item
152166

153167
def get_state(self) -> Dict[str, Any]:

0 commit comments

Comments
 (0)