Skip to content

Commit 177ad85

Browse files
committed
Revert "[distributed] remove pynccl's redundant stream (vllm-project#11744)"
This reverts commit 635b897.
1 parent 9be84c0 commit 177ad85

File tree

3 files changed

+24
-12
lines changed

3 files changed

+24
-12
lines changed

tests/distributed/test_pynccl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,9 @@ def worker_fn_with_cudagraph():
137137
# run something in the default stream to initialize torch engine
138138
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
139139
torch.cuda.synchronize()
140-
with torch.cuda.graph(graph), \
141-
pynccl_comm.change_state(enable=True):
140+
with torch.cuda.graph(
141+
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
142+
enable=True):
142143
a_out = pynccl_comm.all_reduce(a)
143144
torch.cuda.synchronize()
144145
graph.replay()

vllm/distributed/device_communicators/pynccl.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(
5151
if self.world_size == 1:
5252
self.available = False
5353
self.disabled = True
54+
self.stream = None
5455
return
5556
try:
5657
self.nccl = NCCLLibrary(library_path)
@@ -59,6 +60,7 @@ def __init__(
5960
# e.g. in a non-GPU environment
6061
self.available = False
6162
self.disabled = True
63+
self.stream = None
6264
return
6365

6466
self.available = True
@@ -96,12 +98,12 @@ def __init__(
9698
with torch.cuda.device(device):
9799
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
98100
self.world_size, self.unique_id, self.rank)
101+
self.stream = torch.cuda.Stream()
99102

100-
stream = torch.cuda.current_stream()
101103
# A small all_reduce for warmup.
102104
data = torch.zeros(1, device=device)
103105
self.all_reduce(data)
104-
stream.synchronize()
106+
self.stream.synchronize()
105107
del data
106108

107109
def all_reduce(self,
@@ -120,7 +122,7 @@ def all_reduce(self,
120122
out_tensor = torch.empty_like(in_tensor)
121123

122124
if stream is None:
123-
stream = torch.cuda.current_stream()
125+
stream = self.stream
124126
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
125127
buffer_type(out_tensor.data_ptr()),
126128
in_tensor.numel(),
@@ -142,7 +144,7 @@ def all_gather(self,
142144
f"this nccl communicator is created to work on {self.device}, "
143145
f"but the input tensor is on {input_tensor.device}")
144146
if stream is None:
145-
stream = torch.cuda.current_stream()
147+
stream = self.stream
146148
self.nccl.ncclAllGather(
147149
buffer_type(input_tensor.data_ptr()),
148150
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
@@ -163,7 +165,7 @@ def reduce_scatter(self,
163165
f"this nccl communicator is created to work on {self.device}, "
164166
f"but the input tensor is on {input_tensor.device}")
165167
if stream is None:
166-
stream = torch.cuda.current_stream()
168+
stream = self.stream
167169
self.nccl.ncclReduceScatter(
168170
buffer_type(input_tensor.data_ptr()),
169171
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
@@ -178,7 +180,7 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None):
178180
f"this nccl communicator is created to work on {self.device}, "
179181
f"but the input tensor is on {tensor.device}")
180182
if stream is None:
181-
stream = torch.cuda.current_stream()
183+
stream = self.stream
182184
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
183185
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
184186
self.comm, cudaStream_t(stream.cuda_stream))
@@ -190,7 +192,7 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None):
190192
f"this nccl communicator is created to work on {self.device}, "
191193
f"but the input tensor is on {tensor.device}")
192194
if stream is None:
193-
stream = torch.cuda.current_stream()
195+
stream = self.stream
194196
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
195197
ncclDataTypeEnum.from_torch(tensor.dtype), src,
196198
self.comm, cudaStream_t(stream.cuda_stream))
@@ -202,7 +204,7 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
202204
f"this nccl communicator is created to work on {self.device}, "
203205
f"but the input tensor is on {tensor.device}")
204206
if stream is None:
205-
stream = torch.cuda.current_stream()
207+
stream = self.stream
206208
if src == self.rank:
207209
sendbuff = buffer_type(tensor.data_ptr())
208210
# NCCL requires the sender also to have a receive buffer
@@ -215,17 +217,25 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
215217
self.comm, cudaStream_t(stream.cuda_stream))
216218

217219
@contextmanager
218-
def change_state(self, enable: Optional[bool] = None):
220+
def change_state(self,
221+
enable: Optional[bool] = None,
222+
stream: Optional[torch.cuda.Stream] = None):
219223
"""
220224
A context manager to change the state of the communicator.
221225
"""
222226
if enable is None:
223227
# guess a default value when not specified
224228
enable = self.available
225229

230+
if stream is None:
231+
stream = self.stream
232+
226233
old_disable = self.disabled
234+
old_stream = self.stream
227235

236+
self.stream = stream
228237
self.disabled = not enable
229238
yield
230239

231240
self.disabled = old_disable
241+
self.stream = old_stream

vllm/distributed/parallel_state.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,8 @@ def graph_capture(
310310
if not pynccl_comm:
311311
maybe_pynccl_context = nullcontext()
312312
else:
313-
maybe_pynccl_context = pynccl_comm.change_state()
313+
maybe_pynccl_context = pynccl_comm.change_state(
314+
stream=torch.cuda.current_stream())
314315
with maybe_pynccl_context:
315316
yield graph_capture_context
316317

0 commit comments

Comments
 (0)