Skip to content

Commit 9a88f89

Browse files
custom allreduce + torch.compile (#10121)
Signed-off-by: youkaichao <[email protected]> Co-authored-by: youkaichao <[email protected]>
1 parent 519e8e4 commit 9a88f89

File tree

6 files changed

+59
-101
lines changed

6 files changed

+59
-101
lines changed

docs/source/getting_started/debugging.rst

-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ If GPU/CPU communication cannot be established, you can use the following Python
8686
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
8787
8888
pynccl = PyNcclCommunicator(group=gloo_group, device=local_rank)
89-
pynccl.disabled = False
9089
9190
s = torch.cuda.Stream()
9291
with torch.cuda.stream(s):

tests/distributed/test_pynccl.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def worker_fn():
6060
tensor = torch.ones(16, 1024, 1024,
6161
dtype=torch.float32).cuda(pynccl_comm.rank)
6262
with pynccl_comm.change_state(enable=True):
63-
pynccl_comm.all_reduce(tensor)
63+
tensor = pynccl_comm.all_reduce(tensor)
6464
result = tensor.mean().cpu().item()
6565
assert result == pynccl_comm.world_size
6666

@@ -84,12 +84,12 @@ def multiple_allreduce_worker_fn():
8484
with pynccl_comm.change_state(enable=True):
8585
# two groups can communicate independently
8686
if torch.distributed.get_rank() in [0, 1]:
87-
pynccl_comm.all_reduce(tensor)
88-
pynccl_comm.all_reduce(tensor)
87+
tensor = pynccl_comm.all_reduce(tensor)
88+
tensor = pynccl_comm.all_reduce(tensor)
8989
result = tensor.mean().cpu().item()
9090
assert result == 4
9191
else:
92-
pynccl_comm.all_reduce(tensor)
92+
tensor = pynccl_comm.all_reduce(tensor)
9393
result = tensor.mean().cpu().item()
9494
assert result == 2
9595

@@ -140,14 +140,11 @@ def worker_fn_with_cudagraph():
140140
with torch.cuda.graph(
141141
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
142142
enable=True):
143-
# operation during the graph capture is recorded but not executed
144-
# see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
145-
pynccl_comm.all_reduce(a)
143+
a_out = pynccl_comm.all_reduce(a)
146144
pynccl_comm.stream.synchronize()
147-
assert a.mean().cpu().item() == pynccl_comm.world_size**0
148145
graph.replay()
149146
pynccl_comm.stream.synchronize()
150-
assert a.mean().cpu().item() == pynccl_comm.world_size**1
147+
assert a_out.mean().cpu().item() == pynccl_comm.world_size**1
151148

152149

153150
@worker_fn_wrapper

tests/distributed/test_utils.py

-2
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,12 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
7070
rank=rank,
7171
world_size=WORLD_SIZE)
7272
pynccl1 = PyNcclCommunicator(pg1, device=rank)
73-
pynccl1.disabled = False
7473
if rank <= 2:
7574
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
7675
port=port2,
7776
rank=rank,
7877
world_size=3)
7978
pynccl2 = PyNcclCommunicator(pg2, device=rank)
80-
pynccl2.disabled = False
8179
data = torch.tensor([rank]).cuda()
8280
pynccl1.all_reduce(data)
8381
pg1.barrier()

vllm/distributed/device_communicators/pynccl.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -106,30 +106,30 @@ def __init__(
106106
self.stream.synchronize()
107107
del data
108108

109-
# by default it is disabled, e.g. in profiling models and prefill phase.
110-
# to use it, use under `with obj.change_state(enable=True)`, usually
111-
# when we are using CUDA graph.
112-
self.disabled = True
113-
114109
def all_reduce(self,
115-
tensor: torch.Tensor,
110+
in_tensor: torch.Tensor,
116111
op: ReduceOp = ReduceOp.SUM,
117-
stream=None):
112+
stream=None) -> torch.Tensor:
118113
if self.disabled:
119-
return
114+
return None
120115
# nccl communicator created on a specific device
121116
# will only work on tensors on the same device
122117
# otherwise it will cause "illegal memory access"
123-
assert tensor.device == self.device, (
118+
assert in_tensor.device == self.device, (
124119
f"this nccl communicator is created to work on {self.device}, "
125-
f"but the input tensor is on {tensor.device}")
120+
f"but the input tensor is on {in_tensor.device}")
121+
122+
out_tensor = torch.empty_like(in_tensor)
123+
126124
if stream is None:
127125
stream = self.stream
128-
self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
129-
buffer_type(tensor.data_ptr()), tensor.numel(),
130-
ncclDataTypeEnum.from_torch(tensor.dtype),
126+
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
127+
buffer_type(out_tensor.data_ptr()),
128+
in_tensor.numel(),
129+
ncclDataTypeEnum.from_torch(in_tensor.dtype),
131130
ncclRedOpTypeEnum.from_torch(op), self.comm,
132131
cudaStream_t(stream.cuda_stream))
132+
return out_tensor
133133

