Skip to content

Commit 3fff008

Browse files
Fix PagedPrefill python api and some typos (#441)
Fix two small bugs: 1. “NHD” and "HND" used confusing 2. PagedPrefill use self._custom_mask_buf to judge whether is customized_mask, but uninitialized Here is the code snippet to reproduce the 2nd bug: ```python import torch import flashinfer # try to reproduce the bug under speculative decoding case device = torch.device("cuda:0") num_heads = 32 num_qo_heads = num_heads num_kv_heads = 32 head_dim = 128 page_size = 4 max_num_pages = 4 batch_size = 1 seq_len = 4 query = torch.randn(seq_len, num_heads, head_dim, dtype=torch.bfloat16, device=device) packed_kv_cache = torch.randn(max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.bfloat16, device=device) ragged_key_cache = packed_kv_cache[:, 0].reshape(-1, num_kv_heads, head_dim) ragged_value_cache = packed_kv_cache[:, 1].reshape(-1, num_kv_heads, head_dim) # [4, 15] shape attn_mask = torch.tensor([ [ True, True, True, True, True, True, True, True, False, False, False, True, False, False, False], [ True, True, True, True, True, True, True, False, True, False, False, False, True, False, False], [ True, True, True, True, True, True, True, True, False, False, False, False, False, True, False], [ True, True, True, True, True, True, True, False, False, True, False, False, False, False, True] ], device=device) mask = attn_mask.reshape(-1) # packed_mask = flashinfer.quantization.packbits(mask) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") paged_prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, "NHD" ) kv_page_indices = torch.arange(max_num_pages).int().to("cuda:0") kv_page_indptr = torch.tensor( [0, 4], dtype=torch.int32, device="cuda:0" ) # 1 <= kv_last_page_len <= page_size kv_last_page_len = torch.tensor( [3], dtype=torch.int32, device="cuda:0" ) qo_indptr = torch.tensor( [0, 4], dtype=torch.int32, device="cuda:0") # create auxiliary data structures for batch decode attention paged_prefill_wrapper.begin_forward( qo_indptr, kv_page_indptr, kv_page_indices, kv_last_page_len, num_qo_heads, num_kv_heads, head_dim, page_size, mask, q_data_type=torch.bfloat16 ) # assert torch.equal(paged_prefill_wrapper._custom_mask, packed_mask) # assert paged_prefill_wrapper._custom_mask_buf is not None q = query o = paged_prefill_wrapper.forward(q, packed_kv_cache, causal=False) paged_prefill_wrapper.end_forward() # ragged attn workspace_buffer_ragged = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") ragged_prefill_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer_ragged, "NHD" ) kv_indptr = torch.tensor( [0, 15], dtype=torch.int32, device="cuda:0" ) ragged_prefill_wrapper.begin_forward( qo_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim, mask, q_data_type='bfloat16' ) ragged_o = ragged_prefill_wrapper.forward(q, ragged_key_cache, ragged_value_cache) ragged_prefill_wrapper.end_forward() print("query shape: ", q.shape) print("paged vs ragged allclose: ", torch.allclose(o, ragged_o, rtol=1e-3, atol=1e-3)) print("paged vs ragged equal: ", torch.equal(o, ragged_o)) assert torch.allclose(o, ragged_o, rtol=1e-3, atol=1e-3) assert torch.equal(o, ragged_o) ```
1 parent 6ac28f4 commit 3fff008

File tree

5 files changed

+11
-11
lines changed

5 files changed

+11
-11
lines changed

Diff for: python/csrc/batch_prefill.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -289,11 +289,11 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
289289

290290
if (paged_kv_defined) {
291291
// [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND
292-
// [max_num_pages, 2, page_size, num_kv_heads, head_dim] for HND
292+
// [max_num_pages, 2, page_size, num_kv_heads, head_dim] for NHD
293293
CHECK_DIM(5, paged_kv_cache.value());
294294
} else {
295295
// [max_num_pages, num_kv_heads, page_size, head_dim] for HND
296-
// [max_num_pages, page_size, num_kv_heads, head_dim] for HND
296+
// [max_num_pages, page_size, num_kv_heads, head_dim] for NHD
297297
CHECK_DIM(4, paged_k_cache.value());
298298
CHECK_DIM(4, paged_v_cache.value());
299299
}

Diff for: python/flashinfer/cascade.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def forward(
374374
``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
375375
:attr:`kv_layout` is ``NHD``, and
376376
``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
377-
:attr:`kv_layout` is ``NHD``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
377+
:attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
378378
``paged_kv_cache[:, 1]`` is the value-cache.
379379
380380
allow_fp16_qk_reduction : bool
@@ -631,7 +631,7 @@ def forward(
631631
``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
632632
:attr:`kv_layout` is ``NHD``, and
633633
``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
634-
:attr:`kv_layout` is ``NHD``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
634+
:attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
635635
``paged_kv_cache[:, 1]`` is the value-cache.
636636
637637
causal : bool

Diff for: python/flashinfer/decode.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ def forward(
577577
``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
578578
:attr:`kv_layout` is ``NHD``, and
579579
``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
580-
:attr:`kv_layout` is ``NHD``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
580+
:attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
581581
``paged_kv_cache[:, 1]`` is the value-cache.
582582
583583
pos_encoding_mode : str
@@ -696,7 +696,7 @@ def forward_return_lse(
696696
``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
697697
:attr:`kv_layout` is ``NHD``, and
698698
``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
699-
:attr:`kv_layout` is ``NHD``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
699+
:attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
700700
``paged_kv_cache[:, 1]`` is the value-cache.
701701
702702
pos_encoding_mode : str

Diff for: python/flashinfer/page.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def append_paged_kv_cache(
6565
``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
6666
:attr:`kv_layout` is ``NHD``, and
6767
``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
68-
:attr:`kv_layout` is ``NHD``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
68+
:attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
6969
``paged_kv_cache[:, 1]`` is the value-cache.
7070
7171
kv_indices : torch.Tensor

Diff for: python/flashinfer/prefill.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -778,8 +778,8 @@ def begin_forward(
778778
self._paged_kv_indices_buf = paged_kv_indices
779779
self._paged_kv_last_page_len_buf = paged_kv_last_page_len
780780
if packed_custom_mask is not None:
781-
self._custom_mask = packed_custom_mask
782-
self._qk_indptr = qk_indptr
781+
self._custom_mask_buf = packed_custom_mask
782+
self._qk_indptr_buf = qk_indptr
783783
empty_q_data = torch.empty(
784784
0,
785785
dtype=(
@@ -843,7 +843,7 @@ def forward(
843843
``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
844844
:attr:`kv_layout` is ``NHD``, and
845845
``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
846-
:attr:`kv_layout` is ``NHD``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
846+
:attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
847847
``paged_kv_cache[:, 1]`` is the value-cache.
848848
849849
causal : bool
@@ -969,7 +969,7 @@ def forward_return_lse(
969969
``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
970970
:attr:`kv_layout` is ``NHD``, and
971971
``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
972-
:attr:`kv_layout` is ``NHD``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
972+
:attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
973973
``paged_kv_cache[:, 1]`` is the value-cache.
974974
975975
causal : bool

0 commit comments

Comments
 (0)