diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py new file mode 100644 index 00000000000..eb12f8834b4 --- /dev/null +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + + +class DeviceCommunicatorBase: + """ + Base class for device-specific communicator. + It can use the `cpu_group` to initialize the communicator. + If the device has PyTorch integration (PyTorch can recognize its + communication backend), the `device_group` will also be given. + """ + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + self.device = device or torch.device("cpu") + self.cpu_group = cpu_group + self.device_group = device_group + self.unique_name = unique_name + self.rank = dist.get_rank(cpu_group) + self.world_size = dist.get_world_size(cpu_group) + self.ranks = dist.get_process_group_ranks(cpu_group) + self.global_rank = dist.get_rank() + self.global_world_size = dist.get_world_size() + self.rank_in_group = dist.get_group_rank(self.cpu_group, + self.global_rank) + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + dist.all_reduce(input_, group=self.device_group) + return input_ + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # NOTE: we have to use concat-style all-gather here, + # stack-style all-gather has compatibility issues with + # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 + output_size = (input_size[0] * self.world_size, ) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty(output_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + dist.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + # Reshape + output_tensor = output_tensor.reshape((self.world_size, ) + input_size) + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (self.world_size * + input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather(input_, + gather_list, + dst=self.ranks[dst], + group=self.device_group) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + torch.distributed.send(tensor, self.ranks[dst], self.device_group) + + def recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + if src is None: + src = (self.rank_in_group - 1) % self.world_size + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.recv(tensor, self.ranks[src], self.device_group) + return tensor + + def destroy(self): + pass diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py new file mode 100644 index 00000000000..4e86396e713 --- /dev/null +++ b/vllm/distributed/device_communicators/cpu_communicator.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +from torch.distributed import ProcessGroup + +from .base_device_communicator import DeviceCommunicatorBase + + +class CpuCommunicator(DeviceCommunicatorBase): + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + self.ipex_available = False + self.dist_module = torch.distributed + try: + import intel_extension_for_pytorch as ipex + self.ipex_available = True + self.dist_module = ipex.distributed + except ImportError: + """ + Intel IPEX not found. Falling back to PyTorch native + all_reduce for CPU (e.g. MacOS) + """ + pass + + def all_reduce(self, input_): + return self.dist_module.all_reduce(input_, group=self.device_group) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py new file mode 100644 index 00000000000..f806f8b39ef --- /dev/null +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +from torch.distributed import ProcessGroup + +from .base_device_communicator import DeviceCommunicatorBase + + +class CudaCommunicator(DeviceCommunicatorBase): + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + if "pp" in unique_name: + # pipeline parallel does not need custom allreduce + use_custom_allreduce = False + else: + from vllm.distributed.parallel_state import ( + _ENABLE_CUSTOM_ALL_REDUCE) + use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE + use_pynccl = True + + self.use_pynccl = use_pynccl + self.use_custom_allreduce = use_custom_allreduce + + # lazy import to avoid documentation build error + from vllm.distributed.device_communicators.custom_all_reduce import ( + CustomAllreduce) + from vllm.distributed.device_communicators.pynccl import ( + PyNcclCommunicator) + + self.pynccl_comm: Optional[PyNcclCommunicator] = None + if use_pynccl and self.world_size > 1: + self.pynccl_comm = PyNcclCommunicator( + group=self.cpu_group, + device=self.device, + ) + + self.ca_comm: Optional[CustomAllreduce] = None + if use_custom_allreduce and self.world_size > 1: + # Initialize a custom fast all-reduce implementation. + self.ca_comm = CustomAllreduce( + group=self.cpu_group, + device=self.device, + ) + + def all_reduce(self, input_): + # always try custom allreduce first, + # and then pynccl. + ca_comm = self.ca_comm + if ca_comm is not None and not ca_comm.disabled and \ + ca_comm.should_custom_ar(input_): + out = ca_comm.custom_all_reduce(input_) + assert out is not None + return out + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None + out = pynccl_comm.all_reduce(input_) + if out is None: + # fall back to the default all-reduce using PyTorch. + # this usually happens during testing. + # when we run the model, allreduce only happens for the TP + # group, where we always have either custom allreduce or pynccl. + out = input_.clone() + torch.distributed.all_reduce(out, group=self.device_group) + return out + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.send(tensor, dst) + else: + torch.distributed.send(tensor, self.ranks[dst], self.device_group) + + def recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + if src is None: + src = (self.rank_in_group - 1) % self.world_size + + tensor = torch.empty(size, dtype=dtype, device=self.device) + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.recv(tensor, src) + else: + torch.distributed.recv(tensor, self.ranks[src], self.device_group) + return tensor + + def destroy(self): + if self.pynccl_comm is not None: + self.pynccl_comm = None + if self.ca_comm is not None: + self.ca_comm = None diff --git a/vllm/distributed/device_communicators/hpu_communicator.py b/vllm/distributed/device_communicators/hpu_communicator.py index 3f85da98aca..9536a7f883e 100644 --- a/vllm/distributed/device_communicators/hpu_communicator.py +++ b/vllm/distributed/device_communicators/hpu_communicator.py @@ -2,45 +2,40 @@ import torch import torch.distributed as dist -from torch.distributed import ProcessGroup from vllm.platforms import current_platform +from .base_device_communicator import DeviceCommunicatorBase + if current_platform.is_hpu(): import habana_frameworks.torch as htorch # noqa: F401 -class HpuCommunicator: - - def __init__(self, group: ProcessGroup): - if not current_platform.is_hpu(): - self.disabled = True - return - self.disabled = False - self.group = group - self.world_size = dist.get_world_size(self.group) +class HpuCommunicator(DeviceCommunicatorBase): - def all_reduce(self, x: torch.Tensor) -> torch.Tensor: + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used # (which is required for tensor parallel HPUGraph inference) htorch.core.mark_step() - dist.all_reduce(x, group=self.group) - return x + dist.all_reduce(input_, group=self.device_group) + return input_ - def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size if dim < 0: # Convert negative dim to positive. - dim += x.dim() - input_size = x.size() + dim += input_.dim() + input_size = input_.size() # Allocate output tensor. output_tensor = torch.empty((world_size, ) + input_size, - dtype=x.dtype, - device=x.device) + dtype=input_.dtype, + device=input_.device) # All-gather. htorch.core.mark_step() - dist.all_gather_into_tensor(output_tensor, x, group=self.group) + dist.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) # Reshape output_tensor = output_tensor.movedim(0, dim) output_tensor = output_tensor.reshape(input_size[:dim] + diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 7af7c65f642..524e655b6b4 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -1,13 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 import os +from typing import Optional import torch -import torch.distributed as dist from torch.distributed import ProcessGroup from vllm.platforms import current_platform +from .base_device_communicator import DeviceCommunicatorBase + if current_platform.is_tpu(): import torch_xla.core.xla_model as xm import torch_xla.runtime as xr @@ -16,19 +18,20 @@ from vllm.executor import ray_utils -class TpuCommunicator: +class TpuCommunicator(DeviceCommunicatorBase): - def __init__(self, group: ProcessGroup): - if not current_platform.is_tpu(): - self.disabled = True - return - self.disabled = False + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node # must be used together. Therefore, the local rank and world size can # be simply calculated as follows. - global_rank = dist.get_rank(group) - global_world_size = dist.get_world_size(group) + global_rank = self.global_rank + global_world_size = self.global_world_size # Calculate how many TPU nodes are in the current deployment. This # is the Ray placement group if it is deployed with Ray. Default @@ -55,9 +58,9 @@ def __init__(self, group: ProcessGroup): pjrt.initialize_multiprocess(local_rank, local_world_size) xr._init_world_size_ordinal() - def all_reduce(self, x: torch.Tensor) -> torch.Tensor: - return xm.all_reduce(xm.REDUCE_SUM, x) + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + return xm.all_reduce(xm.REDUCE_SUM, input_) - def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: assert dim == -1, "TPUs only support dim=-1 for all-gather." - return xm.all_gather(x, dim=dim) + return xm.all_gather(input_, dim=dim) diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py deleted file mode 100644 index 79ccc101e08..00000000000 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ /dev/null @@ -1,49 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup - -from vllm.platforms import current_platform - - -class XpuCommunicator: - - def __init__(self, group: ProcessGroup): - if not current_platform.is_xpu(): - self.disabled = True - return - self.disabled = False - self.group = group - self.world_size = dist.get_world_size(self.group) - - def all_reduce(self, x: torch.Tensor) -> torch.Tensor: - dist.all_reduce(x, group=self.group) - return x - - def gather(self, - input_: torch.Tensor, - rank_in_group: int, - dst: int = 0, - dim: int = -1): - # For xpu path, gather doesn't work properly together with ray - # cluster so we use all_gather instead for now. - input_size = input_.size() - # Allocate output tensor. - output_tensor = torch.empty((self.world_size, ) + input_size, - dtype=input_.dtype, - device=input_.device) - # All-gather. - torch.distributed.all_gather_into_tensor(output_tensor, - input_, - group=self.group) - if rank_in_group == dst: - # Reshape - output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape(input_size[:dim] + - (self.world_size * - input_size[dim], ) + - input_size[dim + 1:]) - else: - output_tensor = None - return output_tensor diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index bfc41703b94..781f870a756 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -39,9 +39,12 @@ import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer import vllm.envs as envs +from vllm.distributed.device_communicators.base_device_communicator import ( + DeviceCommunicatorBase) from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.utils import direct_register_custom_op, supports_custom_op +from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname, + supports_custom_op) if TYPE_CHECKING: from vllm.config import VllmConfig @@ -130,9 +133,8 @@ class GroupCoordinator: PyTorch ProcessGroup is bound to one specific communication backend, e.g. NCCL, Gloo, MPI, etc. GroupCoordinator takes charge of all the communication operations among - the processes in the group. It can route the communication to - a specific implementation (e.g. switch allreduce implementation - based on the tensor size and cuda graph mode). + the processes in the group. It manages both CPU and device + communication. """ # available attributes: @@ -150,11 +152,8 @@ class GroupCoordinator: rank_in_group: int # rank inside the group cpu_group: ProcessGroup # group for CPU communication device_group: ProcessGroup # group for device communication - use_pynccl: bool # a hint of whether to use PyNccl - use_custom_allreduce: bool # a hint of whether to use CustomAllreduce - # communicators are only created for world size > 1 - pynccl_comm: Optional[Any] # PyNccl communicator - ca_comm: Optional[Any] # Custom allreduce communicator + use_device_communicator: bool # whether to use device communicator + device_communicator: DeviceCommunicatorBase # device communicator mq_broadcaster: Optional[Any] # shared memory broadcaster def __init__( @@ -162,11 +161,7 @@ def __init__( group_ranks: List[List[int]], local_rank: int, torch_distributed_backend: Union[str, Backend], - use_pynccl: bool, - use_custom_allreduce: bool, - use_tpu_communicator: bool, - use_hpu_communicator: bool, - use_xpu_communicator: bool, + use_device_communicator: bool, use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, ): @@ -196,56 +191,26 @@ def __init__( assert self.device_group is not None from vllm.platforms import current_platform + + # TODO: fix it for other platforms if current_platform.is_cuda_alike(): self.device = torch.device(f"cuda:{local_rank}") else: self.device = torch.device("cpu") - self.use_pynccl = use_pynccl - self.use_custom_allreduce = use_custom_allreduce - self.use_tpu_communicator = use_tpu_communicator - self.use_hpu_communicator = use_hpu_communicator - self.use_xpu_communicator = use_xpu_communicator - - # lazy import to avoid documentation build error - from vllm.distributed.device_communicators.custom_all_reduce import ( - CustomAllreduce) - from vllm.distributed.device_communicators.pynccl import ( - PyNcclCommunicator) - - self.pynccl_comm: Optional[PyNcclCommunicator] = None - if use_pynccl and self.world_size > 1: - self.pynccl_comm = PyNcclCommunicator( - group=self.cpu_group, - device=self.device, - ) + self.use_device_communicator = use_device_communicator - self.ca_comm: Optional[CustomAllreduce] = None - if use_custom_allreduce and self.world_size > 1: - # Initialize a custom fast all-reduce implementation. - self.ca_comm = CustomAllreduce( - group=self.cpu_group, + self.device_communicator: DeviceCommunicatorBase = None # type: ignore + if use_device_communicator and self.world_size > 1: + device_comm_cls = resolve_obj_by_qualname( + current_platform.get_device_communicator_cls()) + self.device_communicator = device_comm_cls( + cpu_group=self.cpu_group, device=self.device, + device_group=self.device_group, + unique_name=self.unique_name, ) - from vllm.distributed.device_communicators.tpu_communicator import ( - TpuCommunicator) - self.tpu_communicator: Optional[TpuCommunicator] = None - if use_tpu_communicator and self.world_size > 1: - self.tpu_communicator = TpuCommunicator(group=self.cpu_group) - - from vllm.distributed.device_communicators.hpu_communicator import ( - HpuCommunicator) - self.hpu_communicator: Optional[HpuCommunicator] - if use_hpu_communicator and self.world_size > 1: - self.hpu_communicator = HpuCommunicator(group=self.device_group) - - from vllm.distributed.device_communicators.xpu_communicator import ( - XpuCommunicator) - self.xpu_communicator: Optional[XpuCommunicator] - if use_xpu_communicator and self.world_size > 1: - self.xpu_communicator = XpuCommunicator(group=self.device_group) - from vllm.distributed.device_communicators.shm_broadcast import ( MessageQueue) self.mq_broadcaster: Optional[MessageQueue] = None @@ -253,6 +218,9 @@ def __init__( self.mq_broadcaster = MessageQueue.create_from_process_group( self.cpu_group, 1 << 22, 6) + from vllm.platforms import current_platform + self.use_custom_op_call = current_platform.is_cuda_alike() + @property def first_rank(self): """Return the global rank of the first process in the group""" @@ -296,9 +264,16 @@ def graph_capture( else: stream = graph_capture_context.stream - ca_comm = self.ca_comm - maybe_ca_context = nullcontext( - ) if ca_comm is None else ca_comm.capture() + # only cuda uses this function, + # so we don't abstract it into the base class + maybe_ca_context = nullcontext() + from vllm.distributed.device_communicators.cuda_communicator import ( + CudaCommunicator) + if self.device_communicator is not None: + assert isinstance(self.device_communicator, CudaCommunicator) + ca_comm = self.device_communicator.ca_comm + if ca_comm is not None: + maybe_ca_context = ca_comm.capture() # type: ignore # ensure all initialization operations complete before attempting to # capture the graph on another stream @@ -328,54 +303,14 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.world_size == 1: return input_ - if input_.is_cpu: - try: - import intel_extension_for_pytorch as ipex - ipex.distributed.all_reduce(input_, group=self.device_group) - return input_ - except ImportError: - """ - Intel IPEX not found. Falling back to PyTorch native - all_reduce for CPU - """ - torch.distributed.all_reduce(input_, group=self.device_group) - return input_ - - if self.tpu_communicator is not None and \ - not self.tpu_communicator.disabled: - # TPU handles Dynamo with its own logic. - return self.tpu_communicator.all_reduce(input_) - - if self.hpu_communicator is not None and \ - not self.hpu_communicator.disabled: - return self.hpu_communicator.all_reduce(input_) - - if self.xpu_communicator is not None and \ - not self.xpu_communicator.disabled: - return self.xpu_communicator.all_reduce(input_) - - return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name) + if self.use_custom_op_call: + return torch.ops.vllm.all_reduce(input_, + group_name=self.unique_name) + else: + return self._all_reduce_out_place(input_) def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: - # always try custom allreduce first, - # and then pynccl. - ca_comm = self.ca_comm - if ca_comm is not None and not ca_comm.disabled and \ - ca_comm.should_custom_ar(input_): - out = ca_comm.custom_all_reduce(input_) - assert out is not None - return out - pynccl_comm = self.pynccl_comm - assert pynccl_comm is not None - out = pynccl_comm.all_reduce(input_) - if out is None: - # fall back to the default all-reduce using PyTorch. - # this usually happens during testing. - # when we run the model, allreduce only happens for the TP - # group, where we always have either custom allreduce or pynccl. - out = input_.clone() - torch.distributed.all_reduce(out, group=self.device_group) - return out + return self.device_communicator.all_reduce(input_) def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size @@ -385,40 +320,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - # For TPUs, use TPU communicator. - tpu_comm = self.tpu_communicator - if tpu_comm is not None and not tpu_comm.disabled: - return tpu_comm.all_gather(input_, dim) - - # For HPUs, use HPU communicator. - hpu_comm = self.hpu_communicator - if hpu_comm is not None and not hpu_comm.disabled: - return hpu_comm.all_gather(input_, dim) - - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - input_size = input_.size() - # NOTE: we have to use concat-style all-gather here, - # stack-style all-gather has compatibility issues with - # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 - output_size = (input_size[0] * world_size, ) + input_size[1:] - # Allocate output tensor. - output_tensor = torch.empty(output_size, - dtype=input_.dtype, - device=input_.device) - # All-gather. - torch.distributed.all_gather_into_tensor(output_tensor, - input_, - group=self.device_group) - # Reshape - output_tensor = output_tensor.reshape((world_size, ) + input_size) - output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape(input_size[:dim] + - (world_size * - input_size[dim], ) + - input_size[dim + 1:]) - return output_tensor + return self.device_communicator.all_gather(input_, dim) def gather(self, input_: torch.Tensor, @@ -433,30 +335,7 @@ def gather(self, # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ - assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - if self.xpu_communicator is not None and \ - not self.xpu_communicator.disabled: - return self.xpu_communicator.gather(input_, self.rank_in_group, - dst, dim) - # Allocate output tensor. - if self.rank_in_group == dst: - gather_list = [torch.empty_like(input_) for _ in range(world_size)] - else: - gather_list = None - # Gather. - torch.distributed.gather(input_, - gather_list, - dst=self.ranks[dst], - group=self.device_group) - if self.rank_in_group == dst: - output_tensor = torch.cat(gather_list, dim=dim) - else: - output_tensor = None - return output_tensor + return self.device_communicator.gather(input_, dst, dim) def broadcast(self, input_: torch.Tensor, src: int = 0): """Broadcast the input tensor. @@ -798,14 +677,7 @@ def barrier(self): def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: """Sends a tensor to the destination rank in a non-blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" - if dst is None: - dst = (self.rank_in_group + 1) % self.world_size - - pynccl_comm = self.pynccl_comm - if pynccl_comm is not None and not pynccl_comm.disabled: - pynccl_comm.send(tensor, dst) - else: - torch.distributed.send(tensor, self.ranks[dst], self.device_group) + self.device_communicator.send(tensor, dst) def recv(self, size: torch.Size, @@ -813,16 +685,7 @@ def recv(self, src: Optional[int] = None) -> torch.Tensor: """Receives a tensor from the source rank.""" """NOTE: `src` is the local rank of the source rank.""" - if src is None: - src = (self.rank_in_group - 1) % self.world_size - - tensor = torch.empty(size, dtype=dtype, device=self.device) - pynccl_comm = self.pynccl_comm - if pynccl_comm is not None and not pynccl_comm.disabled: - pynccl_comm.recv(tensor, src) - else: - torch.distributed.recv(tensor, self.ranks[src], self.device_group) - return tensor + return self.device_communicator.recv(size, dtype, src) def destroy(self): if self.device_group is not None: @@ -831,10 +694,8 @@ def destroy(self): if self.cpu_group is not None: torch.distributed.destroy_process_group(self.cpu_group) self.cpu_group = None - if self.pynccl_comm is not None: - self.pynccl_comm = None - if self.ca_comm is not None: - self.ca_comm = None + if self.device_communicator is not None: + self.device_communicator.destroy() if self.mq_broadcaster is not None: self.mq_broadcaster = None @@ -853,11 +714,7 @@ def init_world_group(ranks: List[int], local_rank: int, group_ranks=[ranks], local_rank=local_rank, torch_distributed_backend=backend, - use_pynccl=False, - use_custom_allreduce=False, - use_tpu_communicator=False, - use_hpu_communicator=False, - use_xpu_communicator=False, + use_device_communicator=False, group_name="world", ) @@ -866,23 +723,15 @@ def init_model_parallel_group( group_ranks: List[List[int]], local_rank: int, backend: str, - use_custom_allreduce: Optional[bool] = None, use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, ) -> GroupCoordinator: - if use_custom_allreduce is None: - use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE - from vllm.platforms import current_platform + return GroupCoordinator( group_ranks=group_ranks, local_rank=local_rank, torch_distributed_backend=backend, - use_pynccl=current_platform.is_cuda_alike(), - use_custom_allreduce=current_platform.is_cuda_alike() - and use_custom_allreduce, - use_tpu_communicator=True, - use_hpu_communicator=True, - use_xpu_communicator=True, + use_device_communicator=True, use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, ) @@ -1053,11 +902,9 @@ def initialize_model_parallel( for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group_ranks.append(ranks) - # pipeline parallel does not need custom allreduce _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, - use_custom_allreduce=False, group_name="pp") diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index a9216c2322e..ab8982a3a6e 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -146,3 +146,10 @@ def is_pin_memory_available(cls) -> bool: @classmethod def get_punica_wrapper(cls) -> str: return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU" + + @classmethod + def get_device_communicator_cls(cls) -> str: + """ + Get device specific communicator class for distributed communication. + """ + return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 2c40a798736..5b073125614 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -233,6 +233,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, def get_punica_wrapper(cls) -> str: return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" + @classmethod + def get_device_communicator_cls(cls) -> str: + return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 78ddb67bb3f..4c842b52511 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -88,3 +88,7 @@ def is_pin_memory_available(cls): @classmethod def get_punica_wrapper(cls) -> str: return "vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU" + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "vllm.distributed.device_communicators.hpu_communicator.HpuCommunicator" # noqa diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 19adc2af8c6..58948ad1aba 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -322,6 +322,13 @@ def get_punica_wrapper(cls) -> str: """ raise NotImplementedError + @classmethod + def get_device_communicator_cls(cls) -> str: + """ + Get device specific communicator class for distributed communication. + """ + return "vllm.distributed.device_communicator.base_device_communicator.DeviceCommunicatorBase" # noqa + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index d57cce4231d..393b8a18527 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -186,3 +186,7 @@ def get_current_memory_usage(cls, torch.cuda.reset_peak_memory_stats(device) return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info( device)[0] + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 0c81d6a9389..cdf835a52c0 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -115,3 +115,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: def is_pin_memory_available(cls): logger.warning("Pin memory is not supported on TPU.") return False + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa