Skip to content

Commit 4127635

Browse files
authored
bugfix: Fix arguments of plan for split QK/VO head dims (#795)
#765 introduced changes to the API of `plan`, including renaming `head_dim` to `head_dim_qk` and adding `head_dim_vo`. However, some calling sites were not updated to reflect these changes, resulting in failing unit tests. This PR addresses the issue by updating the relevant calls, which should resolve the following unit test failures after merging: - `tests/test_block_sparse.py::test_block_sparse_attention` - `tests/test_non_contiguous_prefill.py::test_batch_paged_prefill_packed_input` --------- Signed-off-by: abmfy <[email protected]>
1 parent 504b990 commit 4127635

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

flashinfer/sparse.py

+3
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def plan(
355355
q_data_type,
356356
indptr.dtype,
357357
head_dim,
358+
head_dim,
358359
PosEncodingMode[pos_encoding_mode].value,
359360
False, # use_sliding_window
360361
logits_soft_cap > 0, # use_logits_soft_cap
@@ -374,6 +375,7 @@ def plan(
374375
-1, # window_left
375376
logits_soft_cap, # logits_soft_cap
376377
head_dim,
378+
head_dim,
377379
torch.empty(0, dtype=q_data_type),
378380
torch.empty(0, dtype=kv_data_type),
379381
get_cuda_stream(device),
@@ -442,6 +444,7 @@ def plan(
442444
self.C, # page_size
443445
False, # is_cuda_graph_enabled,
444446
head_dim,
447+
head_dim,
445448
causal,
446449
get_cuda_stream(device),
447450
)

tests/test_non_contiguous_prefill.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def test_batch_paged_prefill_packed_input(
177177
paged_kv_last_page_len=paged_kv_last_page_len,
178178
num_qo_heads=num_qo_heads,
179179
num_kv_heads=num_kv_heads,
180-
head_dim=head_dim,
180+
head_dim_qk=head_dim,
181181
page_size=page_size,
182182
causal=causal,
183183
)

0 commit comments

Comments
 (0)