Skip to content

Commit a3b34a0

Browse files
ejguanfacebook-github-bot
authored andcommitted
Update Prefetcher and Implement PinMemory IterDataPipe (#1014)
Summary: Fixes #1013 ## Changes - Simplify the control flow of prefetcher - Delay Exception raised from thread worker to main thread in `__iter__` - Stop prefetching whenever Exception is received - As long as `stop_iteration` is not turned on or `buffer` is not empty, continue yielding data from `__iter__`. - Add serialization test - Add `PinMemory` DataPipe - `is_replciable() -> False` to keep it in the main process - Add unit tests - Update `test_proto_multi_rs.py` to `test_mprs.py` Pull Request resolved: #1014 Reviewed By: NivekT Differential Revision: D43329696 Pulled By: ejguan fbshipit-source-id: da4326dbe2388f4e23b9a1a3a5c43da09d29185a
1 parent efcc766 commit a3b34a0

File tree

10 files changed

+235
-46
lines changed

10 files changed

+235
-46
lines changed

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Features described in this documentation are classified by release status:
4242
dataloader2.rst
4343
reading_service.rst
4444

45+
4546
.. toctree::
4647
:maxdepth: 2
4748
:caption: Tutorial and Examples:

docs/source/torchdata.datapipes.utils.rst

+11
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@ DataPipe Graph Visualization
1515

1616
to_graph
1717

18+
Commond Utility Functions
19+
--------------------------------------
20+
.. currentmodule:: torchdata.datapipes.utils
21+
22+
.. autosummary::
23+
:nosignatures:
24+
:toctree: generated/
25+
:template: function.rst
26+
27+
pin_memory_fn
28+
1829

1930
File Object and Stream Utility
2031
-------------------------------------

test/dataloader2/test_proto_multi_rs.py renamed to test/dataloader2/test_mprs.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from unittest import TestCase
1111

1212
from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize
13-
from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, PrototypeMultiProcessingReadingService
13+
from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, MultiProcessingReadingService
1414
from torchdata.datapipes.iter import IterableWrapper
1515

1616

@@ -29,9 +29,9 @@ def _add_one(x: int) -> int:
2929
dp_parametrize = parametrize("dp", test_dps)
3030

3131

32-
class TestPrototypeMultiProcessingReadingService(TestCase):
32+
class TestMultiProcessingReadingService(TestCase):
3333
r"""
34-
This tests specific functionalities of PrototypeMultiProcessingReadingService, notably
34+
This tests specific functionalities of MultiProcessingReadingService, notably
3535
`pause`, `resume`, `snapshot`.
3636
"""
3737

@@ -40,7 +40,7 @@ def test_reading_service_pause_resume_0_worker(self, ctx) -> None:
4040

4141
# Functional Test: Verifies that this ReadingService will raise error when `pause/resume` is used
4242
# with `num_workers = 0`
43-
rs0 = PrototypeMultiProcessingReadingService(
43+
rs0 = MultiProcessingReadingService(
4444
num_workers=0, worker_prefetch_cnt=0, main_prefetch_cnt=0, multiprocessing_context=ctx
4545
)
4646
dl0: DataLoader2 = DataLoader2(dp1, reading_service=rs0)
@@ -64,7 +64,7 @@ def test_reading_service_pause_resume(self, ctx, dp, n_workers, worker_prefetch_
6464

6565
# Functional Test: Testing various configuration of DataPipe/ReadingService to ensure the pipeline
6666
# properly pauses and resumes
67-
rs = PrototypeMultiProcessingReadingService(
67+
rs = MultiProcessingReadingService(
6868
num_workers=n_workers,
6969
worker_prefetch_cnt=worker_prefetch_cnt,
7070
main_prefetch_cnt=main_prefetch_cnt,
@@ -93,7 +93,7 @@ def test_reading_service_pause_resume(self, ctx, dp, n_workers, worker_prefetch_
9393
def test_reading_service_pause_stop_yield(self, ctx, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None:
9494

9595
# Functional Test: Confirms that `dl` will stop yielding elements after `_pause` is called
96-
rs = PrototypeMultiProcessingReadingService(
96+
rs = MultiProcessingReadingService(
9797
num_workers=n_workers,
9898
worker_prefetch_cnt=worker_prefetch_cnt,
9999
main_prefetch_cnt=main_prefetch_cnt,
@@ -117,7 +117,7 @@ def test_reading_service_pause_stop_yield(self, ctx, dp, n_workers, worker_prefe
117117
@parametrize("n_workers,worker_prefetch_cnt,main_prefetch_cnt", [(1, 0, 0), (1, 0, 2), (2, 0, 0), (2, 2, 2)])
118118
def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None:
119119

120-
rs = PrototypeMultiProcessingReadingService(
120+
rs = MultiProcessingReadingService(
121121
num_workers=n_workers, worker_prefetch_cnt=worker_prefetch_cnt, main_prefetch_cnt=main_prefetch_cnt
122122
)
123123

@@ -209,10 +209,10 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr
209209
# those DPs belong to a dispatching process and only do pause if worker_id == 0
210210
# There might still be a race condition, need to look into the messages
211211

212-
# rs1 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0)
213-
# rs2 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2)
214-
# rs3 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0)
215-
# rs4 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2)
212+
# rs1 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0)
213+
# rs2 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2)
214+
# rs3 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0)
215+
# rs4 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2)
216216
# rss = [rs1, rs2, rs3, rs4]
217217

218218
# for n, rs in enumerate(rss):
@@ -284,7 +284,7 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr
284284
# pass
285285

286286

287-
instantiate_parametrized_tests(TestPrototypeMultiProcessingReadingService)
287+
instantiate_parametrized_tests(TestMultiProcessingReadingService)
288288

289289

290290
if __name__ == "__main__":

test/test_iterdatapipe.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import Dict
1515

1616
import expecttest
17-
import torch.utils.data.datapipes.iter
17+
import torch
1818

1919
import torchdata
2020

@@ -42,6 +42,8 @@
4242
)
4343
from torchdata.datapipes.map import MapDataPipe, SequenceWrapper
4444

45+
skipIfNoCUDA = unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available")
46+
4547

4648
def test_torchdata_pytorch_consistency() -> None:
4749
def extract_datapipe_names(module):
@@ -68,6 +70,14 @@ def extract_datapipe_names(module):
6870
raise AssertionError(msg + "\n".join(sorted(missing_datapipes)))
6971

7072

73+
def _convert_to_tensor(data):
74+
if isinstance(data, dict):
75+
return {k: _convert_to_tensor(v) for k, v in data.items()}
76+
elif isinstance(data, list):
77+
return [_convert_to_tensor(v) for v in data]
78+
return torch.tensor(data)
79+
80+
7181
class TestIterDataPipe(expecttest.TestCase):
7282
def test_in_memory_cache_holder_iterdatapipe(self) -> None:
7383
source_dp = IterableWrapper(range(10))
@@ -1475,6 +1485,38 @@ def test_random_splitter_iterdatapipe(self):
14751485
next(it_train)
14761486
next(it_valid) # No error, can keep going
14771487

1488+
@skipIfNoCUDA
1489+
def test_pin_memory(self):
1490+
# Tensor
1491+
dp = IterableWrapper([(i, i + 1) for i in range(10)]).map(_convert_to_tensor).pin_memory()
1492+
self.assertTrue(all(d.is_pinned() for d in dp))
1493+
1494+
# List of Tensors
1495+
dp = IterableWrapper([[(i - 1, i), (i, i + 1)] for i in range(10)]).map(_convert_to_tensor).pin_memory()
1496+
self.assertTrue(all(d0.is_pinned() and d1.is_pinned() for d0, d1 in dp))
1497+
1498+
# Dict of Tensors
1499+
dp = IterableWrapper([{str(i): (i, i + 1)} for i in range(10)]).map(_convert_to_tensor).pin_memory()
1500+
self.assertTrue(all(v.is_pinned() for d in dp for v in d.values()))
1501+
1502+
# Dict of List of Tensors
1503+
dp = (
1504+
IterableWrapper([{str(i): [(i - 1, i), (i, i + 1)]} for i in range(10)])
1505+
.map(_convert_to_tensor)
1506+
.pin_memory()
1507+
)
1508+
self.assertTrue(all(v.is_pinned() for d in dp for batch in d.values() for v in batch))
1509+
1510+
# List of Dict of Tensors
1511+
dp = IterableWrapper([{str(i): (i, i + 1)} for i in range(10)]).map(_convert_to_tensor).batch(2).pin_memory()
1512+
self.assertTrue(all(v.is_pinned() for batch in dp for d in batch for v in d.values()))
1513+
1514+
# List of List of Tensors
1515+
dp = (
1516+
IterableWrapper([[(i - 1, i), (i, i + 1)] for i in range(10)]).map(_convert_to_tensor).batch(2).pin_memory()
1517+
)
1518+
self.assertTrue(all(d0.is_pinned() and d1.is_pinned() for batch in dp for d0, d1 in batch))
1519+
14781520

14791521
if __name__ == "__main__":
14801522
unittest.main()

test/test_serialization.py

+5
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ def _filter_by_module_availability(datapipes):
9292
return [dp for dp in datapipes if dp[0] not in filter_set]
9393

9494

95+
def _convert_to_tensor(data):
96+
return torch.tensor(data)
97+
98+
9599
class TestIterDataPipeSerialization(expecttest.TestCase):
96100
def setUp(self):
97101
self.temp_dir = create_temp_dir()
@@ -272,6 +276,7 @@ def test_serializable(self):
272276
(),
273277
{},
274278
),
279+
(iterdp.Prefetcher, None, (), {}),
275280
(iterdp.ParquetDataFrameLoader, None, (), {"dtype": DTYPE}),
276281
(iterdp.RarArchiveLoader, None, (), {}),
277282
(

torchdata/datapipes/iter/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,10 @@
108108
CSVParserIterDataPipe as CSVParser,
109109
LineReaderIterDataPipe as LineReader,
110110
)
111-
from torchdata.datapipes.iter.util.prefetcher import PrefetcherIterDataPipe as Prefetcher
111+
from torchdata.datapipes.iter.util.prefetcher import (
112+
PinMemoryIterDataPipe as PinMemory,
113+
PrefetcherIterDataPipe as Prefetcher,
114+
)
112115
from torchdata.datapipes.iter.util.randomsplitter import RandomSplitterIterDataPipe as RandomSplitter
113116
from torchdata.datapipes.iter.util.rararchiveloader import RarArchiveLoaderIterDataPipe as RarArchiveLoader
114117
from torchdata.datapipes.iter.util.rows2columnar import Rows2ColumnarIterDataPipe as Rows2Columnar
@@ -187,6 +190,7 @@
187190
"OnlineReader",
188191
"ParagraphAggregator",
189192
"ParquetDataFrameLoader",
193+
"PinMemory",
190194
"Prefetcher",
191195
"RandomSplitter",
192196
"RarArchiveLoader",

torchdata/datapipes/iter/__init__.pyi.in

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ ${init_base}
1010
from .util.decompressor import CompressionType
1111
from torchdata._constants import default_timeout_in_s
1212
from torchdata.datapipes.map import MapDataPipe
13+
from torchdata.datapipes.utils import pin_memory_fn
1314
from torch.utils.data import DataChunk, IterableDataset, default_collate
1415
from torch.utils.data.datapipes._typing import _DataPipeMeta
1516
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES

0 commit comments

Comments
 (0)