@@ -36,7 +36,8 @@ def test_warmpup_llama():
36
36
torch .float16 ,
37
37
torch .float16 ,
38
38
torch .int32 ,
39
- 128 ,
39
+ 128 , # head_dim_qk
40
+ 128 , # head_dim_vo
40
41
PosEncodingMode .NONE .value ,
41
42
False , # use_sliding_window
42
43
False , # use_logits_soft_cap
@@ -45,11 +46,13 @@ def test_warmpup_llama():
45
46
(
46
47
flashinfer .prefill .gen_batch_prefill_module ,
47
48
[
49
+ "fa2" , # backend
48
50
torch .float16 ,
49
51
torch .float16 ,
50
52
torch .float16 ,
51
53
torch .int32 ,
52
- 128 ,
54
+ 128 , # head_dim_qk
55
+ 128 , # head_dim_vo
53
56
PosEncodingMode .NONE .value ,
54
57
False , # use_sliding_window
55
58
False , # use_logits_soft_cap
@@ -75,7 +78,8 @@ def test_warmpup_llama_sm90():
75
78
torch .float16 ,
76
79
torch .float16 ,
77
80
torch .int32 ,
78
- 128 ,
81
+ 128 , # head_dim_qk
82
+ 128 , # head_dim_vo
79
83
PosEncodingMode .NONE .value ,
80
84
False , # use_sliding_window
81
85
False , # use_logits_soft_cap
@@ -84,25 +88,29 @@ def test_warmpup_llama_sm90():
84
88
(
85
89
flashinfer .prefill .gen_batch_prefill_module ,
86
90
[
91
+ "fa2" , # backend
87
92
torch .float16 ,
88
93
torch .float16 ,
89
94
torch .float16 ,
90
95
torch .int32 ,
91
- 128 ,
96
+ 128 , # head_dim_qk
97
+ 128 , # head_dim_vo
92
98
PosEncodingMode .NONE .value ,
93
99
False , # use_sliding_window
94
100
False , # use_logits_soft_cap
95
101
False , # use_fp16_qk_reduction
96
102
],
97
103
),
98
104
(
99
- flashinfer .prefill .gen_batch_prefill_sm90_module ,
105
+ flashinfer .prefill .gen_batch_prefill_module ,
100
106
[
107
+ "fa3" , # backend
101
108
torch .float16 ,
102
109
torch .float16 ,
103
110
torch .float16 ,
104
111
torch .int32 ,
105
- 128 ,
112
+ 128 , # head_dim_qk
113
+ 128 , # head_dim_vo
106
114
PosEncodingMode .NONE .value ,
107
115
False , # use_sliding_window
108
116
False , # use_logits_soft_cap
0 commit comments