Skip to content

Commit 977d3fe

Browse files
authored
unittest: add unittests for MLA + cudagraph (#890)
This update also addresses an issue in the scheduler that could cause the program to hang under certain conditions.
1 parent dc18f66 commit 977d3fe

File tree

2 files changed

+53
-6
lines changed

2 files changed

+53
-6
lines changed

include/flashinfer/attention/scheduler.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -1098,7 +1098,7 @@ inline cudaError_t MLAPlan(void* float_buffer, size_t float_workspace_size_in_by
10981098
total_kv_lens += effective_kv_len;
10991099
}
11001100
}
1101-
int kv_len_limit = ceil_div(ceil_div(total_kv_lens, num_clusters), 512L) * 512L;
1101+
int kv_len_limit = ceil_div(std::max(ceil_div(total_kv_lens, num_clusters), 1L), 512L) * 512L;
11021102

11031103
// step 1. load-balancing scheduling algorithm
11041104
MinHeap cluster_cost_heap(num_clusters);

tests/test_deepseek_mla.py

+52-5
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads):
189189
@pytest.mark.parametrize("num_heads", [16, 32, 64])
190190
@pytest.mark.parametrize("causal", [False, True])
191191
@pytest.mark.parametrize("page_size", [1])
192-
@pytest.mark.parametrize("backend", ["fa2", "fa3"])
192+
@pytest.mark.parametrize("backend", ["fa3"])
193193
@pytest.mark.parametrize("dtype", [torch.half])
194194
def test_batch_mla_varlen_page_attention(
195195
batch_size,
@@ -311,15 +311,16 @@ def test_batch_mla_varlen_page_attention(
311311
# torch.testing.assert_close(lse_i, lse_ref, rtol=1e-3, atol=1e-3)
312312

313313

314-
@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5, 6, 7])
314+
@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5, 6, 7, 157])
315315
@pytest.mark.parametrize("kv_len", [0, 17, 33, 96, 97, 114, 514, 1024])
316316
@pytest.mark.parametrize(
317317
"qo_len", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
318318
)
319319
@pytest.mark.parametrize("num_heads", [16])
320320
@pytest.mark.parametrize("causal", [False, True])
321-
@pytest.mark.parametrize("page_size", [1])
321+
@pytest.mark.parametrize("page_size", [1, 16])
322322
@pytest.mark.parametrize("backend", ["fa2", "fa3"])
323+
@pytest.mark.parametrize("use_cuda_graph", [True, False])
323324
@pytest.mark.parametrize("dtype", [torch.half])
324325
def test_batch_mla_page_attention(
325326
batch_size,
@@ -329,6 +330,7 @@ def test_batch_mla_page_attention(
329330
causal,
330331
page_size,
331332
backend,
333+
use_cuda_graph,
332334
dtype,
333335
):
334336
if not mla_is_fa3_supported(torch.device("cuda")):
@@ -362,12 +364,51 @@ def test_batch_mla_page_attention(
362364
sm_scale = 1.0 / ((128 + 64) ** 0.5) # use head dimension before matrix absorption
363365
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
364366
wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
365-
workspace_buffer, backend=backend
367+
workspace_buffer,
368+
backend=backend,
369+
use_cuda_graph=True,
370+
qo_indptr=torch.empty(batch_size + 1, dtype=torch.int32, device="cuda"),
371+
kv_indptr=torch.empty(batch_size + 1, dtype=torch.int32, device="cuda"),
372+
kv_indices=torch.empty(1048576, dtype=torch.int32, device="cuda"),
373+
kv_len_arr=torch.empty(batch_size, dtype=torch.int32, device="cuda"),
366374
)
367375
q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len
368376
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * pages_num
369377
kv_indices = torch.arange(0, batch_size * pages_num).to(0).int()
370378
kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
379+
380+
if use_cuda_graph:
381+
kv_indptr_warmup = torch.zeros(batch_size + 1).to(0).int()
382+
kv_indices_warmup = torch.arange(0, batch_size).to(0).int()
383+
kv_lens_warmup = torch.full((batch_size,), 0, dtype=torch.int32).to(0)
384+
wrapper.plan(
385+
q_indptr,
386+
kv_indptr_warmup,
387+
kv_indices_warmup,
388+
kv_lens_warmup,
389+
num_heads,
390+
head_dim_ckv,
391+
head_dim_kpe,
392+
page_size,
393+
causal,
394+
sm_scale,
395+
q_nope.dtype,
396+
ckv.dtype,
397+
)
398+
399+
# warmup
400+
s = torch.cuda.Stream()
401+
s.wait_stream(torch.cuda.current_stream())
402+
with torch.cuda.stream(s):
403+
for _ in range(3):
404+
o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True)
405+
torch.cuda.current_stream().wait_stream(s)
406+
407+
# capture
408+
g = torch.cuda.CUDAGraph()
409+
with torch.cuda.graph(g):
410+
o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True)
411+
371412
wrapper.plan(
372413
q_indptr,
373414
kv_indptr,
@@ -382,7 +423,12 @@ def test_batch_mla_page_attention(
382423
q_nope.dtype,
383424
ckv.dtype,
384425
)
385-
o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True)
426+
if use_cuda_graph:
427+
o.fill_(0)
428+
lse.fill_(0)
429+
g.replay()
430+
else:
431+
o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True)
386432

387433
k, v = generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads)
388434

@@ -408,3 +454,4 @@ def test_batch_mla_page_attention(
408454
test_batch_mla_varlen_page_attention(
409455
155, 1024, 8, 128, 128, 16, False, 1, "fa3", torch.half
410456
)
457+
test_batch_mla_page_attention(1, 1024, 128, 128, False, 1, "fa2", True, torch.half)

0 commit comments

Comments
 (0)