Skip to content

Commit 238563f

Browse files
authored
fix: fatal bugfix in batch decode operator (#177)
The `BatchDecodeWithPagedKVCacheWrapper` didn't run into the kernel.
1 parent 44d3c03 commit 238563f

File tree

3 files changed

+42
-54
lines changed

3 files changed

+42
-54
lines changed

Diff for: include/flashinfer/decode_attention_decl.cuh

+39-51
Original file line numberDiff line numberDiff line change
@@ -166,51 +166,15 @@ cudaError_t BatchDecodeWithPagedKVCache(
166166
* \note This wrapper function should be only called after we call BeginForward function in the
167167
* BatchDecodeHandler.
168168
*/
169-
template <PageStorage page_storage, QKVLayout kv_layout, uint32_t GROUP_SIZE, uint32_t HEAD_DIM,
170-
PosEncodingMode pos_encoding_mode, typename DTypeIn, typename DTypeOut, typename IdType>
171-
cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched(
172-
BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset,
173-
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
174-
float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) {
175-
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> new_paged_kv = paged_kv;
176-
kv_partition_info_t<IdType> kv_partition_info;
177-
DTypeOut* tmp = handler->GetTempFloatBuffer<DTypeOut>();
178-
179-
if (handler->IsForwardStarted()) {
180-
if (tmp != nullptr) {
181-
// create auxiliary information for cooperative kernels
182-
new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition();
183-
new_paged_kv.indptr = handler->GetNewIndPtr<IdType>();
184-
new_paged_kv.last_page_len = handler->GetNewLastPageLen<IdType>();
185-
kv_partition_info.batch_size_before_partition = handler->GetBatchSizeBeforePartition();
186-
kv_partition_info.chunk_indptr = handler->GetChunkIndPtr<IdType>();
187-
kv_partition_info.batch_idx_map = handler->GetBatchIdxMap<IdType>();
188-
kv_partition_info.chunk_start_pos = handler->GetChunkStartPos<IdType>();
189-
kv_partition_info.seq_lens_before_partition = handler->GetSeqLengthsBeforePartition<IdType>();
190-
}
191-
} else {
192-
std::ostringstream err_msg;
193-
err_msg << "Please call BatchDecodeHandler's BeginForward() before calling "
194-
"BatchDecodeWithPagedKVCacheWrapper()";
195-
throw std::runtime_error(err_msg.str());
196-
}
197-
198-
return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage, kv_layout,
199-
pos_encoding_mode, DTypeIn, DTypeOut, IdType>(
200-
q, q_offset, new_paged_kv, kv_partition_info, o, tmp, lse, sm_scale, rope_scale, rope_theta,
201-
stream);
202-
return cudaSuccess;
203-
}
204-
205-
template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
169+
template <PageStorage page_storage, QKVLayout KV_LAYOUT, typename DTypeIn, typename DTypeOut,
206170
typename IdType>
207171
cudaError_t BatchDecodeWithPagedKVCacheWrapper(
208172
BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset,
209-
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
173+
paged_kv_t<page_storage, KV_LAYOUT, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
210174
uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone,
211175
std::optional<float> maybe_sm_scale = std::nullopt, float rope_scale = 1.f,
212176
float rope_theta = 1e4, cudaStream_t stream = nullptr) {
213-
const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(paged_kv.head_dim)));
177+
float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(paged_kv.head_dim)));
214178
const uint32_t num_kv_heads = paged_kv.num_heads;
215179
if (num_qo_heads % num_kv_heads != 0) {
216180
std::ostringstream err_msg;
@@ -219,18 +183,42 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper(
219183
throw std::invalid_argument(err_msg.str());
220184
}
221185

222-
// DISPATCH_GQA_GROUP_SIZE(
223-
// num_qo_heads / num_kv_heads, GROUP_SIZE,
224-
// {DISPATCH_HEAD_DIM(
225-
// paged_kv.head_dim, HEAD_DIM,
226-
// {DISPATCH_POS_ENCODING_MODE(
227-
// pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, {
228-
// return BatchDecodeWithPagedKVCacheWrapperDispatched<
229-
// page_storage, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, DTypeIn,
230-
// DTypeOut, IdType>(handler, q, q_offset, paged_kv, o, lse, sm_scale,
231-
// rope_scale,
232-
// rope_theta, stream);
233-
// })})})});
186+
DISPATCH_GQA_GROUP_SIZE(
187+
num_qo_heads / num_kv_heads, GROUP_SIZE,
188+
{DISPATCH_HEAD_DIM(
189+
paged_kv.head_dim, HEAD_DIM,
190+
{DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, {
191+
paged_kv_t<page_storage, KV_LAYOUT, DTypeIn, IdType> new_paged_kv = paged_kv;
192+
kv_partition_info_t<IdType> kv_partition_info;
193+
DTypeOut* tmp = handler->GetTempFloatBuffer<DTypeOut>();
194+
195+
if (handler->IsForwardStarted()) {
196+
if (tmp != nullptr) {
197+
// create auxiliary information for cooperative kernels
198+
new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition();
199+
new_paged_kv.indptr = handler->GetNewIndPtr<IdType>();
200+
new_paged_kv.last_page_len = handler->GetNewLastPageLen<IdType>();
201+
kv_partition_info.batch_size_before_partition =
202+
handler->GetBatchSizeBeforePartition();
203+
kv_partition_info.chunk_indptr = handler->GetChunkIndPtr<IdType>();
204+
kv_partition_info.batch_idx_map = handler->GetBatchIdxMap<IdType>();
205+
kv_partition_info.chunk_start_pos = handler->GetChunkStartPos<IdType>();
206+
kv_partition_info.seq_lens_before_partition =
207+
handler->GetSeqLengthsBeforePartition<IdType>();
208+
}
209+
} else {
210+
std::ostringstream err_msg;
211+
err_msg << "Please call BatchDecodeHandler's BeginForward() before calling "
212+
"BatchDecodeWithPagedKVCacheWrapper()";
213+
throw std::runtime_error(err_msg.str());
214+
}
215+
216+
return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
217+
KV_LAYOUT, POS_ENCODING_MODE, DTypeIn,
218+
DTypeOut, IdType>(
219+
q, q_offset, new_paged_kv, kv_partition_info, o, tmp, lse, sm_scale, rope_scale,
220+
rope_theta, stream);
221+
})})});
234222
return cudaSuccess;
235223
}
236224

