20
20
import torch
21
21
22
22
import flashinfer
23
+ from flashinfer .jit .attention import (
24
+ gen_batch_mla_module ,
25
+ gen_batch_prefill_module ,
26
+ gen_single_prefill_module ,
27
+ )
23
28
from flashinfer .utils import is_sm90a_supported
24
29
25
30
31
+ @pytest .fixture (autouse = True , scope = "module" )
32
+ def warmup_jit ():
33
+ try :
34
+ modules = []
35
+ for backend in ["fa2" , "fa3" ]:
36
+ if backend == "fa3" and not is_sm90a_supported (torch .device ("cuda" )):
37
+ continue
38
+
39
+ modules .append (
40
+ (
41
+ gen_single_prefill_module ,
42
+ [
43
+ backend ,
44
+ torch .float16 ,
45
+ torch .float16 ,
46
+ torch .float16 ,
47
+ 192 ,
48
+ 128 ,
49
+ 0 ,
50
+ False ,
51
+ False ,
52
+ False ,
53
+ ],
54
+ )
55
+ )
56
+
57
+ for backend in ["fa2" , "fa3" ]:
58
+ if backend == "fa3" and not is_sm90a_supported (torch .device ("cuda" )):
59
+ continue
60
+
61
+ modules .append (
62
+ (
63
+ gen_batch_prefill_module ,
64
+ [
65
+ backend ,
66
+ torch .float16 ,
67
+ torch .float16 ,
68
+ torch .float16 ,
69
+ torch .int32 ,
70
+ 192 ,
71
+ 128 ,
72
+ 0 ,
73
+ False ,
74
+ False ,
75
+ False ,
76
+ ],
77
+ )
78
+ )
79
+
80
+ for backend in ["fa2" , "fa3" ]:
81
+ if backend == "fa3" and not is_sm90a_supported (torch .device ("cuda" )):
82
+ continue
83
+
84
+ modules .append (
85
+ (
86
+ gen_batch_mla_module ,
87
+ [
88
+ backend ,
89
+ torch .float16 ,
90
+ torch .float16 ,
91
+ torch .float16 ,
92
+ torch .int32 ,
93
+ 512 ,
94
+ 64 ,
95
+ False ,
96
+ ],
97
+ )
98
+ )
99
+
100
+ flashinfer .jit .parallel_load_modules (modules )
101
+ except Exception as e :
102
+ # abort the test session if warmup fails
103
+ pytest .exit (str (e ))
104
+ finally :
105
+ yield
106
+
107
+
26
108
def attention_ref (
27
109
batch_size ,
28
110
q : torch .Tensor ,
@@ -83,7 +165,7 @@ def test_single_prefill_with_kv_cache(
83
165
backend ,
84
166
dtype ,
85
167
):
86
- if not is_sm90a_supported (torch .device ("cuda" )):
168
+ if backend == "fa3" and not is_sm90a_supported (torch .device ("cuda" )):
87
169
pytest .skip ("FA3 is not supported on this device" )
88
170
torch .manual_seed (42 )
89
171
head_dim_qk = 192
@@ -117,7 +199,7 @@ def test_batch_prefill_with_ragged_kv_cache(
117
199
backend ,
118
200
dtype ,
119
201
):
120
- if not is_sm90a_supported (torch .device ("cuda" )):
202
+ if backend == "fa3" and not is_sm90a_supported (torch .device ("cuda" )):
121
203
pytest .skip ("FA3 is not supported on this device" )
122
204
torch .manual_seed (42 )
123
205
kv_layout = "NHD"
@@ -188,17 +270,15 @@ def generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads):
188
270
return k , v
189
271
190
272
191
- @pytest .mark .parametrize ("batch_size" , [1 , 2 , 3 , 4 , 5 , 6 , 7 ])
192
- @pytest .mark .parametrize ("kv_len_0" , [0 , 1 , 2 , 3 , 4 , 11 ])
193
- @pytest .mark .parametrize ("kv_len_1" , [17 , 19 , 33 , 79 , 114 ])
273
+ @pytest .mark .parametrize ("batch_size" , [1 , 3 , 5 , 7 ])
274
+ @pytest .mark .parametrize ("kv_len_0" , [0 , 1 , 3 , 11 ])
275
+ @pytest .mark .parametrize ("kv_len_1" , [17 , 33 , 79 , 114 ])
194
276
@pytest .mark .parametrize ("kv_len_2" , [514 , 2743 , 8736 ])
195
- @pytest .mark .parametrize (
196
- "qo_len" , [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 ]
197
- )
198
- @pytest .mark .parametrize ("num_heads" , [16 , 32 , 64 ])
277
+ @pytest .mark .parametrize ("qo_len" , [1 , 3 , 5 , 7 , 9 , 11 , 13 , 15 , 17 ])
278
+ @pytest .mark .parametrize ("num_heads" , [16 , 64 ])
199
279
@pytest .mark .parametrize ("causal" , [False , True ])
200
280
@pytest .mark .parametrize ("page_size" , [1 ])
201
- @pytest .mark .parametrize ("backend" , ["fa3" ])
281
+ @pytest .mark .parametrize ("backend" , ["fa2" , " fa3" ])
202
282
@pytest .mark .parametrize ("dtype" , [torch .half ])
203
283
def test_batch_mla_varlen_page_attention (
204
284
batch_size ,
@@ -212,7 +292,7 @@ def test_batch_mla_varlen_page_attention(
212
292
backend ,
213
293
dtype ,
214
294
):
215
- if not is_sm90a_supported (torch .device ("cuda" )):
295
+ if backend == "fa3" and not is_sm90a_supported (torch .device ("cuda" )):
216
296
pytest .skip ("FA3 is not supported on this device" )
217
297
if causal and qo_len > min (kv_len_0 , kv_len_1 , kv_len_2 ):
218
298
pytest .skip ("qo_len > kv_len not supported for causal attention" )
@@ -336,7 +416,7 @@ def test_batch_mla_varlen_page_attention(
336
416
def test_batch_mla_oob_kv_nan (
337
417
batch_size , kv_len , qo_len , num_heads , causal , page_size , backend , dtype
338
418
):
339
- if not is_sm90a_supported (torch .device ("cuda" )):
419
+ if backend == "fa3" and not is_sm90a_supported (torch .device ("cuda" )):
340
420
pytest .skip ("FA3 is not supported on this device" )
341
421
if causal and qo_len > kv_len :
342
422
pytest .skip ("qo_len > kv_len not supported for causal attention" )
@@ -405,16 +485,14 @@ def test_batch_mla_oob_kv_nan(
405
485
torch .testing .assert_close (lse , lse_ref , rtol = 1e-3 , atol = 1e-3 )
406
486
407
487
408
- @pytest .mark .parametrize ("batch_size" , [1 , 2 , 3 , 4 , 5 , 6 , 7 , 157 ])
488
+ @pytest .mark .parametrize ("batch_size" , [1 , 3 , 5 , 7 , 157 ])
409
489
@pytest .mark .parametrize ("kv_len" , [0 , 17 , 33 , 96 , 97 , 114 , 514 , 1024 ])
410
- @pytest .mark .parametrize (
411
- "qo_len" , [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 ]
412
- )
490
+ @pytest .mark .parametrize ("qo_len" , [1 , 3 , 5 , 7 , 9 , 11 , 13 , 15 , 17 ])
413
491
@pytest .mark .parametrize ("num_heads" , [16 ])
414
492
@pytest .mark .parametrize ("causal" , [False , True ])
415
493
@pytest .mark .parametrize ("page_size" , [1 , 16 ])
416
494
@pytest .mark .parametrize ("backend" , ["fa2" , "fa3" ])
417
- @pytest .mark .parametrize ("use_cuda_graph" , [True , False ])
495
+ @pytest .mark .parametrize ("use_cuda_graph" , [False ])
418
496
@pytest .mark .parametrize ("dtype" , [torch .half ])
419
497
def test_batch_mla_page_attention (
420
498
batch_size ,
@@ -427,7 +505,7 @@ def test_batch_mla_page_attention(
427
505
use_cuda_graph ,
428
506
dtype ,
429
507
):
430
- if not is_sm90a_supported (torch .device ("cuda" )):
508
+ if backend == "fa3" and not is_sm90a_supported (torch .device ("cuda" )):
431
509
pytest .skip ("FA3 is not supported on this device" )
432
510
if causal and qo_len > kv_len :
433
511
pytest .skip ("qo_len > kv_len not supported for causal attention" )
0 commit comments