Skip to content

Commit ac72b1c

Browse files
authored
bugfix: fix decode kernels output for empty kv cache (#363)
When some request has empty kv cache, the output of decode kernels doesn't align with prefill kernels. This PR fixes the issue. Thanks @MasterJH5574 for reporting this bug.
1 parent 1b84fab commit ac72b1c

File tree

3 files changed

+84
-5
lines changed

3 files changed

+84
-5
lines changed

Diff for: include/flashinfer/attention/handler.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ cudaError_t PartitionPagedKVCacheComputeAuxiliaryInfo(
238238
for (uint32_t batch_idx = 0; batch_idx < old_batch_size; batch_idx++) {
239239
uint32_t num_chunks =
240240
ceil_div(old_indptr_h[batch_idx + 1] - old_indptr_h[batch_idx], max_num_pages_per_batch);
241-
chunk_indptr_vec.push_back(chunk_indptr_vec.back() + num_chunks);
241+
chunk_indptr_vec.push_back(chunk_indptr_vec.back() + std::max(num_chunks, 1U));
242242
if (num_chunks == 0) {
243243
new_page_indptr_vec.push_back(old_indptr_h[batch_idx]);
244244
new_last_page_len_vec.push_back(0);

Diff for: python/tests/test_decode_prefill_lse.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""
2+
Copyright (c) 2024 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import flashinfer
18+
import numpy as np
19+
import torch
20+
import pytest
21+
22+
23+
def test_mlc_failed_case():
24+
kv_layout = "HND"
25+
num_pages = 12
26+
kv_indptr_1 = torch.tensor([0, 0, 9]).int().to(0)
27+
kv_indices_1 = torch.tensor([3, 4, 5, 6, 7, 8, 9, 10, 11]).int().to(0)
28+
kv_last_page_len_1 = torch.tensor([0, 1]).int().to(0)
29+
num_qo_heads = 32
30+
num_kv_heads = 32
31+
page_size = 16
32+
head_dim = 128
33+
q = torch.randn(2, num_qo_heads, head_dim).to(0).half()
34+
kv_data = torch.randn(12, 2, num_kv_heads, page_size, head_dim).to(0).half()
35+
36+
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
37+
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout)
38+
wrapper.begin_forward(
39+
kv_indptr_1,
40+
kv_indices_1,
41+
kv_last_page_len_1,
42+
num_qo_heads,
43+
num_kv_heads,
44+
head_dim,
45+
page_size,
46+
pos_encoding_mode="NONE",
47+
data_type=torch.float16,
48+
q_data_type=torch.float16,
49+
)
50+
o_1, lse_1 = wrapper.forward_return_lse(q, kv_data)
51+
52+
wrapper_tensor_cores = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
53+
workspace_buffer, kv_layout, use_tensor_cores=True
54+
)
55+
wrapper_tensor_cores.begin_forward(
56+
kv_indptr_1,
57+
kv_indices_1,
58+
kv_last_page_len_1,
59+
num_qo_heads,
60+
num_kv_heads,
61+
head_dim,
62+
page_size,
63+
pos_encoding_mode="NONE",
64+
data_type=torch.float16,
65+
q_data_type=torch.float16,
66+
)
67+
o_1_tc, lse_1_tc = wrapper_tensor_cores.forward_return_lse(
68+
q, kv_data
69+
)
70+
71+
np.testing.assert_allclose(
72+
lse_1.cpu().numpy(), lse_1_tc.cpu().numpy(), rtol=1e-3, atol=1e-3
73+
)
74+
np.testing.assert_allclose(
75+
o_1.cpu().numpy(), o_1_tc.cpu().numpy(), rtol=1e-3, atol=1e-3
76+
)
77+
78+
if __name__ == "__main__":
79+
test_mlc_failed_case()

Diff for: python/tests/test_tensor_cores_decode.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_batch_decode_tensor_cores(
104104
num_kv_heads,
105105
head_dim,
106106
page_size,
107-
"NONE",
107+
pos_encoding_mode=pos_encoding_mode,
108108
data_type=torch.float16,
109109
q_data_type=torch.float16,
110110
)
@@ -121,7 +121,7 @@ def test_batch_decode_tensor_cores(
121121
num_kv_heads,
122122
head_dim,
123123
page_size,
124-
"NONE",
124+
pos_encoding_mode=pos_encoding_mode,
125125
data_type=torch.float16,
126126
q_data_type=torch.float16,
127127
)
@@ -187,7 +187,7 @@ def test_batch_decode_tensor_cores_cuda_graph(
187187
num_kv_heads,
188188
head_dim,
189189
page_size,
190-
"NONE",
190+
pos_encoding_mode=pos_encoding_mode,
191191
data_type=torch.float16,
192192
q_data_type=torch.float16,
193193
)
@@ -226,7 +226,7 @@ def test_batch_decode_tensor_cores_cuda_graph(
226226
num_kv_heads,
227227
head_dim,
228228
page_size,
229-
"NONE",
229+
pos_encoding_mode=pos_encoding_mode,
230230
data_type=torch.float16,
231231
q_data_type=torch.float16,
232232
)

0 commit comments

Comments
 (0)