Skip to content

Commit 3c17c62

Browse files
committed
More checks
Signed-off-by: kaixih <[email protected]>
1 parent 3efdc58 commit 3c17c62

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

tests/kernels/test_cutlass_mla_decode.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def ref_mla(
4747
@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096])
4848
@pytest.mark.parametrize("bs", [1, 2, 4])
4949
@pytest.mark.parametrize("varlen", [False, True])
50-
@pytest.mark.parametrize("block_size", [16, 128])
50+
@pytest.mark.parametrize("block_size", [16, 64, 128])
5151
def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int,
5252
varlen: bool, block_size: int):
5353
torch.set_default_dtype(dtype)
@@ -69,6 +69,12 @@ def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int,
6969
max_seq_len = seq_lens.max().item()
7070
block_num = (max_seq_len + block_size - 1) // block_size
7171

72+
# Pad block_num so that small blocks can be packed into full 128-sized
73+
# CUTLASS tiles. One 128-wide tile can hold (128 // block_size) small
74+
# blocks.
75+
pack_factor = 128 // block_size
76+
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
77+
7278
q = torch.randn(bs, h_q, d)
7379
block_table = torch.randint(0,
7480
bs * block_num, (bs, block_num),

vllm/_custom_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,8 +1440,10 @@ def cutlass_mla_decode(q_nope_and_q_pe: torch.Tensor,
14401440
assert not current_platform.is_rocm()
14411441
assert q_nope_and_q_pe.ndim == 3, f"q_nope_and_q_pe must be a 3D tensor, but got {q_nope_and_q_pe.ndim}"
14421442
assert kv_c_and_k_pe_cache.ndim == 3, f"kv_c_and_k_pe_cache must be a 3D tensor, but got {kv_c_and_k_pe_cache.ndim}"
1443+
assert page_table.ndim == 2, f"page_table must be a 2D tensor, but got {page_table.ndim}"
14431444
B_q, H, D_q = q_nope_and_q_pe.shape
14441445
_, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape
1446+
B_pt, PAGE_NUM = page_table.shape
14451447

14461448
D_latent = 512
14471449
D_rope = 64
@@ -1453,6 +1455,11 @@ def cutlass_mla_decode(q_nope_and_q_pe: torch.Tensor,
14531455
assert PAGE_SIZE > 0 and (
14541456
PAGE_SIZE & (PAGE_SIZE - 1)
14551457
) == 0, f"PAGE_SIZE must be a power of 2, but got {PAGE_SIZE}"
1458+
assert B_pt == B_q, f"Batch dims must be same for page_table and q_nope_and_q_pe, but got {B_pt} and {B_q}"
1459+
1460+
# Current cutlass MLA implementation will pack smaller pages into a 128 page.
1461+
assert PAGE_NUM % (128 / PAGE_SIZE) == 0, f"PAGE_NUM must be divisible by 128 / PAGE_SIZE, but got {PAGE_NUM} and {128 / PAGE_SIZE}"
1462+
14561463

14571464
# TODO(kaixih@nvidia): support fp8
14581465
assert q_nope_and_q_pe.dtype in (torch.float16, torch.bfloat16), (

0 commit comments

Comments
 (0)