Skip to content

Commit d6fc629

Browse files
authored
[Kernel][Minor] Re-fuse triton moe weight application (#16071)
Signed-off-by: Bill Nell <[email protected]>
1 parent af51d80 commit d6fc629

File tree

1 file changed

+18
-24
lines changed

1 file changed

+18
-24
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

+18-24
Original file line numberDiff line numberDiff line change
@@ -1297,30 +1297,24 @@ def fused_experts_impl(hidden_states: torch.Tensor,
12971297
qintermediate_cache2 = intermediate_cache2
12981298
a2q_scale = a2_scale
12991299

1300-
invoke_fused_moe_kernel(
1301-
qintermediate_cache2,
1302-
w2,
1303-
intermediate_cache3,
1304-
a2q_scale,
1305-
w2_scale,
1306-
w2_zp,
1307-
curr_topk_weights,
1308-
sorted_token_ids,
1309-
expert_ids,
1310-
num_tokens_post_padded,
1311-
False, #True,
1312-
1,
1313-
config,
1314-
compute_type=compute_type,
1315-
use_fp8_w8a8=use_fp8_w8a8,
1316-
use_int8_w8a16=use_int8_w8a16,
1317-
use_int4_w4a16=use_int4_w4a16,
1318-
block_shape=block_shape)
1319-
1320-
if True:
1321-
intermediate_cache3 = intermediate_cache3.view(-1, top_k_num, K)
1322-
intermediate_cache3.mul_(
1323-
curr_topk_weights.view(tokens_in_chunk, -1, 1))
1300+
invoke_fused_moe_kernel(qintermediate_cache2,
1301+
w2,
1302+
intermediate_cache3,
1303+
a2q_scale,
1304+
w2_scale,
1305+
w2_zp,
1306+
curr_topk_weights,
1307+
sorted_token_ids,
1308+
expert_ids,
1309+
num_tokens_post_padded,
1310+
True,
1311+
1,
1312+
config,
1313+
compute_type=compute_type,
1314+
use_fp8_w8a8=use_fp8_w8a8,
1315+
use_int8_w8a16=use_int8_w8a16,
1316+
use_int4_w4a16=use_int4_w4a16,
1317+
block_shape=block_shape)
13241318

13251319
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
13261320
out_hidden_states[begin_chunk_idx:end_chunk_idx])

0 commit comments

Comments
 (0)