@@ -59,8 +59,7 @@ def worker_fn():
59
59
device = get_world_group ().device )
60
60
tensor = torch .ones (16 , 1024 , 1024 ,
61
61
dtype = torch .float32 ).cuda (pynccl_comm .rank )
62
- with pynccl_comm .change_state (enable = True ):
63
- tensor = pynccl_comm .all_reduce (tensor )
62
+ tensor = pynccl_comm .all_reduce (tensor )
64
63
torch .cuda .synchronize ()
65
64
assert torch .all (tensor == pynccl_comm .world_size ).cpu ().item ()
66
65
@@ -81,17 +80,16 @@ def multiple_allreduce_worker_fn():
81
80
group = groups [0 ] if torch .distributed .get_rank () in [0 , 1 ] else groups [1 ]
82
81
pynccl_comm = PyNcclCommunicator (group = group , device = device )
83
82
tensor = torch .ones (16 , 1024 , 1024 , dtype = torch .float32 , device = device )
84
- with pynccl_comm .change_state (enable = True ):
85
- # two groups can communicate independently
86
- if torch .distributed .get_rank () in [0 , 1 ]:
87
- tensor = pynccl_comm .all_reduce (tensor )
88
- tensor = pynccl_comm .all_reduce (tensor )
89
- torch .cuda .synchronize ()
90
- assert torch .all (tensor == 4 ).cpu ().item ()
91
- else :
92
- tensor = pynccl_comm .all_reduce (tensor )
93
- torch .cuda .synchronize ()
94
- assert torch .all (tensor == 2 ).cpu ().item ()
83
+ # two groups can communicate independently
84
+ if torch .distributed .get_rank () in [0 , 1 ]:
85
+ tensor = pynccl_comm .all_reduce (tensor )
86
+ tensor = pynccl_comm .all_reduce (tensor )
87
+ torch .cuda .synchronize ()
88
+ assert torch .all (tensor == 4 ).cpu ().item ()
89
+ else :
90
+ tensor = pynccl_comm .all_reduce (tensor )
91
+ torch .cuda .synchronize ()
92
+ assert torch .all (tensor == 2 ).cpu ().item ()
95
93
96
94
97
95
@pytest .mark .skipif (torch .cuda .device_count () < 4 ,
@@ -137,8 +135,7 @@ def worker_fn_with_cudagraph():
137
135
# run something in the default stream to initialize torch engine
138
136
a = torch .ones ((4 , 4 ), device = f'cuda:{ pynccl_comm .rank } ' )
139
137
torch .cuda .synchronize ()
140
- with torch .cuda .graph (graph ), \
141
- pynccl_comm .change_state (enable = True ):
138
+ with torch .cuda .graph (graph ):
142
139
a_out = pynccl_comm .all_reduce (a )
143
140
torch .cuda .synchronize ()
144
141
graph .replay ()
@@ -167,8 +164,7 @@ def all_gather_worker_fn():
167
164
for r in range (world_size )
168
165
]).to (device )
169
166
170
- with pynccl_comm .change_state (enable = True ):
171
- pynccl_comm .all_gather (result , tensor )
167
+ pynccl_comm .all_gather (result , tensor )
172
168
torch .cuda .synchronize ()
173
169
torch .testing .assert_close (result , expected , rtol = 1e-5 , atol = 1e-8 )
174
170
@@ -205,8 +201,7 @@ def reduce_scatter_worker_fn():
205
201
expected = sum (tensor [rank * scattered_size :(rank + 1 ) * scattered_size ]
206
202
for tensor in all_tensors ).to (device )
207
203
208
- with pynccl_comm .change_state (enable = True ):
209
- pynccl_comm .reduce_scatter (result , tensor )
204
+ pynccl_comm .reduce_scatter (result , tensor )
210
205
torch .cuda .synchronize ()
211
206
torch .testing .assert_close (result , expected , rtol = 1e-5 , atol = 1e-8 )
212
207
@@ -233,15 +228,13 @@ def send_recv_worker_fn():
233
228
else :
234
229
tensor = torch .empty (16 , 1024 , 1024 ,
235
230
dtype = torch .float32 ).cuda (pynccl_comm .rank )
236
- with pynccl_comm .change_state (enable = True ):
237
- if pynccl_comm .rank == 0 :
238
- pynccl_comm .send (tensor ,
239
- dst = (pynccl_comm .rank + 1 ) %
240
- pynccl_comm .world_size )
241
- else :
242
- pynccl_comm .recv (tensor ,
243
- src = (pynccl_comm .rank - 1 ) %
244
- pynccl_comm .world_size )
231
+
232
+ if pynccl_comm .rank == 0 :
233
+ pynccl_comm .send (tensor ,
234
+ dst = (pynccl_comm .rank + 1 ) % pynccl_comm .world_size )
235
+ else :
236
+ pynccl_comm .recv (tensor ,
237
+ src = (pynccl_comm .rank - 1 ) % pynccl_comm .world_size )
245
238
torch .cuda .synchronize ()
246
239
assert torch .all (tensor == 1 ).cpu ().item ()
247
240
@@ -272,15 +265,12 @@ def multiple_send_recv_worker_fn():
272
265
1024 ,
273
266
dtype = torch .float32 ,
274
267
device = device )
275
- with pynccl_comm .change_state (enable = True ):
276
- if torch .distributed .get_rank () in [0 , 1 ]:
277
- pynccl_comm .send (tensor ,
278
- dst = (pynccl_comm .rank + 1 ) %
279
- pynccl_comm .world_size )
280
- else :
281
- pynccl_comm .recv (tensor ,
282
- src = (pynccl_comm .rank - 1 ) %
283
- pynccl_comm .world_size )
268
+ if torch .distributed .get_rank () in [0 , 1 ]:
269
+ pynccl_comm .send (tensor ,
270
+ dst = (pynccl_comm .rank + 1 ) % pynccl_comm .world_size )
271
+ else :
272
+ pynccl_comm .recv (tensor ,
273
+ src = (pynccl_comm .rank - 1 ) % pynccl_comm .world_size )
284
274
torch .cuda .synchronize ()
285
275
if torch .distributed .get_rank () in [0 , 2 ]:
286
276
assert torch .all (tensor == 1 ).cpu ().item ()
0 commit comments