3
3
4
4
Run `pytest tests/kernels/test_moe.py`.
5
5
"""
6
+ import unittest .mock as mock
7
+
6
8
import pytest
7
9
import torch
8
10
from torch .nn import Parameter
40
42
@pytest .mark .parametrize ("topk" , TOP_KS )
41
43
@pytest .mark .parametrize ("ep_size" , EP_SIZE )
42
44
@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
45
+ @pytest .mark .parametrize ("padding" , [True , False ])
43
46
def test_fused_moe (
44
47
m : int ,
45
48
n : int ,
@@ -48,20 +51,20 @@ def test_fused_moe(
48
51
topk : int ,
49
52
ep_size : int ,
50
53
dtype : torch .dtype ,
54
+ padding : bool ,
51
55
):
56
+ if padding :
57
+ padding_size = 128
58
+ envs .VLLM_MOE_PADDING = True
59
+ else :
60
+ padding_size = 0
61
+ envs .VLLM_MOE_PADDING = False
62
+
52
63
a = torch .randn ((m , k ), device = "cuda" , dtype = dtype ) / 10
53
64
w1 = torch .randn ((e , 2 * n , k ), device = "cuda" , dtype = dtype ) / 10
54
65
w2 = torch .randn ((e , k , n ), device = "cuda" , dtype = dtype ) / 10
55
66
56
67
score = torch .randn ((m , e ), device = "cuda" , dtype = dtype )
57
-
58
- # Pad the input if use padding
59
- if envs .VLLM_MOE_PADDING :
60
- w1 = F .pad (w1 , (0 , 128 ), "constant" , 0 )
61
- torch .cuda .empty_cache ()
62
- w2 = F .pad (w2 , (0 , 128 ), "constant" , 0 )
63
- torch .cuda .empty_cache ()
64
-
65
68
if ep_size > 1 :
66
69
local_e = e // ep_size
67
70
e_ids = torch .randint (0 ,
@@ -75,16 +78,7 @@ def test_fused_moe(
75
78
else :
76
79
e_map = None
77
80
78
- triton_output = fused_moe (a ,
79
- w1 ,
80
- w2 ,
81
- score ,
82
- topk ,
83
- global_num_experts = e ,
84
- expert_map = e_map ,
85
- renormalize = False )
86
81
torch_output = torch_moe (a , w1 , w2 , score , topk , e_map )
87
- torch .testing .assert_close (triton_output , torch_output , atol = 2e-2 , rtol = 0 )
88
82
iterative_output = iterative_moe (a ,
89
83
w1 ,
90
84
w2 ,
@@ -93,6 +87,26 @@ def test_fused_moe(
93
87
global_num_experts = e ,
94
88
expert_map = e_map ,
95
89
renormalize = False )
90
+ # Pad the input if use padding
91
+ if envs .VLLM_MOE_PADDING :
92
+ w1 = F .pad (w1 , (0 , 128 ), "constant" , 0 )
93
+ torch .cuda .empty_cache ()
94
+ w2 = F .pad (w2 , (0 , 128 ), "constant" , 0 )
95
+ torch .cuda .empty_cache ()
96
+
97
+ with mock .patch (
98
+ 'vllm.model_executor.layers.fused_moe.fused_moe.padding_size' ,
99
+ padding_size ):
100
+ triton_output = fused_moe (a ,
101
+ w1 ,
102
+ w2 ,
103
+ score ,
104
+ topk ,
105
+ global_num_experts = e ,
106
+ expert_map = e_map ,
107
+ renormalize = False )
108
+
109
+ torch .testing .assert_close (triton_output , torch_output , atol = 2e-2 , rtol = 0 )
96
110
torch .testing .assert_close (iterative_output ,
97
111
torch_output ,
98
112
atol = 1e-2 ,
0 commit comments