Skip to content

[V1] EP/TP MoE + DP Attention #13931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions examples/offline_inference/data_parallel.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is for local testing, right? maybe revert it in the ci?

in addition, I think you need to add VLLM_TEST_ENABLE_EP to test it locally.

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from vllm import LLM, SamplingParams
from vllm.utils import get_open_port

GPUs_per_dp_rank = 2
DP_size = 2


def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
os.environ["VLLM_DP_RANK"] = str(dp_rank)
Expand Down Expand Up @@ -48,8 +51,8 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
max_tokens=16 * (dp_rank + 1))

# Create an LLM.
llm = LLM(model="facebook/opt-125m",
tensor_parallel_size=2,
llm = LLM(model="neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8",
tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=True)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
Expand All @@ -62,14 +65,12 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):

if __name__ == "__main__":
from multiprocessing import Process
dp_size = 2
GPUs_per_dp_rank = 2
dp_master_ip = "127.0.0.1"
dp_master_port = get_open_port()
procs = []
for i in range(dp_size):
for i in range(DP_size):
proc = Process(target=main,
args=(dp_size, i, dp_master_ip, dp_master_port,
args=(DP_size, i, dp_master_ip, dp_master_port,
GPUs_per_dp_rank))
proc.start()
procs.append(proc)
Expand Down
34 changes: 34 additions & 0 deletions vllm/distributed/device_communicators/base_device_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,40 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
input_size[dim + 1:])
return output_tensor

def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")

if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()

# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()

assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]

output_tensor = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)

# Perform reduce-scatter operation
torch.distributed.reduce_scatter_tensor(output_tensor,
input_tensor,
group=self.device_group)

# Reshape before returning
return output_tensor.movedim(0, dim).contiguous()

def gather(self,
input_: torch.Tensor,
dst: int = 0,
Expand Down
25 changes: 25 additions & 0 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,31 @@ def all_reduce(self, input_):
torch.distributed.all_reduce(out, group=self.device_group)
return out

def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
world_size = self.world_size
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()

# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()

assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]

output = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)

pynccl_comm.reduce_scatter(output, input_)

# Reshape before returning
return output.movedim(0, dim).contiguous()

def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
Expand Down
35 changes: 35 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,26 @@ def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return group._all_reduce_out_place(tensor)


def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group.reduce_scatter(tensor, dim)


def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)


def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
new_shape = list(tensor.shape)
new_shape[dim] = tensor.shape[dim] // world_size
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


if supports_custom_op():
direct_register_custom_op(
op_name="all_reduce",
Expand All @@ -126,6 +142,13 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
fake_impl=all_reduce_fake,
)

direct_register_custom_op(
op_name="reduce_scatter",
op_func=reduce_scatter,
mutates_args=[],
fake_impl=reduce_scatter_fake,
)


class GroupCoordinator:
"""
Expand Down Expand Up @@ -322,6 +345,18 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:

return self.device_communicator.all_gather(input_, dim)

def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")

return self.device_communicator.reduce_scatter(input_, dim)

def gather(self,
input_: torch.Tensor,
dst: int = 0,
Expand Down
120 changes: 97 additions & 23 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import torch

import vllm.envs as envs
from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import (
Expand Down Expand Up @@ -245,6 +246,51 @@ def forward_tpu(
forward_native = forward_cuda


def determine_expert_map(
ep_size: int, ep_rank: int,
global_num_experts: int) -> Tuple[int, Optional[torch.Tensor]]:
"""
Calculates how many experts should be assigned to each rank for EP and
creates a mapping from global to local expert index. Experts are
distributed evenly across ranks. Any remaining are assigned to the
last rank.
Comment on lines +258 to +259
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this behavior could be improved by using the more even partitioning strat from #13839

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a bad idea! I think we can try this in a future PR. cc @cakeng


Args:
ep_size (int): The size of the expert parallel group
global_num_experts (int): The total number of experts in the model.

