Skip to content

Commit fd644da

Browse files
committed
[Communicator] Add monkey patch (vllm-project#30)
Some PR for plugin support is not merged by vllm yet. This PR add monkey patch to vllm-ascend to make vllm-ascend work with vllm directly. This patch code should be removed once the related function is supported by vllm originally. Signed-off-by: wangxiyuan <[email protected]>
1 parent cb28d33 commit fd644da

File tree

4 files changed

+142
-4
lines changed

4 files changed

+142
-4
lines changed

vllm_ascend/communicator.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,62 @@
1717

1818
import torch
1919
import torch.distributed as dist
20-
from vllm.distributed.device_communicators.base_communicator import \
21-
CommunicatorBase
2220

2321

24-
class NPUCommunicator(CommunicatorBase):
22+
class NPUCommunicator:
23+
24+
def __init__(self, group, unique_name=""):
25+
self.group = group
26+
self.unique_name = unique_name
27+
self.rank = dist.get_rank(group)
28+
self.world_size = dist.get_world_size(self.group)
29+
self.ranks = dist.get_process_group_ranks(self.group)
30+
global_rank = dist.get_rank()
31+
self.rank_in_group = dist.get_group_rank(self.group, global_rank)
2532

2633
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
2734
dist.all_reduce(x, group=self.group)
2835
return x
36+
37+
def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1):
38+
# NOTE: We assume that the input tensor is on the same device across
39+
# all the ranks.
40+
# NOTE: `dst` is the local rank of the destination rank.
41+
# Allocate output tensor.
42+
if self.rank_in_group == dst:
43+
gather_list = [
44+
torch.empty_like(input_) for _ in range(self.world_size)
45+
]
46+
else:
47+
gather_list = None
48+
# Gather.
49+
dist.gather(input_, gather_list, dst=self.ranks[dst], group=self.group)
50+
if self.rank_in_group == dst:
51+
output_tensor = torch.cat(gather_list, dim=dim)
52+
else:
53+
output_tensor = None
54+
return output_tensor
55+
56+
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
57+
if dim < 0:
58+
# Convert negative dim to positive.
59+
dim += input_.dim()
60+
input_size = input_.size()
61+
# NOTE: we have to use concat-style all-gather here,
62+
# stack-style all-gather has compatibility issues with
63+
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
64+
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
65+
# Allocate output tensor.
66+
output_tensor = torch.empty(output_size,
67+
dtype=input_.dtype,
68+
device=input_.device)
69+
# All-gather.
70+
dist.all_gather_into_tensor(output_tensor, input_, group=self.group)
71+
# Reshape
72+
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
73+
output_tensor = output_tensor.movedim(0, dim)
74+
output_tensor = output_tensor.reshape(input_size[:dim] +
75+
(self.world_size *
76+
input_size[dim], ) +
77+
input_size[dim + 1:])
78+
return output_tensor

vllm_ascend/patch/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from vllm_ascend.patch import patch_commnicator # noqa
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
# This file is used to monkey patch communicator in vllm to support ascend.
18+
# Remove this file when vllm support by
19+
# https://github.com/vllm-project/vllm/pull/11324.
20+
21+
import torch
22+
from vllm.distributed.parallel_state import GroupCoordinator
23+
from vllm.utils import resolve_obj_by_qualname
24+
25+
26+
class GroupCoordinatorPatch(GroupCoordinator):
27+
28+
def __init__(self, *args, **kwargs):
29+
super().__init__(*args, **kwargs)
30+
self.device = torch.device(f"npu:{self.local_rank}")
31+
32+
from vllm.platforms import current_platform
33+
device_comm_cls = resolve_obj_by_qualname(
34+
current_platform.get_device_communicator_cls())
35+
# we have checked and ensure that reusing tpu tag here is fine.
36+
use_custom_device = kwargs.get("use_tpu_communicator", False)
37+
if use_custom_device and self.world_size > 1:
38+
self.communicator = device_comm_cls(group=self.device_group,
39+
unique_name=self.unique_name)
40+
41+
def all_reduce(self, input_):
42+
# Bypass the function if we are using only 1 device.
43+
if self.world_size == 1:
44+
return input_
45+
46+
return self.communicator.all_reduce(input_)
47+
48+
def gather(self, input_, dst=0, dim=-1):
49+
# Bypass the function if we are using only 1 device.
50+
if self.world_size == 1:
51+
return input_
52+
assert -input_.dim() <= dim < input_.dim(), (
53+
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
54+
if dim < 0:
55+
# Convert negative dim to positive.
56+
dim += input_.dim()
57+
58+
return self.communicator.gather(input_, dst, dim)
59+
60+
def all_gather(self, input_, dim=-1):
61+
# Bypass the function if we are using only 1 device.
62+
if self.world_size == 1:
63+
return input_
64+
assert -input_.dim() <= dim < input_.dim(), (
65+
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
66+
return self.communicator.all_gather(input_, dim)
67+
68+
69+
GroupCoordinator = GroupCoordinatorPatch

vllm_ascend/platform.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,9 @@ def mem_get_info(cls) -> Tuple[int, int]:
9696

9797
@classmethod
9898
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
99-
# Register ops when setup.
99+
# Register ops and patch when setup.
100100
from vllm_ascend import ops # noqa: F401
101+
from vllm_ascend import patch # noqa: F401
101102

102103
parallel_config = vllm_config.parallel_config
103104
if parallel_config.worker_cls == "auto":

0 commit comments

Comments
 (0)