Skip to content

Commit 3ee6551

Browse files
authored
Fixing the shape to use in padding calculation (vllm-project#464)
* Fixing the shape to use in padding calculation * Assertion on the int8 quantized moe * Properly testing for padding
1 parent 0feb91a commit 3ee6551

File tree

2 files changed

+34
-18
lines changed

2 files changed

+34
-18
lines changed

tests/kernels/test_moe.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
44
Run `pytest tests/kernels/test_moe.py`.
55
"""
6+
import unittest.mock as mock
7+
68
import pytest
79
import torch
810
from torch.nn import Parameter
@@ -40,6 +42,7 @@
4042
@pytest.mark.parametrize("topk", TOP_KS)
4143
@pytest.mark.parametrize("ep_size", EP_SIZE)
4244
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
45+
@pytest.mark.parametrize("padding", [True, False])
4346
def test_fused_moe(
4447
m: int,
4548
n: int,
@@ -48,20 +51,20 @@ def test_fused_moe(
4851
topk: int,
4952
ep_size: int,
5053
dtype: torch.dtype,
54+
padding: bool,
5155
):
56+
if padding:
57+
padding_size = 128
58+
envs.VLLM_MOE_PADDING = True
59+
else:
60+
padding_size = 0
61+
envs.VLLM_MOE_PADDING = False
62+
5263
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
5364
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
5465
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
5566

5667
score = torch.randn((m, e), device="cuda", dtype=dtype)
57-
58-
# Pad the input if use padding
59-
if envs.VLLM_MOE_PADDING:
60-
w1 = F.pad(w1, (0, 128), "constant", 0)
61-
torch.cuda.empty_cache()
62-
w2 = F.pad(w2, (0, 128), "constant", 0)
63-
torch.cuda.empty_cache()
64-
6568
if ep_size > 1:
6669
local_e = e // ep_size
6770
e_ids = torch.randint(0,
@@ -75,16 +78,7 @@ def test_fused_moe(
7578
else:
7679
e_map = None
7780

78-
triton_output = fused_moe(a,
79-
w1,
80-
w2,
81-
score,
82-
topk,
83-
global_num_experts=e,
84-
expert_map=e_map,
85-
renormalize=False)
8681
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
87-
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
8882
iterative_output = iterative_moe(a,
8983
w1,
9084
w2,
@@ -93,6 +87,26 @@ def test_fused_moe(
9387
global_num_experts=e,
9488
expert_map=e_map,
9589
renormalize=False)
90+
# Pad the input if use padding
91+
if envs.VLLM_MOE_PADDING:
92+
w1 = F.pad(w1, (0, 128), "constant", 0)
93+
torch.cuda.empty_cache()
94+
w2 = F.pad(w2, (0, 128), "constant", 0)
95+
torch.cuda.empty_cache()
96+
97+
with mock.patch(
98+
'vllm.model_executor.layers.fused_moe.fused_moe.padding_size',
99+
padding_size):
100+
triton_output = fused_moe(a,
101+
w1,
102+
w2,
103+
score,
104+
topk,
105+
global_num_experts=e,
106+
expert_map=e_map,
107+
renormalize=False)
108+
109+
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
96110
torch.testing.assert_close(iterative_output,
97111
torch_output,
98112
atol=1e-2,

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
719719
block_shape is not None and block_shape[1] > 0:
720720
assert B_scale is not None and B_scale.ndim == 3
721721
assert B_zp is None or B_zp.ndim == 3
722+
assert padding_size == 0, "MoE padding is not supported " \
723+
"with GPTQ/AWQ quantization"
722724

723725
fused_moe_kernel_gptq_awq[grid](
724726
A,
@@ -770,7 +772,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
770772
expert_ids,
771773
num_tokens_post_padded,
772774
B.shape[1],
773-
A.shape[1] - padding_size,
775+
B.shape[2] - padding_size,
774776
EM,
775777
topk_ids.numel(),
776778
A.stride(0),

0 commit comments

Comments
 (0)