Skip to content

Add monkey patch #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed

Conversation

wangxiyuan
Copy link
Collaborator

Some PR for plugin support is not merged by vllm yet. This PR add monkey patch to vllm-ascend to make vllm-ascend work with vllm directly.

This patch code should be removed once the related function is supported by vllm originally.

@wangxiyuan wangxiyuan changed the title Add monckey patch Add monkey patch Feb 10, 2025
from vllm.platforms import current_platform
device_comm_cls = resolve_obj_by_qualname(
current_platform.get_device_communicator_cls())
self.communicator = device_comm_cls(group=self.device_group,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should check if use_xxx_communicator (any is fine because they remain the same) and world_size > 1 is true before creating communicator.
https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py#L167-L169

Besides model parallel group, there will be a world group, which won't use any device communication. Adding this check will reduce time when creating the world group.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added world_size check in the new Patch. There is no use_xxx_communicator in vllm.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean use_tpu_communicator, use_xpu_communicator or use_hpu_communicator, any one of them is ok

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are checked in supper.init, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, the check of use_tpu_communicator in supper.init only work for tpu_communicator, we use it here for npu communicator, because there is no bool value for npu to control this check.
I think we could just use use_tpu_communicator as all the use_xxx_communicator remains the same in vLLM

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got your idea. Thanks. i'll update then

from vllm.platforms import current_platform
device_comm_cls = resolve_obj_by_qualname(
current_platform.get_device_communicator_cls())
self.communicator = device_comm_cls(group=self.device_group,
Copy link
Collaborator

@Yikun Yikun Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this still depends on the vllm-project/vllm CommunicatorBase? Seems CommunicatorBase should also move to vllm-ascend?

https://github.com/vllm-project/vllm-ascend/blob/main/vllm_ascend/communicator.py#L21

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed CommunicatorBase in the new patchset

# Remove this file when vllm support by
# https://github.com/vllm-project/vllm/pull/11324.

from vllm.distributed.parallel_state import GroupCoordinator
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated but just curious: should vllm be a dependency of vllm-ascend as oneline in requriement and pyproject?

Copy link
Collaborator Author

@wangxiyuan wangxiyuan Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emm. Let's have a try. we can add it.

While IMO, it maybe raises error because there is no CPU version of pytorch on pypi.

Once it's added, the install step in the future from my sight is:

  1. install cpu version of Pytorch by hand. (torch==2.5.1+cpu)
  2. pip install vllm-ascend

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no warries, we can do it in followup

@wangxiyuan wangxiyuan force-pushed the add_patch branch 2 times, most recently from 4da98ee to 57f3aca Compare February 10, 2025 07:50

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
dist.all_reduce(x, group=self.group)
return x

def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have any UT to check the functionality?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

communicator test need more than one NPU card which is not supported by current CI. We're working on multi card support for CI system.

In this comment, we need test this PR by hand locally and be careful to merge it.

output_tensor = None
return output_tensor

def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@wangxiyuan
Copy link
Collaborator Author

Do not merge until it's fully tested locally. Thanks.

@Yikun
Copy link
Collaborator

Yikun commented Feb 10, 2025

vllm-ascend/mypy.ini

Lines 12 to 14 in 7006835

; Remove this after https://github.com/vllm-project/vllm/pull/11324 merged
[mypy-vllm.distributed.device_communicators.base_communicator]
ignore_missing_imports = True

This should also be removed

Signed-off-by: wangxiyuan <[email protected]>
Copy link
Collaborator

@Yikun Yikun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM if it passed in multi-card env

@wangxiyuan wangxiyuan closed this Feb 11, 2025
@wangxiyuan wangxiyuan deleted the add_patch branch February 11, 2025 02:46
@wangxiyuan wangxiyuan restored the add_patch branch February 11, 2025 02:53
@wangxiyuan
Copy link
Collaborator Author

See #30

wangxiyuan pushed a commit that referenced this pull request Feb 12, 2025
### What this PR does / why we need it?
- Remove on communicator mypy to address:
#24 (comment)
- Add mypy.ini to trigger list

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed

Signed-off-by: Yikun Jiang <[email protected]>
ttanzhiqiang pushed a commit to ttanzhiqiang/vllm-ascend that referenced this pull request Apr 27, 2025
…m-project#45)

### What this PR does / why we need it?
- Remove on communicator mypy to address:
vllm-project#24 (comment)
- Add mypy.ini to trigger list

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed

Signed-off-by: Yikun Jiang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants