@@ -1297,30 +1297,24 @@ def fused_experts_impl(hidden_states: torch.Tensor,
1297
1297
qintermediate_cache2 = intermediate_cache2
1298
1298
a2q_scale = a2_scale
1299
1299
1300
- invoke_fused_moe_kernel (
1301
- qintermediate_cache2 ,
1302
- w2 ,
1303
- intermediate_cache3 ,
1304
- a2q_scale ,
1305
- w2_scale ,
1306
- w2_zp ,
1307
- curr_topk_weights ,
1308
- sorted_token_ids ,
1309
- expert_ids ,
1310
- num_tokens_post_padded ,
1311
- False , #True,
1312
- 1 ,
1313
- config ,
1314
- compute_type = compute_type ,
1315
- use_fp8_w8a8 = use_fp8_w8a8 ,
1316
- use_int8_w8a16 = use_int8_w8a16 ,
1317
- use_int4_w4a16 = use_int4_w4a16 ,
1318
- block_shape = block_shape )
1319
-
1320
- if True :
1321
- intermediate_cache3 = intermediate_cache3 .view (- 1 , top_k_num , K )
1322
- intermediate_cache3 .mul_ (
1323
- curr_topk_weights .view (tokens_in_chunk , - 1 , 1 ))
1300
+ invoke_fused_moe_kernel (qintermediate_cache2 ,
1301
+ w2 ,
1302
+ intermediate_cache3 ,
1303
+ a2q_scale ,
1304
+ w2_scale ,
1305
+ w2_zp ,
1306
+ curr_topk_weights ,
1307
+ sorted_token_ids ,
1308
+ expert_ids ,
1309
+ num_tokens_post_padded ,
1310
+ True ,
1311
+ 1 ,
1312
+ config ,
1313
+ compute_type = compute_type ,
1314
+ use_fp8_w8a8 = use_fp8_w8a8 ,
1315
+ use_int8_w8a16 = use_int8_w8a16 ,
1316
+ use_int4_w4a16 = use_int4_w4a16 ,
1317
+ block_shape = block_shape )
1324
1318
1325
1319
ops .moe_sum (intermediate_cache3 .view (* intermediate_cache3 .shape ),
1326
1320
out_hidden_states [begin_chunk_idx :end_chunk_idx ])
0 commit comments