Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 8be0a34

Browse files
pcmoritzalexm-redhat
authored andcommitted
Add unit test for Mixtral MoE layer (vllm-project#2677)
1 parent 9769731 commit 8be0a34

File tree

5 files changed

+119
-55
lines changed

5 files changed

+119
-55
lines changed

Dockerfile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
77
RUN apt-get update -y \
88
&& apt-get install -y python3-pip git
99

10+
# Workaround for https://github.com/openai/triton/issues/2507 and
11+
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
12+
# this won't be needed for future versions of this docker image
13+
# or future versions of triton.
14+
RUN ldconfig /usr/local/cuda-12.1/compat/
15+
1016
WORKDIR /workspace
1117

1218
# install build and runtime dependencies

tests/kernels/test_fused_moe.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

tests/kernels/test_moe.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""Tests for the MOE layers.
2+
3+
Run `pytest tests/kernels/test_moe.py`.
4+
"""
5+
6+
import pytest
7+
import torch
8+
9+
from transformers import MixtralConfig
10+
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
11+
12+
from vllm.model_executor.layers.fused_moe import fused_moe
13+
from vllm.model_executor.layers.activation import SiluAndMul
14+
from vllm.model_executor.models.mixtral import MixtralMoE
15+
16+
17+
def torch_moe(a, w1, w2, topk_weight, topk_ids):
18+
B, D = a.shape
19+
a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D)
20+
out = torch.zeros(B * topk_ids.shape[1],
21+
w2.shape[1],
22+
dtype=a.dtype,
23+
device=a.device)
24+
topk_ids = topk_ids.view(-1)
25+
topk_weight = topk_weight.view(-1)
26+
for i in range(w1.shape[0]):
27+
mask = topk_ids == i
28+
if mask.sum():
29+
out[mask] = SiluAndMul()(
30+
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
31+
return (out.view(B, -1, w2.shape[1]) *
32+
topk_weight.view(B, -1, 1)).sum(dim=1)
33+
34+
35+
@pytest.mark.parametrize("m", [512, 222, 33, 1])
36+
@pytest.mark.parametrize("n", [2048, 256, 1024])
37+
@pytest.mark.parametrize("k", [128, 511, 1024])
38+
@pytest.mark.parametrize("e", [8, 64])
39+
@pytest.mark.parametrize("topk", [2, 6])
40+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
41+
def test_fused_moe(
42+
m: int,
43+
n: int,
44+
k: int,
45+
e: int,
46+
topk: int,
47+
dtype: torch.dtype,
48+
):
49+
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
50+
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
51+
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
52+
53+
score = torch.randn((m, e), device='cuda', dtype=dtype)
54+
score = torch.softmax(score, dim=-1)
55+
topk_weight, topk_ids = torch.topk(score, topk)
56+
57+
triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False)
58+
torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids)
59+
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)
60+
61+
62+
@pytest.mark.parametrize("dtype",
63+
[torch.float32, torch.float16, torch.bfloat16])
64+
@torch.inference_mode()
65+
def test_mixtral_moe(dtype: torch.dtype):
66+
"Make sure our Mixtral MoE implementation agrees with the one from huggingface."
67+
68+
# Instantiate our and huggingface's MoE blocks
69+
config = MixtralConfig()
70+
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
71+
vllm_moe = MixtralMoE(
72+
num_experts=config.num_local_experts,
73+
top_k=config.num_experts_per_tok,
74+
hidden_size=config.hidden_size,
75+
intermediate_size=config.intermediate_size,
76+
params_dtype=dtype,
77+
tp_size=1,
78+
)
79+
80+
# Load the weights
81+
vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data
82+
for i in range(config.num_local_experts):
83+
weights = (hf_moe.experts[i].w1.weight.data,
84+
hf_moe.experts[i].w3.weight.data)
85+
vllm_moe.ws[i][:] = torch.cat(weights, dim=0)
86+
vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data
87+
88+
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
89+
inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
90+
91+
# Run forward passes for both MoE blocks
92+
hf_states, _ = hf_moe.forward(inputs)
93+
vllm_states = vllm_moe.forward(inputs)
94+
95+
mixtral_moe_tol = {
96+
torch.float32: 1e-3,
97+
torch.float16: 1e-3,
98+
torch.bfloat16: 1e-2,
99+
}
100+
101+
assert torch.allclose(hf_states,
102+
vllm_states,
103+
rtol=mixtral_moe_tol[dtype],
104+
atol=mixtral_moe_tol[dtype])

vllm/model_executor/layers/fused_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,9 @@ def fused_moe(hidden_states: torch.Tensor,
235235
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
236236
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
237237
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
238-
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
238+
assert hidden_states.dtype in [
239+
torch.float32, torch.float16, torch.bfloat16
240+
]
239241
M, _ = hidden_states.shape
240242
E, N, _ = w1.shape
241243

vllm/model_executor/models/mixtral.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,14 @@ def __init__(
7070
hidden_size: int,
7171
intermediate_size: int,
7272
params_dtype: Optional[torch.dtype] = None,
73+
tp_size: Optional[int] = None,
7374
):
7475
super().__init__()
75-
tp_size = get_tensor_model_parallel_world_size()
76+
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
7677
self.num_total_experts = num_experts
7778
self.top_k = top_k
7879
self.hidden_size = hidden_size
79-
self.intermediate_size = intermediate_size // tp_size
80+
self.intermediate_size = intermediate_size // self.tp_size
8081

8182
if params_dtype is None:
8283
params_dtype = torch.get_default_dtype()
@@ -141,8 +142,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
141142
selected_experts,
142143
inplace=True)
143144

144-
final_hidden_states = tensor_model_parallel_all_reduce(
145-
final_hidden_states)
145+
if self.tp_size > 1:
146+
final_hidden_states = tensor_model_parallel_all_reduce(
147+
final_hidden_states)
146148

147149
return final_hidden_states.view(batch_size, sequence_length,
148150
hidden_size)

0 commit comments

Comments
 (0)