Skip to content

Commit e6b9401

Browse files
avshalommanGWS0428
authored andcommitted
[Hardware][TPU] workaround fix for MoE on TPU (vllm-project#11764)
1 parent 6a0f069 commit e6b9401

File tree

3 files changed

+60
-1
lines changed

3 files changed

+60
-1
lines changed

tests/kernels/test_moe.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from vllm.model_executor.layers.fused_moe import fused_moe
1515
from vllm.model_executor.layers.fused_moe.fused_moe import (
1616
fused_topk, moe_align_block_size)
17+
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
18+
fused_moe as iterative_moe)
1719
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
1820
marlin_quantize)
1921
from vllm.model_executor.models.mixtral import MixtralMoE
@@ -46,6 +48,11 @@ def test_fused_moe(
4648
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
4749
torch_output = torch_moe(a, w1, w2, score, topk)
4850
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
51+
iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False)
52+
torch.testing.assert_close(iterative_output,
53+
torch_output,
54+
atol=2e-2,
55+
rtol=0)
4956

5057

5158
@pytest.mark.parametrize("dtype",

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
else:
2121
fused_experts = None # type: ignore
2222
if current_platform.is_tpu():
23-
from .moe_pallas import fused_moe as fused_moe_pallas
23+
# the iterative moe implementation is used until the moe_pallas is fixed
24+
from .moe_torch_iterative import fused_moe as fused_moe_pallas
2425
else:
2526
fused_moe_pallas = None # type: ignore
2627
logger = init_logger(__name__)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
import torch.nn.functional as F
3+
4+
5+
def fused_moe(
6+
hidden_states: torch.Tensor,
7+
w1: torch.Tensor,
8+
w2: torch.Tensor,
9+
gating_output: torch.Tensor,
10+
topk: int,
11+
renormalize: bool,
12+
) -> torch.Tensor:
13+
"""
14+
Args:
15+
hidden_states: [*, hidden_size]
16+
w1: [num_experts, intermediate_size * 2, hidden_size]
17+
w2: [num_experts, hidden_size, intermediate_size]
18+
gating_output: [*, num_experts]
19+
"""
20+
orig_shape = hidden_states.shape
21+
hidden_size = hidden_states.shape[-1]
22+
num_tokens = hidden_states.shape[:-1].numel()
23+
num_experts = w1.shape[0]
24+
intermediate_size = w2.shape[-1]
25+
dtype = hidden_states.dtype
26+
27+
hidden_states = hidden_states.view(num_tokens, hidden_size)
28+
gating_output = gating_output.view(num_tokens, num_experts)
29+
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
30+
topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
31+
if renormalize:
32+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
33+
topk_weights = topk_weights.to(dtype)
34+
35+
final_hidden_states = None
36+
for expert_idx in range(num_experts):
37+
expert_w1 = w1[expert_idx]
38+
expert_w2 = w2[expert_idx]
39+
expert_mask = (selected_experts == expert_idx)
40+
expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True)
41+
x = F.linear(hidden_states, expert_w1)
42+
gate = F.silu(x[:, :intermediate_size])
43+
x = x[:, intermediate_size:] * gate
44+
x = F.linear(x, expert_w2)
45+
current_hidden_states = x * expert_weights
46+
if final_hidden_states is None:
47+
final_hidden_states = current_hidden_states
48+
else:
49+
final_hidden_states = final_hidden_states + current_hidden_states
50+
51+
return final_hidden_states.view(orig_shape) # type: ignore

0 commit comments

Comments
 (0)