Skip to content

Commit 8d97fbb

Browse files
authored
Make parallel mapper work on dataset specific transformations
Differential Revision: D73187135 Pull Request resolved: #1486
1 parent ac259e1 commit 8d97fbb

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

test/nodes/test_multi_node_weighted_sampler.py

+65
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,68 @@ def test_multi_node_weighted_large_sample_size_with_prefetcher(self, midpoint, s
308308
stop_criteria,
309309
)
310310
run_test_save_load_state(self, node, midpoint)
311+
312+
def test_multi_node_weighted_sampler_tag_output_dict_items(self) -> None:
313+
"""Test MultiNodeWeightedSampler with tag_output=True for dictionary items"""
314+
node = MultiNodeWeightedSampler(
315+
self.datasets,
316+
self.weights,
317+
tag_output=True,
318+
)
319+
320+
results = list(node)
321+
322+
# Verify that each result has a 'dataset_key' key with the correct dataset name
323+
for result in results:
324+
self.assertIn("dataset_key", result)
325+
326+
dataset_name = result["dataset_key"]
327+
self.assertIn(dataset_name, [f"ds{i}" for i in range(self._num_datasets)])
328+
329+
self.assertIn("name", result)
330+
self.assertIn("test_tensor", result)
331+
332+
self.assertEqual(dataset_name, result["name"])
333+
334+
def test_multi_node_weighted_sampler_tag_output_non_dict_items(self) -> None:
335+
"""Test MultiNodeWeightedSampler with tag_output=True for non-dictionary items"""
336+
non_dict_datasets = {
337+
f"ds{i}": IterableWrapper(range(i * 10, (i + 1) * 10))
338+
for i in range(self._num_datasets)
339+
}
340+
341+
node = MultiNodeWeightedSampler(
342+
non_dict_datasets,
343+
self.weights,
344+
tag_output=True,
345+
)
346+
347+
results = list(node)
348+
349+
# Verify that each result is now a dictionary with 'data' and 'dataset_key' keys
350+
for result in results:
351+
self.assertIsInstance(result, dict)
352+
353+
self.assertIn("data", result)
354+
self.assertIn("dataset_key", result)
355+
356+
dataset_name = result["dataset_key"]
357+
self.assertIn(dataset_name, [f"ds{i}" for i in range(self._num_datasets)])
358+
359+
def test_multi_node_weighted_sampler_tag_output_false(self) -> None:
360+
"""Test MultiNodeWeightedSampler with tag_output=False (default behavior)"""
361+
node = MultiNodeWeightedSampler(
362+
self.datasets,
363+
self.weights,
364+
tag_output=False,
365+
)
366+
367+
results = list(node)
368+
369+
# Verify that none of the results have a 'dataset' key
370+
for result in results:
371+
self.assertNotIn("dataset", result)
372+
373+
# Check that the original data is preserved
374+
self.assertIn("name", result)
375+
self.assertIn("test_tensor", result)

torchdata/nodes/samplers/multi_node_weighted_sampler.py

+12
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class MultiNodeWeightedSampler(BaseNode[T]):
4848
world_size (int): The world size of the distributed environment. Default is None, in
4949
which case the world size will be obtained from the distributed environment.
5050
seed (int): The seed for the random number generator. Default is 0.
51+
tag_output (bool): Whether to tag the output with the dataset name. Default is False.
5152
"""
5253

5354
DATASET_NODE_STATES_KEY = "dataset_node_states"
@@ -64,6 +65,7 @@ def __init__(
6465
rank: Optional[int] = None,
6566
world_size: Optional[int] = None,
6667
seed: int = 0,
68+
tag_output: bool = False,
6769
) -> None:
6870
super().__init__()
6971

@@ -74,6 +76,7 @@ def __init__(
7476
self._num_yielded = 0
7577
self._started = False
7678
self.seed = seed
79+
self.tag_output = tag_output
7780

7881
# Setup rank and world size
7982
if rank is None or world_size is None:
@@ -194,8 +197,17 @@ def next(self) -> T:
194197

195198
# If we did't throw StopIteration, increment the number of items yielded and return the item
196199
self._num_yielded += 1
200+
201+
# If tag_output is True, add the dataset key to the output
202+
if self.tag_output:
203+
if isinstance(item, dict): # type: ignore[used-before-def]
204+
item["dataset_key"] = key # type: ignore[used-before-def]
205+
else:
206+
item = {"dataset_key": key, "data": item}
207+
197208
return item
198209

210+
199211
def get_state(self) -> Dict[str, Any]:
200212
return {
201213
self.DATASETS_EXHAUSTED_KEY: copy.deepcopy(self._datasets_exhausted),

0 commit comments

Comments
 (0)