Skip to content

Commit 08b7a10

Browse files
cennnrasmith
authored andcommitted
[distributed] remove pynccl's redundant stream (vllm-project#11744)
1 parent 3b22418 commit 08b7a10

File tree

3 files changed

+12
-24
lines changed

3 files changed

+12
-24
lines changed

tests/distributed/test_pynccl.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,8 @@ def worker_fn_with_cudagraph():
137137
# run something in the default stream to initialize torch engine
138138
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
139139
torch.cuda.synchronize()
140-
with torch.cuda.graph(
141-
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
142-
enable=True):
140+
with torch.cuda.graph(graph), \
141+
pynccl_comm.change_state(enable=True):
143142
a_out = pynccl_comm.all_reduce(a)
144143
torch.cuda.synchronize()
145144
graph.replay()

vllm/distributed/device_communicators/pynccl.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def __init__(
5151
if self.world_size == 1:
5252
self.available = False
5353
self.disabled = True
54-
self.stream = None
5554
return
5655
try:
5756
self.nccl = NCCLLibrary(library_path)
@@ -60,7 +59,6 @@ def __init__(
6059
# e.g. in a non-GPU environment
6160
self.available = False
6261
self.disabled = True
63-
self.stream = None
6462
return
6563

6664
self.available = True
@@ -98,12 +96,12 @@ def __init__(
9896
with torch.cuda.device(device):
9997
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
10098
self.world_size, self.unique_id, self.rank)
101-
self.stream = torch.cuda.Stream()
10299

100+
stream = torch.cuda.current_stream()
103101
# A small all_reduce for warmup.
104102
data = torch.zeros(1, device=device)
105103
self.all_reduce(data)
106-
self.stream.synchronize()
104+
stream.synchronize()
107105
del data
108106

109107
def all_reduce(self,
@@ -122,7 +120,7 @@ def all_reduce(self,
122120
out_tensor = torch.empty_like(in_tensor)
123121

124122
if stream is None:
125-
stream = self.stream
123+
stream = torch.cuda.current_stream()
126124
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
127125
buffer_type(out_tensor.data_ptr()),
128126
in_tensor.numel(),
@@ -144,7 +142,7 @@ def all_gather(self,
144142
f"this nccl communicator is created to work on {self.device}, "
145143
f"but the input tensor is on {input_tensor.device}")
146144
if stream is None:
147-
stream = self.stream
145+
stream = torch.cuda.current_stream()
148146
self.nccl.ncclAllGather(
149147
buffer_type(input_tensor.data_ptr()),
150148
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
@@ -165,7 +163,7 @@ def reduce_scatter(self,
165163
f"this nccl communicator is created to work on {self.device}, "
166164
f"but the input tensor is on {input_tensor.device}")
167165
if stream is None:
168-
stream = self.stream
166+
stream = torch.cuda.current_stream()
169167
self.nccl.ncclReduceScatter(
170168
buffer_type(input_tensor.data_ptr()),
171169
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
@@ -180,7 +178,7 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None):
180178
f"this nccl communicator is created to work on {self.device}, "
181179
f"but the input tensor is on {tensor.device}")
182180
if stream is None:
183-
stream = self.stream
181+
stream = torch.cuda.current_stream()
184182
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
185183
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
186184
self.comm, cudaStream_t(stream.cuda_stream))
@@ -192,7 +190,7 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None):
192190
f"this nccl communicator is created to work on {self.device}, "
193191
f"but the input tensor is on {tensor.device}")
194192
if stream is None:
195-
stream = self.stream
193+
stream = torch.cuda.current_stream()
196194
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
197195
ncclDataTypeEnum.from_torch(tensor.dtype), src,
198196
self.comm, cudaStream_t(stream.cuda_stream))
@@ -204,7 +202,7 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
204202
f"this nccl communicator is created to work on {self.device}, "
205203
f"but the input tensor is on {tensor.device}")
206204
if stream is None:
207-
stream = self.stream
205+
stream = torch.cuda.current_stream()
208206
if src == self.rank:
209207
sendbuff = buffer_type(tensor.data_ptr())
210208
# NCCL requires the sender also to have a receive buffer
@@ -217,25 +215,17 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
217215
self.comm, cudaStream_t(stream.cuda_stream))
218216

219217
@contextmanager
220-
def change_state(self,
221-
enable: Optional[bool] = None,
222-
stream: Optional[torch.cuda.Stream] = None):
218+
def change_state(self, enable: Optional[bool] = None):
223219
"""
224220
A context manager to change the state of the communicator.
225221
"""
226222
if enable is None:
227223
# guess a default value when not specified
228224
enable = self.available
229225

230-
if stream is None:
231-
stream = self.stream
232-
233226
old_disable = self.disabled
234-
old_stream = self.stream
235227

236-
self.stream = stream
237228
self.disabled = not enable
238229
yield
239230

240231
self.disabled = old_disable
241-
self.stream = old_stream

vllm/distributed/parallel_state.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,7 @@ def graph_capture(
310310
if not pynccl_comm:
311311
maybe_pynccl_context = nullcontext()
312312
else:
313-
maybe_pynccl_context = pynccl_comm.change_state(
314-
stream=torch.cuda.current_stream())
313+
maybe_pynccl_context = pynccl_comm.change_state()
315314
with maybe_pynccl_context:
316315
yield graph_capture_context
317316

0 commit comments

Comments
 (0)