Diff for: python/tests/test_batch_prefill_kernels.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
@pytest.mark.parametrize("batch_size", [12, 17])
2525
@pytest.mark.parametrize("kv_len", [54, 97])
2626
@pytest.mark.parametrize("qo_len", [37, 17])
27-
@pytest.mark.parametrize("page_size", [1, 8, 16])
27+
@pytest.mark.parametrize("page_size", [1, 16])
2828
@pytest.mark.parametrize("num_kv_heads", [4])
2929
@pytest.mark.parametrize("num_qo_heads", [4, 32])
3030
@pytest.mark.parametrize("head_dim", [128, 256])

Diff for: python/tests/test_shared_prefix_kernels.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_batch_decode_with_shared_prefix_padded_kv_cache(
5858
@pytest.mark.parametrize("shared_kv_len", [54, 97, 1979])
5959
@pytest.mark.parametrize("num_heads", [8, 16])
6060
@pytest.mark.parametrize("head_dim", [128, 256])
61-
@pytest.mark.parametrize("page_size", [1, 4, 16])
61+
@pytest.mark.parametrize("page_size", [1, 16])
6262
def test_batch_decode_with_shared_prefix_paged_kv_cache(
6363
batch_size, unique_kv_len, shared_kv_len, num_heads, head_dim, page_size
6464
):
@@ -131,7 +131,7 @@ def test_batch_decode_with_shared_prefix_paged_kv_cache(
131131
@pytest.mark.parametrize("num_heads", [8, 16])
132132
@pytest.mark.parametrize("causal", [False, True])
133133
@pytest.mark.parametrize("head_dim", [128, 256])
134-
@pytest.mark.parametrize("page_size", [1, 4, 16])
134+
@pytest.mark.parametrize("page_size", [1, 16])
135135
def test_batch_prefill_with_shared_prefix_paged_kv_cache(
136136
batch_size, unique_kv_len, shared_kv_len, num_heads, causal, head_dim, page_size
137137
):

0 commit comments

Comments
 (0)