Skip to content

Commit c04755e

Browse files
authored
bugfix: fix the JIT warmup arguments in unittests (#775)
Followup of #765 , fix the JIT warmup utilities functions.
1 parent a0443d5 commit c04755e

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

tests/jit_utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def jit_decode_attention_func_args(
5353
q_dtype,
5454
kv_dtype,
5555
q_dtype,
56-
head_dim,
56+
head_dim, # head_dim_qk
57+
head_dim, # head_dim_vo
5758
pos_encoding_mode,
5859
use_sliding_window,
5960
use_logits_soft_cap,
@@ -68,7 +69,8 @@ def jit_decode_attention_func_args(
6869
kv_dtype,
6970
q_dtype,
7071
torch.int32,
71-
head_dim,
72+
head_dim, # head_dim_qk
73+
head_dim, # head_dim_vo
7274
pos_encoding_mode,
7375
use_sliding_window,
7476
use_logits_soft_cap,

tests/test_jit_warmup.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def test_warmpup_llama():
3636
torch.float16,
3737
torch.float16,
3838
torch.int32,
39-
128,
39+
128, # head_dim_qk
40+
128, # head_dim_vo
4041
PosEncodingMode.NONE.value,
4142
False, # use_sliding_window
4243
False, # use_logits_soft_cap
@@ -45,11 +46,13 @@ def test_warmpup_llama():
4546
(
4647
flashinfer.prefill.gen_batch_prefill_module,
4748
[
49+
"fa2", # backend
4850
torch.float16,
4951
torch.float16,
5052
torch.float16,
5153
torch.int32,
52-
128,
54+
128, # head_dim_qk
55+
128, # head_dim_vo
5356
PosEncodingMode.NONE.value,
5457
False, # use_sliding_window
5558
False, # use_logits_soft_cap
@@ -75,7 +78,8 @@ def test_warmpup_llama_sm90():
7578
torch.float16,
7679
torch.float16,
7780
torch.int32,
78-
128,
81+
128, # head_dim_qk
82+
128, # head_dim_vo
7983
PosEncodingMode.NONE.value,
8084
False, # use_sliding_window
8185
False, # use_logits_soft_cap
@@ -84,25 +88,29 @@ def test_warmpup_llama_sm90():
8488
(
8589
flashinfer.prefill.gen_batch_prefill_module,
8690
[
91+
"fa2", # backend
8792
torch.float16,
8893
torch.float16,
8994
torch.float16,
9095
torch.int32,
91-
128,
96+
128, # head_dim_qk
97+
128, # head_dim_vo
9298
PosEncodingMode.NONE.value,
9399
False, # use_sliding_window
94100
False, # use_logits_soft_cap
95101
False, # use_fp16_qk_reduction
96102
],
97103
),
98104
(
99-
flashinfer.prefill.gen_batch_prefill_sm90_module,
105+
flashinfer.prefill.gen_batch_prefill_module,
100106
[
107+
"fa3", # backend
101108
torch.float16,
102109
torch.float16,
103110
torch.float16,
104111
torch.int32,
105-
128,
112+
128, # head_dim_qk
113+
128, # head_dim_vo
106114
PosEncodingMode.NONE.value,
107115
False, # use_sliding_window
108116
False, # use_logits_soft_cap

0 commit comments

Comments
 (0)