134134
def all_gather(self,
135135
output_tensor: torch.Tensor,

vllm/distributed/parallel_state.py

+36-74
Original file line numberDiff line numberDiff line change
@@ -96,42 +96,24 @@ def _register_group(group: "GroupCoordinator") -> None:
9696
_groups[group.unique_name] = weakref.ref(group)
9797

9898

99-
if supports_custom_op():
100-
101-
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
102-
assert group_name in _groups, f"Group {group_name} is not found."
103-
group = _groups[group_name]()
104-
if group is None:
105-
raise ValueError(f"Group {group_name} is destroyed.")
106-
group._all_reduce_in_place(tensor)
107-
108-
def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None:
109-
return
99+
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
100+
assert group_name in _groups, f"Group {group_name} is not found."
101+
group = _groups[group_name]()
102+
if group is None:
103+
raise ValueError(f"Group {group_name} is destroyed.")
104+
return group._all_reduce_out_place(tensor)
110105

111-
direct_register_custom_op(
112-
op_name="inplace_all_reduce",
113-
op_func=inplace_all_reduce,
114-
mutates_args=["tensor"],
115-
fake_impl=inplace_all_reduce_fake,
116-
)
117106

118-
def outplace_all_reduce(tensor: torch.Tensor,
119-
group_name: str) -> torch.Tensor:
120-
assert group_name in _groups, f"Group {group_name} is not found."
121-
group = _groups[group_name]()
122-
if group is None:
123-
raise ValueError(f"Group {group_name} is destroyed.")
124-
return group._all_reduce_out_place(tensor)
107+
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
108+
return torch.empty_like(tensor)
125109

126-
def outplace_all_reduce_fake(tensor: torch.Tensor,
127-
group_name: str) -> torch.Tensor:
128-
return torch.empty_like(tensor)
129110

111+
if supports_custom_op():
130112
direct_register_custom_op(
131-
op_name="outplace_all_reduce",
132-
op_func=outplace_all_reduce,
113+
op_name="all_reduce",
114+
op_func=all_reduce,
133115
mutates_args=[],
134-
fake_impl=outplace_all_reduce_fake,
116+
fake_impl=all_reduce_fake,
135117
)
136118

137119

@@ -317,30 +299,13 @@ def graph_capture(
317299
stream.wait_stream(curr_stream)
318300

319301
with torch.cuda.stream(stream), maybe_ca_context:
320-
# In graph mode, we have to be very careful about the collective
321-
# operations. The current status is:
322-
# allreduce \ Mode | Eager | Graph |
323-
# --------------------------------------------
324-
# custom allreduce | enabled | enabled |
325-
# PyNccl | disabled| enabled |
326-
# torch.distributed | enabled | disabled|
327-
#
328-
# Note that custom allreduce will have a runtime check, if the
329-
# tensor size is too large, it will fallback to the next
330-
# available option.
331-
# In summary: When using CUDA graph, we use
332-
# either custom all-reduce kernel or pynccl. When not using
333-
# CUDA graph, we use either custom all-reduce kernel or
334-
# PyTorch NCCL. We always prioritize using custom all-reduce
335-
# kernel but fall back to PyTorch or pynccl if it is
336-
# disabled or not supported.
337302
pynccl_comm = self.pynccl_comm
338303
maybe_pynccl_context: Any
339304
if not pynccl_comm:
340305
maybe_pynccl_context = nullcontext()
341306
else:
342307
maybe_pynccl_context = pynccl_comm.change_state(
343-
enable=True, stream=torch.cuda.current_stream())
308+
stream=torch.cuda.current_stream())
344309
with maybe_pynccl_context:
345310
yield graph_capture_context
346311

@@ -356,8 +321,8 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
356321
coordinator.
357322
358323
In addition, PyTorch custom ops do not support mutation or returning
359-
a new tensor in the same op. So we need to figure out if the op is
360-
in-place or out-of-place ahead of time.
324+
a new tensor in the same op. So we always make the all-reduce operation
325+
out-of-place.
361326
"""
362327
# Bypass the function if we are using only 1 GPU.
363328
if self.world_size == 1:
@@ -368,10 +333,6 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
368333
ipex.distributed.all_reduce(input_, group=self.device_group)
369334
return input_
370335

371-
if not supports_custom_op():
372-
self._all_reduce_in_place(input_)
373-
return input_
374-
375336
if self.tpu_communicator is not None and \
376337
not self.tpu_communicator.disabled:
377338
# TPU handles Dynamo with its own logic.
@@ -385,30 +346,31 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
385346
not self.xpu_communicator.disabled:
386347
return self.xpu_communicator.all_reduce(input_)
387348

388-
if self.ca_comm is not None and \
389-
not self.ca_comm.disabled and \
390-
self.ca_comm.should_custom_ar(input_):
391-
return torch.ops.vllm.outplace_all_reduce(
392-
input_, group_name=self.unique_name)
393-
else:
394-
torch.ops.vllm.inplace_all_reduce(input_,
395-
group_name=self.unique_name)
396-
return input_
349+
return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
397350

398351
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
352+
# always try custom allreduce first,
353+
# and then pynccl.
399354
ca_comm = self.ca_comm
400-
assert ca_comm is not None
401-
assert not ca_comm.disabled
402-
out = ca_comm.custom_all_reduce(input_)
403-
assert out is not None
404-
return out
405-
406-
def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
355+
if ca_comm is not None and not ca_comm.disabled and \
356+
ca_comm.should_custom_ar(input_):
357+
out = ca_comm.custom_all_reduce(input_)
358+
assert out is not None
359+
return out
407360
pynccl_comm = self.pynccl_comm
408-
if (pynccl_comm is not None and not pynccl_comm.disabled):
409-
pynccl_comm.all_reduce(input_)
410-
else:
411-
torch.distributed.all_reduce(input_, group=self.device_group)
361+
assert pynccl_comm is not None
362+
# TODO: pynccl should not use `stream=`
363+
# it can just always use the current stream.
364+
out = pynccl_comm.all_reduce(input_,
365+
stream=torch.cuda.current_stream())
366+
if out is None:
367+
# fall back to the default all-reduce using PyTorch.
368+
# this usually happens during testing.
369+
# when we run the model, allreduce only happens for the TP
370+
# group, where we always have either custom allreduce or pynccl.
371+
out = input_.clone()
372+
torch.distributed.all_reduce(out, group=self.device_group)
373+
return out
412374

413375
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
414376
world_size = self.world_size

vllm/v1/worker/gpu_model_runner.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from vllm.compilation.compile_context import set_compile_context
1212
from vllm.config import CompilationLevel, VllmConfig
13+
from vllm.distributed.parallel_state import graph_capture
1314
from vllm.forward_context import set_forward_context
1415
from vllm.inputs import INPUT_REGISTRY, InputRegistry
1516
from vllm.logger import init_logger
@@ -570,8 +571,9 @@ def capture_model(self) -> None:
570571
# Trigger CUDA graph capture for specific shapes.
571572
# Capture the large shapes first so that the smaller shapes
572573
# can reuse the memory pool allocated for the large shapes.
573-
for num_tokens in reversed(self.cudagraph_batch_sizes):
574-
self._dummy_run(self.model, num_tokens, self.kv_caches)
574+
with graph_capture():
575+
for num_tokens in reversed(self.cudagraph_batch_sizes):
576+
self._dummy_run(self.model, num_tokens, self.kv_caches)
575577

576578
end_time = time.perf_counter()
577579
end_free_gpu_memory = torch.cuda.mem_get_info()[0]

0 commit comments

Comments
 (0)