-
-
Notifications
You must be signed in to change notification settings - Fork 7.8k
[V1] EP/TP MoE + DP Attention #13931
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
8548f9c
e55a971
018c7f3
84585ea
a93fde6
e6fd1b9
bb4f8ae
33e0ee0
8551ad7
2188480
4fa682b
4a2318a
21eca4c
9a93e2d
6a628cf
cbdf1bb
19e84a5
523f4bf
eb13f62
0c6fb10
32f5b02
1b864de
bd88ae1
a7668fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,9 +7,10 @@ | |
import torch | ||
|
||
import vllm.envs as envs | ||
from vllm.distributed import (get_tensor_model_parallel_rank, | ||
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, | ||
get_tensor_model_parallel_world_size, | ||
tensor_model_parallel_all_reduce) | ||
from vllm.forward_context import get_forward_context | ||
from vllm.logger import init_logger | ||
from vllm.model_executor.custom_op import CustomOp | ||
from vllm.model_executor.layers.quantization.base_config import ( | ||
|
@@ -245,6 +246,51 @@ def forward_tpu( | |
forward_native = forward_cuda | ||
|
||
|
||
def determine_expert_map( | ||
ep_size: int, ep_rank: int, | ||
global_num_experts: int) -> Tuple[int, Optional[torch.Tensor]]: | ||
""" | ||
Calculates how many experts should be assigned to each rank for EP and | ||
creates a mapping from global to local expert index. Experts are | ||
distributed evenly across ranks. Any remaining are assigned to the | ||
last rank. | ||
Comment on lines
+258
to
+259
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe this behavior could be improved by using the more even partitioning strat from #13839 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a bad idea! I think we can try this in a future PR. cc @cakeng |
||
|
||
Args: | ||
ep_size (int): The size of the expert parallel group | ||
global_num_experts (int): The total number of experts in the model. | ||
|
||
Returns: | ||
Tuple[int, Optional[torch.Tensor]]: A tuple containing: | ||
- local_num_experts (int): The number of experts assigned | ||
to the current rank. | ||
- expert_map (Optional[torch.Tensor]): A tensor of shape | ||
(global_num_experts,) mapping from global to local index. | ||
Contains -1 for experts not assigned to the current rank. | ||
Returns None if ep_size is 1. | ||
""" | ||
assert ep_size > 0 | ||
if ep_size == 1: | ||
return (global_num_experts, None) | ||
|
||
local_num_experts = global_num_experts // ep_size | ||
|
||
# Create a tensor of size num_experts filled with -1 | ||
expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32) | ||
# Create a expert map for the local experts | ||
if ep_rank < (ep_size - 1): | ||
# Each non-last rank gets local_num_experts experts. | ||
expert_map[ep_rank * local_num_experts: | ||
(ep_rank + 1) * local_num_experts] = \ | ||
torch.arange(0, local_num_experts, dtype=torch.int32) | ||
else: | ||
# All remaining experts are assigned to the last rank. | ||
local_num_experts = (global_num_experts - ep_rank * local_num_experts) | ||
|
||
expert_map[-local_num_experts:] = \ | ||
torch.arange(0, local_num_experts, dtype=torch.int32) | ||
return (local_num_experts, expert_map) | ||
|
||
|
||
class FusedMoE(torch.nn.Module): | ||
"""FusedMoE layer for MoE models. | ||
|
||
|
@@ -294,14 +340,27 @@ def __init__( | |
|
||
self.tp_size = (tp_size if tp_size is not None else | ||
get_tensor_model_parallel_world_size()) | ||
self.dp_size = get_dp_group().world_size | ||
self.dp_rank = get_dp_group().rank_in_group | ||
self.global_num_experts = num_experts | ||
|
||
if envs.VLLM_TEST_ENABLE_EP: | ||
self.ep_size = self.tp_size | ||
self.ep_size = self.tp_size * self.dp_size | ||
self.ep_rank = (get_tensor_model_parallel_rank() + | ||
self.tp_size * self.dp_rank) | ||
self.tp_size = 1 | ||
|
||
self.local_num_experts, self.expert_map = determine_expert_map( | ||
ep_size=self.ep_size, | ||
ep_rank=self.ep_rank, | ||
global_num_experts=self.global_num_experts) | ||
else: | ||
self.ep_size = 1 | ||
self.local_num_experts = self.global_num_experts | ||
self.expert_map = None | ||
self.top_k = top_k | ||
self.global_num_experts = num_experts | ||
self.local_num_experts = self.global_num_experts // self.ep_size | ||
|
||
assert intermediate_size % self.tp_size == 0 | ||
self.intermediate_size_per_partition = intermediate_size // self.tp_size | ||
self.reduce_results = reduce_results | ||
|
@@ -315,26 +374,6 @@ def __init__( | |
self.scoring_func = scoring_func | ||
self.e_score_correction_bias = e_score_correction_bias | ||
self.activation = activation | ||
self.expert_map = None | ||
|
||
if self.ep_size > 1: | ||
# Create a tensor of size num_experts filled with -1 | ||
self.expert_map = torch.full((self.global_num_experts, ), | ||
-1, | ||
dtype=torch.int32) | ||
# Create a expert map for the local experts | ||
ep_rank = get_tensor_model_parallel_rank() | ||
if ep_rank < (self.ep_size - 1): | ||
# Each non-last rank gets local_num_experts experts. | ||
self.expert_map[ep_rank * self.local_num_experts: | ||
(ep_rank + 1) * self.local_num_experts] = \ | ||
torch.arange(0, self.local_num_experts, dtype=torch.int32) | ||
else: | ||
# All remaining experts are assigned to the last rank. | ||
self.local_num_experts = (self.global_num_experts - | ||
ep_rank * self.local_num_experts) | ||
self.expert_map[-self.local_num_experts:] = \ | ||
torch.arange(0, self.local_num_experts, dtype=torch.int32) | ||
|
||
if self.scoring_func != "softmax" and not self.use_grouped_topk: | ||
raise ValueError("Only softmax scoring function is supported for " | ||
|
@@ -645,10 +684,32 @@ def select_experts(hidden_states: torch.Tensor, | |
|
||
return topk_weights, topk_ids | ||
|
||
def naive_multicast(self, x: torch.Tensor, max_num_tokens: int): | ||
assert (len(x.shape) == 2) | ||
num_tokens = x.size(0) | ||
buffer = torch.zeros((self.dp_size, max_num_tokens, x.size(1)), | ||
device=x.device, | ||
dtype=x.dtype) | ||
|
||
buffer[self.dp_rank, :num_tokens, :].copy_(x) | ||
|
||
x = get_dp_group().all_reduce(buffer) | ||
x = x.view(-1, x.size(-1)) | ||
return x | ||
|
||
def forward(self, hidden_states: torch.Tensor, | ||
router_logits: torch.Tensor): | ||
assert self.quant_method is not None | ||
|
||
if self.dp_size > 1: | ||
num_tokens_across_dp = get_forward_context().num_tokens_across_dp | ||
max_num_tokens = max(num_tokens_across_dp) | ||
num_tokens = hidden_states.size(0) | ||
|
||
assert num_tokens_across_dp is not None | ||
hidden_states = self.naive_multicast(hidden_states, max_num_tokens) | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
router_logits = self.naive_multicast(router_logits, max_num_tokens) | ||
|
||
# Matrix multiply. | ||
final_hidden_states = self.quant_method.apply( | ||
layer=self, | ||
|
@@ -667,6 +728,19 @@ def forward(self, hidden_states: torch.Tensor, | |
activation=self.activation, | ||
) | ||
|
||
if self.dp_size > 1: | ||
if True: | ||
all_hidden_states = get_dp_group().all_reduce( | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
final_hidden_states) | ||
all_hidden_states = all_hidden_states.view( | ||
self.dp_size, -1, all_hidden_states.size(-1)) | ||
final_hidden_states = all_hidden_states[ | ||
self.dp_rank, :num_tokens, :] | ||
else: | ||
final_hidden_states = get_dp_group().reduce_scatter( | ||
final_hidden_states, 0) | ||
final_hidden_states = final_hidden_states[:num_tokens, :] | ||
|
||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): | ||
# Default set to False. (May have to add shared expert outputs.) | ||
final_hidden_states = tensor_model_parallel_all_reduce( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is for local testing, right? maybe revert it in the ci?
in addition, I think you need to add
VLLM_TEST_ENABLE_EP
to test it locally.