@@ -59,7 +59,8 @@ def worker_fn():
59
59
device = get_world_group ().device )
60
60
tensor = torch .ones (16 , 1024 , 1024 ,
61
61
dtype = torch .float32 ).cuda (pynccl_comm .rank )
62
- tensor = pynccl_comm .all_reduce (tensor )
62
+ with pynccl_comm .change_state (enable = True ):
63
+ tensor = pynccl_comm .all_reduce (tensor )
63
64
torch .cuda .synchronize ()
64
65
assert torch .all (tensor == pynccl_comm .world_size ).cpu ().item ()
65
66
@@ -80,16 +81,17 @@ def multiple_allreduce_worker_fn():
80
81
group = groups [0 ] if torch .distributed .get_rank () in [0 , 1 ] else groups [1 ]
81
82
pynccl_comm = PyNcclCommunicator (group = group , device = device )
82
83
tensor = torch .ones (16 , 1024 , 1024 , dtype = torch .float32 , device = device )
83
- # two groups can communicate independently
84
- if torch .distributed .get_rank () in [0 , 1 ]:
85
- tensor = pynccl_comm .all_reduce (tensor )
86
- tensor = pynccl_comm .all_reduce (tensor )
87
- torch .cuda .synchronize ()
88
- assert torch .all (tensor == 4 ).cpu ().item ()
89
- else :
90
- tensor = pynccl_comm .all_reduce (tensor )
91
- torch .cuda .synchronize ()
92
- assert torch .all (tensor == 2 ).cpu ().item ()
84
+ with pynccl_comm .change_state (enable = True ):
85
+ # two groups can communicate independently
86
+ if torch .distributed .get_rank () in [0 , 1 ]:
87
+ tensor = pynccl_comm .all_reduce (tensor )
88
+ tensor = pynccl_comm .all_reduce (tensor )
89
+ torch .cuda .synchronize ()
90
+ assert torch .all (tensor == 4 ).cpu ().item ()
91
+ else :
92
+ tensor = pynccl_comm .all_reduce (tensor )
93
+ torch .cuda .synchronize ()
94
+ assert torch .all (tensor == 2 ).cpu ().item ()
93
95
94
96
95
97
@pytest .mark .skipif (torch .cuda .device_count () < 4 ,
@@ -135,7 +137,8 @@ def worker_fn_with_cudagraph():
135
137
# run something in the default stream to initialize torch engine
136
138
a = torch .ones ((4 , 4 ), device = f'cuda:{ pynccl_comm .rank } ' )
137
139
torch .cuda .synchronize ()
138
- with torch .cuda .graph (graph ):
140
+ with torch .cuda .graph (graph ), \
141
+ pynccl_comm .change_state (enable = True ):
139
142
a_out = pynccl_comm .all_reduce (a )
140
143
torch .cuda .synchronize ()
141
144
graph .replay ()
@@ -164,7 +167,8 @@ def all_gather_worker_fn():
164
167
for r in range (world_size )
165
168
]).to (device )
166
169
167
- pynccl_comm .all_gather (result , tensor )
170
+ with pynccl_comm .change_state (enable = True ):
171
+ pynccl_comm .all_gather (result , tensor )
168
172
torch .cuda .synchronize ()
169
173
torch .testing .assert_close (result , expected , rtol = 1e-5 , atol = 1e-8 )
170
174
@@ -201,7 +205,8 @@ def reduce_scatter_worker_fn():
201
205
expected = sum (tensor [rank * scattered_size :(rank + 1 ) * scattered_size ]
202
206
for tensor in all_tensors ).to (device )
203
207
204
- pynccl_comm .reduce_scatter (result , tensor )
208
+ with pynccl_comm .change_state (enable = True ):
209
+ pynccl_comm .reduce_scatter (result , tensor )
205
210
torch .cuda .synchronize ()
206
211
torch .testing .assert_close (result , expected , rtol = 1e-5 , atol = 1e-8 )
207
212
@@ -228,13 +233,15 @@ def send_recv_worker_fn():
228
233
else :
229
234
tensor = torch .empty (16 , 1024 , 1024 ,
230
235
dtype = torch .float32 ).cuda (pynccl_comm .rank )
231
-
232
- if pynccl_comm .rank == 0 :
233
- pynccl_comm .send (tensor ,
234
- dst = (pynccl_comm .rank + 1 ) % pynccl_comm .world_size )
235
- else :
236
- pynccl_comm .recv (tensor ,
237
- src = (pynccl_comm .rank - 1 ) % pynccl_comm .world_size )
236
+ with pynccl_comm .change_state (enable = True ):
237
+ if pynccl_comm .rank == 0 :
238
+ pynccl_comm .send (tensor ,
239
+ dst = (pynccl_comm .rank + 1 ) %
240
+ pynccl_comm .world_size )
241
+ else :
242
+ pynccl_comm .recv (tensor ,
243
+ src = (pynccl_comm .rank - 1 ) %
244
+ pynccl_comm .world_size )
238
245
torch .cuda .synchronize ()
239
246
assert torch .all (tensor == 1 ).cpu ().item ()
240
247
@@ -265,12 +272,15 @@ def multiple_send_recv_worker_fn():
265
272
1024 ,
266
273
dtype = torch .float32 ,
267
274
device = device )
268
- if torch .distributed .get_rank () in [0 , 1 ]:
269
- pynccl_comm .send (tensor ,
270
- dst = (pynccl_comm .rank + 1 ) % pynccl_comm .world_size )
271
- else :
272
- pynccl_comm .recv (tensor ,
273
- src = (pynccl_comm .rank - 1 ) % pynccl_comm .world_size )
275
+ with pynccl_comm .change_state (enable = True ):
276
+ if torch .distributed .get_rank () in [0 , 1 ]:
277
+ pynccl_comm .send (tensor ,
278
+ dst = (pynccl_comm .rank + 1 ) %
279
+ pynccl_comm .world_size )
280
+ else :
281
+ pynccl_comm .recv (tensor ,
282
+ src = (pynccl_comm .rank - 1 ) %
283
+ pynccl_comm .world_size )
274
284
torch .cuda .synchronize ()
275
285
if torch .distributed .get_rank () in [0 , 2 ]:
276
286
assert torch .all (tensor == 1 ).cpu ().item ()
0 commit comments