Skip to content

Commit 310aca8

Browse files
authored
[perf]fix current stream (#11870)
Signed-off-by: youkaichao <[email protected]>
1 parent a732900 commit 310aca8

File tree

4 files changed

+46
-15
lines changed

4 files changed

+46
-15
lines changed

vllm/distributed/device_communicators/pynccl.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ncclRedOpTypeEnum, ncclUniqueId)
1111
from vllm.distributed.utils import StatelessProcessGroup
1212
from vllm.logger import init_logger
13+
from vllm.utils import current_stream
1314

1415
logger = init_logger(__name__)
1516

@@ -96,7 +97,7 @@ def __init__(
9697
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
9798
self.world_size, self.unique_id, self.rank)
9899

99-
stream = torch.cuda.current_stream()
100+
stream = current_stream()
100101
# A small all_reduce for warmup.
101102
data = torch.zeros(1, device=device)
102103
self.all_reduce(data)
@@ -119,7 +120,7 @@ def all_reduce(self,
119120
out_tensor = torch.empty_like(in_tensor)
120121

121122
if stream is None:
122-
stream = torch.cuda.current_stream()
123+
stream = current_stream()
123124
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
124125
buffer_type(out_tensor.data_ptr()),
125126
in_tensor.numel(),
@@ -141,7 +142,7 @@ def all_gather(self,
141142
f"this nccl communicator is created to work on {self.device}, "
142143
f"but the input tensor is on {input_tensor.device}")
143144
if stream is None:
144-
stream = torch.cuda.current_stream()
145+
stream = current_stream()
145146
self.nccl.ncclAllGather(
146147
buffer_type(input_tensor.data_ptr()),
147148
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
@@ -162,7 +163,7 @@ def reduce_scatter(self,
162163
f"this nccl communicator is created to work on {self.device}, "
163164
f"but the input tensor is on {input_tensor.device}")
164165
if stream is None:
165-
stream = torch.cuda.current_stream()
166+
stream = current_stream()
166167
self.nccl.ncclReduceScatter(
167168
buffer_type(input_tensor.data_ptr()),
168169
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
@@ -177,7 +178,7 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None):
177178
f"this nccl communicator is created to work on {self.device}, "
178179
f"but the input tensor is on {tensor.device}")
179180
if stream is None:
180-
stream = torch.cuda.current_stream()
181+
stream = current_stream()
181182
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
182183
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
183184
self.comm, cudaStream_t(stream.cuda_stream))
@@ -189,7 +190,7 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None):
189190
f"this nccl communicator is created to work on {self.device}, "
190191
f"but the input tensor is on {tensor.device}")
191192
if stream is None:
192-
stream = torch.cuda.current_stream()
193+
stream = current_stream()
193194
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
194195
ncclDataTypeEnum.from_torch(tensor.dtype), src,
195196
self.comm, cudaStream_t(stream.cuda_stream))
@@ -201,7 +202,7 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
201202
f"this nccl communicator is created to work on {self.device}, "
202203
f"but the input tensor is on {tensor.device}")
203204
if stream is None:
204-
stream = torch.cuda.current_stream()
205+
stream = current_stream()
205206
if src == self.rank:
206207
sendbuff = buffer_type(tensor.data_ptr())
207208
# NCCL requires the sender also to have a receive buffer

vllm/distributed/parallel_state.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,10 +357,7 @@ def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
357357
return out
358358
pynccl_comm = self.pynccl_comm
359359
assert pynccl_comm is not None
360-
# TODO: pynccl should not use `stream=`
361-
# it can just always use the current stream.
362-
out = pynccl_comm.all_reduce(input_,
363-
stream=torch.cuda.current_stream())
360+
out = pynccl_comm.all_reduce(input_)
364361
if out is None:
365362
# fall back to the default all-reduce using PyTorch.
366363
# this usually happens during testing.

vllm/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,39 @@ def find_nccl_library() -> str:
944944
return so_file
945945

946946

947+
prev_set_stream = torch.cuda.set_stream
948+
949+
_current_stream = None
950+
951+
952+
def _patched_set_stream(stream: torch.cuda.Stream) -> None:
953+
global _current_stream
954+
_current_stream = stream
955+
prev_set_stream(stream)
956+
957+
958+
torch.cuda.set_stream = _patched_set_stream
959+
960+
961+
def current_stream() -> torch.cuda.Stream:
962+
"""
963+
replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`.
964+
it turns out that `torch.cuda.current_stream()` is quite expensive,
965+
as it will construct a new stream object at each call.
966+
here we patch `torch.cuda.set_stream` to keep track of the current stream
967+
directly, so that we can avoid calling `torch.cuda.current_stream()`.
968+
969+
the underlying hypothesis is that we do not call `torch._C._cuda_setStream`
970+
from C/C++ code.
971+
"""
972+
global _current_stream
973+
if _current_stream is None:
974+
# when this function is called before any stream is set,
975+
# we return the default stream.
976+
_current_stream = torch.cuda.current_stream()
977+
return _current_stream
978+
979+
947980
def enable_trace_function_call_for_thread(vllm_config: "VllmConfig") -> None:
948981
"""Set up function tracing for the current thread,
949982
if enabled via the VLLM_TRACE_FUNCTION environment variable

vllm/worker/multi_step_model_runner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
get_pythonized_sample_results)
1515
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
1616
Logprob, SequenceGroupMetadata, SequenceOutput)
17-
from vllm.utils import PyObjectCache, async_tensor_h2d
17+
from vllm.utils import PyObjectCache, async_tensor_h2d, current_stream
1818
from vllm.worker.model_runner import (GPUModelRunnerBase,
1919
ModelInputForGPUWithSamplingMetadata)
2020
from vllm.worker.model_runner_base import (
@@ -498,7 +498,7 @@ def execute_model(
498498
# appended sampler output from last iteration
499499
# - also maybe pythonize if CPU is ahead of GPU
500500

501-
current_stream = torch.cuda.current_stream()
501+
stream = current_stream()
502502
if not model_input.is_first_multi_step:
503503
# Explicitly block on the previous step's forward to make sure we
504504
# don't clobber any GPU tensors still in use.
@@ -541,7 +541,7 @@ def execute_model(
541541
num_steps=1)
542542

543543
# record the event for the current step so that the next step can sync
544-
model_input.record_step_event(current_stream)
544+
model_input.record_step_event(stream)
545545

546546
if get_pp_group().is_last_rank and self.is_driver_worker:
547547
assert isinstance(output, list)
@@ -552,7 +552,7 @@ def execute_model(
552552
# event for the pythonization so that we only pythonize if the
553553
# tensors are ready. May be able to be combined with the step event
554554
output_ready_event = torch.cuda.Event()
555-
output_ready_event.record(current_stream)
555+
output_ready_event.record(stream)
556556
if self.parallel_config.pipeline_parallel_size > 1:
557557
output[0].sampled_token_ids_cpu = output[
558558
0].sampled_token_ids.cpu()

0 commit comments

Comments
 (0)