@@ -360,7 +360,7 @@ def fp8_perm(m, idx):
360
360
return m [idx , ...]
361
361
362
362
363
- def test_moe_permute (a , a_s , topk_ids , num_groups , topk , block_m ):
363
+ def _moe_permute (a , a_s , topk_ids , num_groups , topk , block_m ):
364
364
M , K = a .shape
365
365
366
366
sorted_token_ids , m_indices , num_pad = moe_align_block_size (
@@ -379,7 +379,7 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
379
379
return a , a_s , m_indices , inv_perm
380
380
381
381
382
- def test_moe_unpermute (out , inv_perm , topk , K , topk_weight ):
382
+ def _moe_unpermute (out , inv_perm , topk , K , topk_weight ):
383
383
M = topk_weight .shape [0 ]
384
384
out = out [inv_perm , ...]
385
385
tmp_out = out .view (- 1 , topk , K )
@@ -401,8 +401,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
401
401
402
402
a_q , a_s = per_token_group_quant_fp8 (a , block_m )
403
403
404
- a_q , a_s , m_indices , inv_perm = test_moe_permute (a_q , a_s , topk_ids ,
405
- num_groups , topk , block_m )
404
+ a_q , a_s , m_indices , inv_perm = _moe_permute (a_q , a_s , topk_ids ,
405
+ num_groups , topk , block_m )
406
406
407
407
inter_out = torch .zeros ((a_q .shape [0 ], N * 2 ),
408
408
dtype = torch .bfloat16 ,
@@ -419,7 +419,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
419
419
deep_gemm .m_grouped_gemm_fp8_fp8_bf16_nt_contiguous (
420
420
(act_out_q , act_out_s ), (w2 , w2_s ), out , m_indices )
421
421
422
- final_out = test_moe_unpermute (out , inv_perm , topk , K , topk_weight )
422
+ final_out = _moe_unpermute (out , inv_perm , topk , K , topk_weight )
423
423
424
424
return final_out
425
425
0 commit comments