Skip to content

Commit 9be84c0

Browse files
committed
Revert "[distributed] remove pynccl's redundant change_state (vllm-project#11749)"
This reverts commit 9e764e7.
1 parent 88e020d commit 9be84c0

File tree

3 files changed

+62
-28
lines changed

3 files changed

+62
-28
lines changed

tests/distributed/test_pynccl.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ def worker_fn():
5959
device=get_world_group().device)
6060
tensor = torch.ones(16, 1024, 1024,
6161
dtype=torch.float32).cuda(pynccl_comm.rank)
62-
tensor = pynccl_comm.all_reduce(tensor)
62+
with pynccl_comm.change_state(enable=True):
63+
tensor = pynccl_comm.all_reduce(tensor)
6364
torch.cuda.synchronize()
6465
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
6566

@@ -80,16 +81,17 @@ def multiple_allreduce_worker_fn():
8081
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
8182
pynccl_comm = PyNcclCommunicator(group=group, device=device)
8283
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
83-
# two groups can communicate independently
84-
if torch.distributed.get_rank() in [0, 1]:
85-
tensor = pynccl_comm.all_reduce(tensor)
86-
tensor = pynccl_comm.all_reduce(tensor)
87-
torch.cuda.synchronize()
88-
assert torch.all(tensor == 4).cpu().item()
89-
else:
90-
tensor = pynccl_comm.all_reduce(tensor)
91-
torch.cuda.synchronize()
92-
assert torch.all(tensor == 2).cpu().item()
84+
with pynccl_comm.change_state(enable=True):
85+
# two groups can communicate independently
86+
if torch.distributed.get_rank() in [0, 1]:
87+
tensor = pynccl_comm.all_reduce(tensor)
88+
tensor = pynccl_comm.all_reduce(tensor)
89+
torch.cuda.synchronize()
90+
assert torch.all(tensor == 4).cpu().item()
91+
else:
92+
tensor = pynccl_comm.all_reduce(tensor)
93+
torch.cuda.synchronize()
94+
assert torch.all(tensor == 2).cpu().item()
9395

9496

9597
@pytest.mark.skipif(torch.cuda.device_count() < 4,
@@ -135,7 +137,8 @@ def worker_fn_with_cudagraph():
135137
# run something in the default stream to initialize torch engine
136138
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
137139
torch.cuda.synchronize()
138-
with torch.cuda.graph(graph):
140+
with torch.cuda.graph(graph), \
141+
pynccl_comm.change_state(enable=True):
139142
a_out = pynccl_comm.all_reduce(a)
140143
torch.cuda.synchronize()
141144
graph.replay()
@@ -164,7 +167,8 @@ def all_gather_worker_fn():
164167
for r in range(world_size)
165168
]).to(device)
166169

167-
pynccl_comm.all_gather(result, tensor)
170+
with pynccl_comm.change_state(enable=True):
171+
pynccl_comm.all_gather(result, tensor)
168172
torch.cuda.synchronize()
169173
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
170174

@@ -201,7 +205,8 @@ def reduce_scatter_worker_fn():
201205
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
202206
for tensor in all_tensors).to(device)
203207

204-
pynccl_comm.reduce_scatter(result, tensor)
208+
with pynccl_comm.change_state(enable=True):
209+
pynccl_comm.reduce_scatter(result, tensor)
205210
torch.cuda.synchronize()
206211
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
207212

@@ -228,13 +233,15 @@ def send_recv_worker_fn():
228233
else:
229234
tensor = torch.empty(16, 1024, 1024,
230235
dtype=torch.float32).cuda(pynccl_comm.rank)
231-
232-
if pynccl_comm.rank == 0:
233-
pynccl_comm.send(tensor,
234-
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
235-
else:
236-
pynccl_comm.recv(tensor,
237-
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
236+
with pynccl_comm.change_state(enable=True):
237+
if pynccl_comm.rank == 0:
238+
pynccl_comm.send(tensor,
239+
dst=(pynccl_comm.rank + 1) %
240+
pynccl_comm.world_size)
241+
else:
242+
pynccl_comm.recv(tensor,
243+
src=(pynccl_comm.rank - 1) %
244+
pynccl_comm.world_size)
238245
torch.cuda.synchronize()
239246
assert torch.all(tensor == 1).cpu().item()
240247

@@ -265,12 +272,15 @@ def multiple_send_recv_worker_fn():
265272
1024,
266273
dtype=torch.float32,
267274
device=device)
268-
if torch.distributed.get_rank() in [0, 1]:
269-
pynccl_comm.send(tensor,
270-
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
271-
else:
272-
pynccl_comm.recv(tensor,
273-
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
275+
with pynccl_comm.change_state(enable=True):
276+
if torch.distributed.get_rank() in [0, 1]:
277+
pynccl_comm.send(tensor,
278+
dst=(pynccl_comm.rank + 1) %
279+
pynccl_comm.world_size)
280+
else:
281+
pynccl_comm.recv(tensor,
282+
src=(pynccl_comm.rank - 1) %
283+
pynccl_comm.world_size)
274284
torch.cuda.synchronize()
275285
if torch.distributed.get_rank() in [0, 2]:
276286
assert torch.all(tensor == 1).cpu().item()

vllm/distributed/device_communicators/pynccl.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from contextlib import contextmanager
12
from typing import Optional, Union
23

34
# ===================== import region =====================
@@ -212,3 +213,19 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
212213
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
213214
ncclDataTypeEnum.from_torch(tensor.dtype), src,
214215
self.comm, cudaStream_t(stream.cuda_stream))
216+
217+
@contextmanager
218+
def change_state(self, enable: Optional[bool] = None):
219+
"""
220+
A context manager to change the state of the communicator.
221+
"""
222+
if enable is None:
223+
# guess a default value when not specified
224+
enable = self.available
225+
226+
old_disable = self.disabled
227+
228+
self.disabled = not enable
229+
yield
230+
231+
self.disabled = old_disable

vllm/distributed/parallel_state.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,14 @@ def graph_capture(
305305
stream.wait_stream(curr_stream)
306306

307307
with torch.cuda.stream(stream), maybe_ca_context:
308-
yield graph_capture_context
308+
pynccl_comm = self.pynccl_comm
309+
maybe_pynccl_context: Any
310+
if not pynccl_comm:
311+
maybe_pynccl_context = nullcontext()
312+
else:
313+
maybe_pynccl_context = pynccl_comm.change_state()
314+
with maybe_pynccl_context:
315+
yield graph_capture_context
309316

310317
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
311318
"""

0 commit comments

Comments
 (0)