Returns:
Tuple[int, Optional[torch.Tensor]]: A tuple containing:
- local_num_experts (int): The number of experts assigned
to the current rank.
- expert_map (Optional[torch.Tensor]): A tensor of shape
(global_num_experts,) mapping from global to local index.
Contains -1 for experts not assigned to the current rank.
Returns None if ep_size is 1.
"""
assert ep_size > 0
if ep_size == 1:
return (global_num_experts, None)

local_num_experts = global_num_experts // ep_size

# Create a tensor of size num_experts filled with -1
expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32)
# Create a expert map for the local experts
if ep_rank < (ep_size - 1):
# Each non-last rank gets local_num_experts experts.
expert_map[ep_rank * local_num_experts:
(ep_rank + 1) * local_num_experts] = \
torch.arange(0, local_num_experts, dtype=torch.int32)
else:
# All remaining experts are assigned to the last rank.
local_num_experts = (global_num_experts - ep_rank * local_num_experts)

expert_map[-local_num_experts:] = \
torch.arange(0, local_num_experts, dtype=torch.int32)
return (local_num_experts, expert_map)


class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.

Expand Down Expand Up @@ -294,14 +340,27 @@ def __init__(

self.tp_size = (tp_size if tp_size is not None else
get_tensor_model_parallel_world_size())
self.dp_size = get_dp_group().world_size
self.dp_rank = get_dp_group().rank_in_group
self.global_num_experts = num_experts

if envs.VLLM_TEST_ENABLE_EP:
self.ep_size = self.tp_size
self.ep_size = self.tp_size * self.dp_size
self.ep_rank = (get_tensor_model_parallel_rank() +
self.tp_size * self.dp_rank)
self.tp_size = 1

self.local_num_experts, self.expert_map = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts)
else:
self.ep_size = 1
self.local_num_experts = self.global_num_experts
self.expert_map = None
self.top_k = top_k
self.global_num_experts = num_experts
self.local_num_experts = self.global_num_experts // self.ep_size

assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
Expand All @@ -315,26 +374,6 @@ def __init__(
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.activation = activation
self.expert_map = None

if self.ep_size > 1:
# Create a tensor of size num_experts filled with -1
self.expert_map = torch.full((self.global_num_experts, ),
-1,
dtype=torch.int32)
# Create a expert map for the local experts
ep_rank = get_tensor_model_parallel_rank()
if ep_rank < (self.ep_size - 1):
# Each non-last rank gets local_num_experts experts.
self.expert_map[ep_rank * self.local_num_experts:
(ep_rank + 1) * self.local_num_experts] = \
torch.arange(0, self.local_num_experts, dtype=torch.int32)
else:
# All remaining experts are assigned to the last rank.
self.local_num_experts = (self.global_num_experts -
ep_rank * self.local_num_experts)
self.expert_map[-self.local_num_experts:] = \
torch.arange(0, self.local_num_experts, dtype=torch.int32)

if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
Expand Down Expand Up @@ -645,10 +684,32 @@ def select_experts(hidden_states: torch.Tensor,

return topk_weights, topk_ids

def naive_multicast(self, x: torch.Tensor, max_num_tokens: int):
assert (len(x.shape) == 2)
num_tokens = x.size(0)
buffer = torch.zeros((self.dp_size, max_num_tokens, x.size(1)),
device=x.device,
dtype=x.dtype)

buffer[self.dp_rank, :num_tokens, :].copy_(x)

x = get_dp_group().all_reduce(buffer)
x = x.view(-1, x.size(-1))
return x

def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None

if self.dp_size > 1:
num_tokens_across_dp = get_forward_context().num_tokens_across_dp
max_num_tokens = max(num_tokens_across_dp)
num_tokens = hidden_states.size(0)

assert num_tokens_across_dp is not None
hidden_states = self.naive_multicast(hidden_states, max_num_tokens)
router_logits = self.naive_multicast(router_logits, max_num_tokens)

# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
Expand All @@ -667,6 +728,19 @@ def forward(self, hidden_states: torch.Tensor,
activation=self.activation,
)

if self.dp_size > 1:
if True:
all_hidden_states = get_dp_group().all_reduce(
final_hidden_states)
all_hidden_states = all_hidden_states.view(
self.dp_size, -1, all_hidden_states.size(-1))
final_hidden_states = all_hidden_states[
self.dp_rank, :num_tokens, :]
else:
final_hidden_states = get_dp_group().reduce_scatter(
final_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens, :]

if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.)
final_hidden_states = tensor_model_parallel_all_reduce(
Expand Down