11
11
dispatch_fused_experts_func , dispatch_topk_func ,
12
12
torch_vllm_inplace_fused_experts , torch_vllm_outplace_fused_experts ,
13
13
vllm_topk_softmax )
14
+ from vllm .model_executor .layers .fused_moe .rocm_aiter_fused_moe import (
15
+ is_rocm_aiter_moe_enabled )
14
16
from vllm .model_executor .layers .layernorm import (
15
17
RMSNorm , dispatch_cuda_rmsnorm_func , fused_add_rms_norm , rms_norm ,
16
18
rocm_aiter_fused_add_rms_norm , rocm_aiter_rms_norm )
@@ -100,11 +102,10 @@ def test_enabled_ops_invalid(env: str):
100
102
def test_topk_dispatch (use_rocm_aiter : str , monkeypatch ):
101
103
monkeypatch .setenv ("VLLM_ROCM_USE_AITER" , use_rocm_aiter )
102
104
topk_func = dispatch_topk_func ()
103
-
105
+ is_rocm_aiter_moe_enabled . cache_clear ()
104
106
if current_platform .is_rocm () and int (use_rocm_aiter ):
105
107
from vllm .model_executor .layers .fused_moe .rocm_aiter_fused_moe import (
106
108
rocm_aiter_topk_softmax )
107
-
108
109
assert topk_func == rocm_aiter_topk_softmax
109
110
else :
110
111
assert topk_func == vllm_topk_softmax
@@ -116,11 +117,11 @@ def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool,
116
117
monkeypatch ):
117
118
118
119
monkeypatch .setenv ("VLLM_ROCM_USE_AITER" , use_rocm_aiter )
120
+ is_rocm_aiter_moe_enabled .cache_clear ()
119
121
fused_experts_func = dispatch_fused_experts_func (inplace )
120
122
if current_platform .is_rocm () and int (use_rocm_aiter ):
121
123
from vllm .model_executor .layers .fused_moe .rocm_aiter_fused_moe import (
122
124
rocm_aiter_fused_experts )
123
-
124
125
assert fused_experts_func == rocm_aiter_fused_experts
125
126
elif inplace :
126
127
assert fused_experts_func == torch_vllm_inplace_fused_experts
0 commit comments