10
10
from unittest import TestCase
11
11
12
12
from torch .testing ._internal .common_utils import instantiate_parametrized_tests , parametrize
13
- from torchdata .dataloader2 import DataLoader2 , DataLoader2Iterator , PrototypeMultiProcessingReadingService
13
+ from torchdata .dataloader2 import DataLoader2 , DataLoader2Iterator , MultiProcessingReadingService
14
14
from torchdata .datapipes .iter import IterableWrapper
15
15
16
16
@@ -29,9 +29,9 @@ def _add_one(x: int) -> int:
29
29
dp_parametrize = parametrize ("dp" , test_dps )
30
30
31
31
32
- class TestPrototypeMultiProcessingReadingService (TestCase ):
32
+ class TestMultiProcessingReadingService (TestCase ):
33
33
r"""
34
- This tests specific functionalities of PrototypeMultiProcessingReadingService , notably
34
+ This tests specific functionalities of MultiProcessingReadingService , notably
35
35
`pause`, `resume`, `snapshot`.
36
36
"""
37
37
@@ -40,7 +40,7 @@ def test_reading_service_pause_resume_0_worker(self, ctx) -> None:
40
40
41
41
# Functional Test: Verifies that this ReadingService will raise error when `pause/resume` is used
42
42
# with `num_workers = 0`
43
- rs0 = PrototypeMultiProcessingReadingService (
43
+ rs0 = MultiProcessingReadingService (
44
44
num_workers = 0 , worker_prefetch_cnt = 0 , main_prefetch_cnt = 0 , multiprocessing_context = ctx
45
45
)
46
46
dl0 : DataLoader2 = DataLoader2 (dp1 , reading_service = rs0 )
@@ -64,7 +64,7 @@ def test_reading_service_pause_resume(self, ctx, dp, n_workers, worker_prefetch_
64
64
65
65
# Functional Test: Testing various configuration of DataPipe/ReadingService to ensure the pipeline
66
66
# properly pauses and resumes
67
- rs = PrototypeMultiProcessingReadingService (
67
+ rs = MultiProcessingReadingService (
68
68
num_workers = n_workers ,
69
69
worker_prefetch_cnt = worker_prefetch_cnt ,
70
70
main_prefetch_cnt = main_prefetch_cnt ,
@@ -93,7 +93,7 @@ def test_reading_service_pause_resume(self, ctx, dp, n_workers, worker_prefetch_
93
93
def test_reading_service_pause_stop_yield (self , ctx , dp , n_workers , worker_prefetch_cnt , main_prefetch_cnt ) -> None :
94
94
95
95
# Functional Test: Confirms that `dl` will stop yielding elements after `_pause` is called
96
- rs = PrototypeMultiProcessingReadingService (
96
+ rs = MultiProcessingReadingService (
97
97
num_workers = n_workers ,
98
98
worker_prefetch_cnt = worker_prefetch_cnt ,
99
99
main_prefetch_cnt = main_prefetch_cnt ,
@@ -117,7 +117,7 @@ def test_reading_service_pause_stop_yield(self, ctx, dp, n_workers, worker_prefe
117
117
@parametrize ("n_workers,worker_prefetch_cnt,main_prefetch_cnt" , [(1 , 0 , 0 ), (1 , 0 , 2 ), (2 , 0 , 0 ), (2 , 2 , 2 )])
118
118
def test_reading_service_limit (self , dp , n_workers , worker_prefetch_cnt , main_prefetch_cnt ) -> None :
119
119
120
- rs = PrototypeMultiProcessingReadingService (
120
+ rs = MultiProcessingReadingService (
121
121
num_workers = n_workers , worker_prefetch_cnt = worker_prefetch_cnt , main_prefetch_cnt = main_prefetch_cnt
122
122
)
123
123
@@ -209,10 +209,10 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr
209
209
# those DPs belong to a dispatching process and only do pause if worker_id == 0
210
210
# There might still be a race condition, need to look into the messages
211
211
212
- # rs1 = PrototypeMultiProcessingReadingService (num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0)
213
- # rs2 = PrototypeMultiProcessingReadingService (num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2)
214
- # rs3 = PrototypeMultiProcessingReadingService (num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0)
215
- # rs4 = PrototypeMultiProcessingReadingService (num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2)
212
+ # rs1 = MultiProcessingReadingService (num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0)
213
+ # rs2 = MultiProcessingReadingService (num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2)
214
+ # rs3 = MultiProcessingReadingService (num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0)
215
+ # rs4 = MultiProcessingReadingService (num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2)
216
216
# rss = [rs1, rs2, rs3, rs4]
217
217
218
218
# for n, rs in enumerate(rss):
@@ -284,7 +284,7 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr
284
284
# pass
285
285
286
286
287
- instantiate_parametrized_tests (TestPrototypeMultiProcessingReadingService )
287
+ instantiate_parametrized_tests (TestMultiProcessingReadingService )
288
288
289
289
290
290
if __name__ == "__main__" :
0 commit comments