Skip to content

Commit 9e764e7

Browse files
authored
[distributed] remove pynccl's redundant change_state (vllm-project#11749)
1 parent 33fc1e2 commit 9e764e7

File tree

3 files changed

+28
-62
lines changed

3 files changed

+28
-62
lines changed

tests/distributed/test_pynccl.py

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ def worker_fn():
5959
device=get_world_group().device)
6060
tensor = torch.ones(16, 1024, 1024,
6161
dtype=torch.float32).cuda(pynccl_comm.rank)
62-
with pynccl_comm.change_state(enable=True):
63-
tensor = pynccl_comm.all_reduce(tensor)
62+
tensor = pynccl_comm.all_reduce(tensor)
6463
torch.cuda.synchronize()
6564
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
6665

@@ -81,17 +80,16 @@ def multiple_allreduce_worker_fn():
8180
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
8281
pynccl_comm = PyNcclCommunicator(group=group, device=device)
8382
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
84-
with pynccl_comm.change_state(enable=True):
85-
# two groups can communicate independently
86-
if torch.distributed.get_rank() in [0, 1]:
87-
tensor = pynccl_comm.all_reduce(tensor)
88-
tensor = pynccl_comm.all_reduce(tensor)
89-
torch.cuda.synchronize()
90-
assert torch.all(tensor == 4).cpu().item()
91-
else:
92-
tensor = pynccl_comm.all_reduce(tensor)
93-
torch.cuda.synchronize()
94-
assert torch.all(tensor == 2).cpu().item()
83+
# two groups can communicate independently
84+
if torch.distributed.get_rank() in [0, 1]:
85+
tensor = pynccl_comm.all_reduce(tensor)
86+
tensor = pynccl_comm.all_reduce(tensor)
87+
torch.cuda.synchronize()
88+
assert torch.all(tensor == 4).cpu().item()
89+
else:
90+
tensor = pynccl_comm.all_reduce(tensor)
91+
torch.cuda.synchronize()
92+
assert torch.all(tensor == 2).cpu().item()
9593

9694

9795
@pytest.mark.skipif(torch.cuda.device_count() < 4,
@@ -137,8 +135,7 @@ def worker_fn_with_cudagraph():
137135
# run something in the default stream to initialize torch engine
138136
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
139137
torch.cuda.synchronize()
140-
with torch.cuda.graph(graph), \
141-
pynccl_comm.change_state(enable=True):
138+
with torch.cuda.graph(graph):
142139
a_out = pynccl_comm.all_reduce(a)
143140
torch.cuda.synchronize()
144141
graph.replay()
@@ -167,8 +164,7 @@ def all_gather_worker_fn():
167164
for r in range(world_size)
168165
]).to(device)
169166

170-
with pynccl_comm.change_state(enable=True):
171-
pynccl_comm.all_gather(result, tensor)
167+
pynccl_comm.all_gather(result, tensor)
172168
torch.cuda.synchronize()
173169
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
174170

@@ -205,8 +201,7 @@ def reduce_scatter_worker_fn():
205201
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
206202
for tensor in all_tensors).to(device)
207203

208-
with pynccl_comm.change_state(enable=True):
209-
pynccl_comm.reduce_scatter(result, tensor)
204+
pynccl_comm.reduce_scatter(result, tensor)
210205
torch.cuda.synchronize()
211206
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
212207

@@ -233,15 +228,13 @@ def send_recv_worker_fn():
233228
else:
234229
tensor = torch.empty(16, 1024, 1024,
235230
dtype=torch.float32).cuda(pynccl_comm.rank)
236-
with pynccl_comm.change_state(enable=True):
237-
if pynccl_comm.rank == 0:
238-
pynccl_comm.send(tensor,
239-
dst=(pynccl_comm.rank + 1) %
240-
pynccl_comm.world_size)
241-
else:
242-
pynccl_comm.recv(tensor,
243-
src=(pynccl_comm.rank - 1) %
244-
pynccl_comm.world_size)
231+
232+
if pynccl_comm.rank == 0:
233+
pynccl_comm.send(tensor,
234+
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
235+
else:
236+
pynccl_comm.recv(tensor,
237+
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
245238
torch.cuda.synchronize()
246239
assert torch.all(tensor == 1).cpu().item()
247240

@@ -272,15 +265,12 @@ def multiple_send_recv_worker_fn():
272265
1024,
273266
dtype=torch.float32,
274267
device=device)
275-
with pynccl_comm.change_state(enable=True):
276-
if torch.distributed.get_rank() in [0, 1]:
277-
pynccl_comm.send(tensor,
278-
dst=(pynccl_comm.rank + 1) %
279-
pynccl_comm.world_size)
280-
else:
281-
pynccl_comm.recv(tensor,
282-
src=(pynccl_comm.rank - 1) %
283-
pynccl_comm.world_size)
268+
if torch.distributed.get_rank() in [0, 1]:
269+
pynccl_comm.send(tensor,
270+
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
271+
else:
272+
pynccl_comm.recv(tensor,
273+
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
284274
torch.cuda.synchronize()
285275
if torch.distributed.get_rank() in [0, 2]:
286276
assert torch.all(tensor == 1).cpu().item()

vllm/distributed/device_communicators/pynccl.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from contextlib import contextmanager
21
from typing import Optional, Union
32

43
# ===================== import region =====================
@@ -213,19 +212,3 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
213212
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
214213
ncclDataTypeEnum.from_torch(tensor.dtype), src,
215214
self.comm, cudaStream_t(stream.cuda_stream))
216-
217-
@contextmanager
218-
def change_state(self, enable: Optional[bool] = None):
219-
"""
220-
A context manager to change the state of the communicator.
221-
"""
222-
if enable is None:
223-
# guess a default value when not specified
224-
enable = self.available
225-
226-
old_disable = self.disabled
227-
228-
self.disabled = not enable
229-
yield
230-
231-
self.disabled = old_disable

vllm/distributed/parallel_state.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -305,14 +305,7 @@ def graph_capture(
305305
stream.wait_stream(curr_stream)
306306

307307
with torch.cuda.stream(stream), maybe_ca_context:
308-
pynccl_comm = self.pynccl_comm
309-
maybe_pynccl_context: Any
310-
if not pynccl_comm:
311-
maybe_pynccl_context = nullcontext()
312-
else:
313-
maybe_pynccl_context = pynccl_comm.change_state()
314-
with maybe_pynccl_context:
315-
yield graph_capture_context
308+
yield graph_capture_context
316309

317310
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
318311
"""

0 commit comments

Comments
 (0)