@@ -189,7 +189,7 @@ def generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads):
189
189
@pytest .mark .parametrize ("num_heads" , [16 , 32 , 64 ])
190
190
@pytest .mark .parametrize ("causal" , [False , True ])
191
191
@pytest .mark .parametrize ("page_size" , [1 ])
192
- @pytest .mark .parametrize ("backend" , ["fa2" , " fa3" ])
192
+ @pytest .mark .parametrize ("backend" , ["fa3" ])
193
193
@pytest .mark .parametrize ("dtype" , [torch .half ])
194
194
def test_batch_mla_varlen_page_attention (
195
195
batch_size ,
@@ -311,15 +311,16 @@ def test_batch_mla_varlen_page_attention(
311
311
# torch.testing.assert_close(lse_i, lse_ref, rtol=1e-3, atol=1e-3)
312
312
313
313
314
- @pytest .mark .parametrize ("batch_size" , [1 , 2 , 3 , 4 , 5 , 6 , 7 ])
314
+ @pytest .mark .parametrize ("batch_size" , [1 , 2 , 3 , 4 , 5 , 6 , 7 , 157 ])
315
315
@pytest .mark .parametrize ("kv_len" , [0 , 17 , 33 , 96 , 97 , 114 , 514 , 1024 ])
316
316
@pytest .mark .parametrize (
317
317
"qo_len" , [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 ]
318
318
)
319
319
@pytest .mark .parametrize ("num_heads" , [16 ])
320
320
@pytest .mark .parametrize ("causal" , [False , True ])
321
- @pytest .mark .parametrize ("page_size" , [1 ])
321
+ @pytest .mark .parametrize ("page_size" , [1 , 16 ])
322
322
@pytest .mark .parametrize ("backend" , ["fa2" , "fa3" ])
323
+ @pytest .mark .parametrize ("use_cuda_graph" , [True , False ])
323
324
@pytest .mark .parametrize ("dtype" , [torch .half ])
324
325
def test_batch_mla_page_attention (
325
326
batch_size ,
@@ -329,6 +330,7 @@ def test_batch_mla_page_attention(
329
330
causal ,
330
331
page_size ,
331
332
backend ,
333
+ use_cuda_graph ,
332
334
dtype ,
333
335
):
334
336
if not mla_is_fa3_supported (torch .device ("cuda" )):
@@ -362,12 +364,51 @@ def test_batch_mla_page_attention(
362
364
sm_scale = 1.0 / ((128 + 64 ) ** 0.5 ) # use head dimension before matrix absorption
363
365
workspace_buffer = torch .empty (128 * 1024 * 1024 , dtype = torch .int8 ).to (0 )
364
366
wrapper = flashinfer .mla .BatchMLAPagedAttentionWrapper (
365
- workspace_buffer , backend = backend
367
+ workspace_buffer ,
368
+ backend = backend ,
369
+ use_cuda_graph = True ,
370
+ qo_indptr = torch .empty (batch_size + 1 , dtype = torch .int32 , device = "cuda" ),
371
+ kv_indptr = torch .empty (batch_size + 1 , dtype = torch .int32 , device = "cuda" ),
372
+ kv_indices = torch .empty (1048576 , dtype = torch .int32 , device = "cuda" ),
373
+ kv_len_arr = torch .empty (batch_size , dtype = torch .int32 , device = "cuda" ),
366
374
)
367
375
q_indptr = torch .arange (0 , batch_size + 1 ).to (0 ).int () * qo_len
368
376
kv_indptr = torch .arange (0 , batch_size + 1 ).to (0 ).int () * pages_num
369
377
kv_indices = torch .arange (0 , batch_size * pages_num ).to (0 ).int ()
370
378
kv_lens = torch .full ((batch_size ,), kv_len , dtype = torch .int32 ).to (0 )
379
+
380
+ if use_cuda_graph :
381
+ kv_indptr_warmup = torch .zeros (batch_size + 1 ).to (0 ).int ()
382
+ kv_indices_warmup = torch .arange (0 , batch_size ).to (0 ).int ()
383
+ kv_lens_warmup = torch .full ((batch_size ,), 0 , dtype = torch .int32 ).to (0 )
384
+ wrapper .plan (
385
+ q_indptr ,
386
+ kv_indptr_warmup ,
387
+ kv_indices_warmup ,
388
+ kv_lens_warmup ,
389
+ num_heads ,
390
+ head_dim_ckv ,
391
+ head_dim_kpe ,
392
+ page_size ,
393
+ causal ,
394
+ sm_scale ,
395
+ q_nope .dtype ,
396
+ ckv .dtype ,
397
+ )
398
+
399
+ # warmup
400
+ s = torch .cuda .Stream ()
401
+ s .wait_stream (torch .cuda .current_stream ())
402
+ with torch .cuda .stream (s ):
403
+ for _ in range (3 ):
404
+ o , lse = wrapper .run (q_nope , q_pe , ckv , kpe , return_lse = True )
405
+ torch .cuda .current_stream ().wait_stream (s )
406
+
407
+ # capture
408
+ g = torch .cuda .CUDAGraph ()
409
+ with torch .cuda .graph (g ):
410
+ o , lse = wrapper .run (q_nope , q_pe , ckv , kpe , return_lse = True )
411
+
371
412
wrapper .plan (
372
413
q_indptr ,
373
414
kv_indptr ,
@@ -382,7 +423,12 @@ def test_batch_mla_page_attention(
382
423
q_nope .dtype ,
383
424
ckv .dtype ,
384
425
)
385
- o , lse = wrapper .run (q_nope , q_pe , ckv , kpe , return_lse = True )
426
+ if use_cuda_graph :
427
+ o .fill_ (0 )
428
+ lse .fill_ (0 )
429
+ g .replay ()
430
+ else :
431
+ o , lse = wrapper .run (q_nope , q_pe , ckv , kpe , return_lse = True )
386
432
387
433
k , v = generate_kv_from_cache (ckv , kpe , kv_len , batch_size , num_heads )
388
434
@@ -408,3 +454,4 @@ def test_batch_mla_page_attention(
408
454
test_batch_mla_varlen_page_attention (
409
455
155 , 1024 , 8 , 128 , 128 , 16 , False , 1 , "fa3" , torch .half
410
456
)
457
+ test_batch_mla_page_attention (1 , 1024 , 128 , 128 , False , 1 , "fa2" , True , torch .half )
0 commit comments