Skip to content

Commit 9f8ee65

Browse files
bnellnmnishith-fujitsu
authored andcommitted
[Bugfix] Fix function names in test_block_fp8.py (vllm-project#16033)
Signed-off-by: Bill Nell <[email protected]>
1 parent 7b60bed commit 9f8ee65

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tests/kernels/test_block_fp8.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def fp8_perm(m, idx):
360360
return m[idx, ...]
361361

362362

363-
def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
363+
def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
364364
M, K = a.shape
365365

366366
sorted_token_ids, m_indices, num_pad = moe_align_block_size(
@@ -379,7 +379,7 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
379379
return a, a_s, m_indices, inv_perm
380380

381381

382-
def test_moe_unpermute(out, inv_perm, topk, K, topk_weight):
382+
def _moe_unpermute(out, inv_perm, topk, K, topk_weight):
383383
M = topk_weight.shape[0]
384384
out = out[inv_perm, ...]
385385
tmp_out = out.view(-1, topk, K)
@@ -401,8 +401,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
401401

402402
a_q, a_s = per_token_group_quant_fp8(a, block_m)
403403

404-
a_q, a_s, m_indices, inv_perm = test_moe_permute(a_q, a_s, topk_ids,
405-
num_groups, topk, block_m)
404+
a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids,
405+
num_groups, topk, block_m)
406406

407407
inter_out = torch.zeros((a_q.shape[0], N * 2),
408408
dtype=torch.bfloat16,
@@ -419,7 +419,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
419419
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
420420
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
421421

422-
final_out = test_moe_unpermute(out, inv_perm, topk, K, topk_weight)
422+
final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight)
423423

424424
return final_out
425425

0 commit comments

Comments
 (0)