Skip to content

Commit 60459e4

Browse files
authored
refactor: refactor decode handler (#294)
Change the use of an optional `fixed_grid_size` to `padded_batch_size`.
1 parent 4c5e28b commit 60459e4

File tree

7 files changed

+28
-31
lines changed

7 files changed

+28
-31
lines changed

include/flashinfer/attention/decode.cuh

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -849,13 +849,11 @@ template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVL
849849
cudaError_t BatchDecodeWithPagedKVCacheDispatched(
850850
DTypeIn* q, IdType* q_offset, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
851851
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s,
852-
float* lse, bool* block_valid_mask, std::optional<uint32_t> fixed_grid_size, float sm_scale,
852+
float* lse, bool* block_valid_mask, uint32_t padded_batch_size, float sm_scale,
853853
float rope_scale, float rope_theta, cudaStream_t stream) {
854854
const float rope_rcp_scale = 1.f / rope_scale;
855855
const float rope_rcp_theta = 1.f / rope_theta;
856856
const uint32_t num_kv_heads = paged_kv.num_heads;
857-
const uint32_t batch_size = paged_kv.batch_size;
858-
const uint32_t grid_size = fixed_grid_size.value_or(batch_size * num_kv_heads);
859857
const uint32_t num_qo_heads = num_kv_heads * GROUP_SIZE;
860858

861859
constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL);
@@ -872,7 +870,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
872870

873871
if (tmp_v == nullptr) {
874872
// do not use partition-kv kernel
875-
dim3 nblks(grid_size / num_kv_heads, num_kv_heads);
873+
dim3 nblks(padded_batch_size, num_kv_heads);
876874
dim3 nthrs(bdx, bdy, bdz);
877875
auto kernel =
878876
BatchDecodeWithPagedKVCacheKernel</*partition_kv=*/false, POS_ENCODING_MODE,
@@ -913,7 +911,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
913911
(void*)&sm_scale,
914912
(void*)&rope_rcp_scale,
915913
(void*)&rope_rcp_theta};
916-
dim3 nblks(grid_size / num_kv_heads, num_kv_heads);
914+
dim3 nblks(padded_batch_size, num_kv_heads);
917915
dim3 nthrs(bdx, bdy, bdz);
918916
FLASHINFER_CUDA_CALL(
919917
cudaLaunchKernel((void*)partition_kv_kernel, nblks, nthrs, args, smem_size, stream));

include/flashinfer/attention/handler.cuh

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ class BatchDecodeHandler {
289289
return (IdType*)seq_lengths_before_partition_;
290290
}
291291

292-
uint32_t GetFixedGridSize() const { return fixed_grid_size_; }
292+
uint32_t GetPaddedBatchSize() const { return padded_batch_size_; }
293293

294294
bool* GetBlockValidMask() const { return block_valid_mask_; }
295295

@@ -320,7 +320,7 @@ class BatchDecodeHandler {
320320
}
321321
size_t padded_batch_size_after_partition = max_grid_size / num_kv_heads;
322322
if (tmp_size > 0) {
323-
fixed_grid_size_ = padded_batch_size_after_partition * num_kv_heads;
323+
padded_batch_size_ = padded_batch_size_after_partition;
324324
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
325325
tmp_v_ = allocator.aligned_alloc<void>(
326326
num_qo_heads * padded_batch_size_after_partition * HEAD_DIM * sizeof(DTypeOut), 16);
@@ -367,11 +367,13 @@ class BatchDecodeHandler {
367367
/*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_));
368368
} else {
369369
block_valid_mask_ = nullptr;
370-
fixed_grid_size_ = num_kv_heads * batch_size;
370+
padded_batch_size_ = num_kv_heads * batch_size;
371371
}
372372
} else {
373373
// NOTE(Zihao): we don't use block_valid_mask when CUDAGraph is disabled.
374374
block_valid_mask_ = nullptr;
375+
// do not pad the batch size when not using CUDAGraph
376+
padded_batch_size_ = batch_size_after_partition_;
375377
if (tmp_size > 0) {
376378
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
377379
tmp_v_ = allocator.aligned_alloc<void>(tmp_size, 16);
@@ -418,7 +420,7 @@ class BatchDecodeHandler {
418420

419421
cudaError_t EndForward() {
420422
forward_started_ = false;
421-
fixed_grid_size_ = 0;
423+
padded_batch_size_ = 0;
422424
batch_size_before_partition_ = 0;
423425
batch_size_after_partition_ = 0;
424426
block_valid_mask_ = nullptr;
@@ -492,7 +494,7 @@ class BatchDecodeHandler {
492494
void* seq_lengths_before_partition_;
493495
bool forward_started_;
494496
bool cuda_graph_enabled_;
495-
uint32_t fixed_grid_size_;
497+
uint32_t padded_batch_size_;
496498
uint32_t fixed_batch_size_;
497499
cudaStream_t stream_;
498500
};

include/flashinfer/decode_attention_decl.cuh

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVL
4040
cudaError_t BatchDecodeWithPagedKVCacheDispatched(
4141
DTypeIn* q, IdType* q_offset, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
4242
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s,
43-
float* lse, bool* block_valid_mask, std::optional<uint32_t> fixed_grid_size, float sm_scale,
43+
float* lse, bool* block_valid_mask, uint32_t padded_batch_size, float sm_scale,
4444
float rope_scale, float rope_theta, cudaStream_t stream);
4545

4646
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT,
@@ -84,10 +84,8 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched(
8484
return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage, KV_LAYOUT,
8585
POS_ENCODING_MODE, DTypeIn, DTypeOut, IdType>(
8686
q, q_offset, new_paged_kv, kv_partition_info, o, tmp_v, tmp_s, lse,
87-
handler->GetBlockValidMask(),
88-
(handler->IsCUDAGraphEnabled() ? std::optional<uint32_t>(handler->GetFixedGridSize())
89-
: std::nullopt),
90-
sm_scale, rope_scale, rope_theta, stream);
87+
handler->GetBlockValidMask(), handler->GetPaddedBatchSize(), sm_scale, rope_scale, rope_theta,
88+
stream);
9189
}
9290

9391
} // namespace flashinfer

python/generate_batch_paged_decode_inst.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ def get_cu_file_str(
3939
paged_kv_t<page_storage, {kv_layout}, {dtype_in}, {idtype}> paged_kv,
4040
kv_partition_info_t<{idtype}> kv_partition_info,
4141
{dtype_out}* o, {dtype_out}* tmp_v, float* tmp_s, float* lse,
42-
bool* block_valid_mask,
43-
std::optional<uint32_t> fixed_grid_size,
42+
bool* block_valid_mask, uint32_t padded_batch_size,
4443
float sm_scale, float rope_scale,
4544
float rope_theta, cudaStream_t stream);
4645

src/bench_batch_decode.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,9 @@ void bench_flashinfer_batch_decode(nvbench::state& state) {
8989
} else {
9090
state.exec([&](nvbench::launch&) {
9191
cudaError_t status =
92-
BatchDecodeWithPagedKVCache<PageStorage::kIndices, kv_layout, T, T, int32_t>(
92+
BatchDecodeWithPagedKVCacheNoSplitKV<PageStorage::kIndices, kv_layout, T, T, int32_t>(
9393
thrust::raw_pointer_cast(q.data()), /*q_offset=*/nullptr, paged_kv,
94-
kv_partition_info_t<int32_t>(), thrust::raw_pointer_cast(o.data()), /*tmp_v=*/nullptr,
95-
/*tmp_s=*/nullptr,
94+
kv_partition_info_t<int32_t>(), thrust::raw_pointer_cast(o.data()),
9695
/*lse=*/nullptr, num_qo_heads, pos_encoding_mode);
9796
if (status != cudaSuccess) {
9897
state.skip("CUDA error: " + std::string(cudaGetErrorString(status)));

src/flashinfer_ops.cuh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,10 @@ cudaError_t BatchDecodeWithPaddedKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTy
209209

210210
template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
211211
typename IdType>
212-
cudaError_t BatchDecodeWithPagedKVCache(
212+
cudaError_t BatchDecodeWithPagedKVCacheNoSplitKV(
213213
DTypeIn* q, IdType* q_offset, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
214-
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s,
215-
float* lse, uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone,
214+
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, float* lse, uint32_t num_qo_heads,
215+
PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone,
216216
std::optional<float> maybe_sm_scale = std::nullopt, float rope_scale = 1.f,
217217
float rope_theta = 1e4, cudaStream_t stream = nullptr) {
218218
const uint32_t num_kv_heads = paged_kv.num_heads;
@@ -233,9 +233,10 @@ cudaError_t BatchDecodeWithPagedKVCache(
233233
return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
234234
kv_layout, POS_ENCODING_MODE, DTypeIn,
235235
DTypeOut, IdType>(
236-
q, q_offset, paged_kv, kv_partition_info, o, tmp_v, tmp_s, lse,
237-
/*block_valid_mask=*/nullptr, std::nullopt, sm_scale, rope_scale, rope_theta,
238-
stream);
236+
q, q_offset, paged_kv, kv_partition_info, o, /*tmp_v=*/nullptr, /*tmp_s=*/nullptr,
237+
lse,
238+
/*block_valid_mask=*/nullptr, /*padded_batch_size=*/paged_kv.batch_size, sm_scale,
239+
rope_scale, rope_theta, stream);
239240
})})});
240241

241242
return cudaSuccess;

src/test_batch_decode.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,11 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si
107107

108108
if (!cooperative) {
109109
// use non-cooperative kernel
110-
cudaError_t status =
111-
flashinfer::BatchDecodeWithPagedKVCache<PageStorage::kIndices, kv_layout, T, T, int32_t>(
112-
thrust::raw_pointer_cast(q_device.data()), /*q_offset=*/nullptr, paged_kv,
113-
kv_partition_info_t<int32_t>(), thrust::raw_pointer_cast(o_device.data()),
114-
/*tmp_v=*/nullptr, /*tmp_s=*/nullptr, /*lse=*/nullptr, num_qo_heads, pos_encoding_mode);
110+
cudaError_t status = flashinfer::BatchDecodeWithPagedKVCacheNoSplitKV<PageStorage::kIndices,
111+
kv_layout, T, T, int32_t>(
112+
thrust::raw_pointer_cast(q_device.data()), /*q_offset=*/nullptr, paged_kv,
113+
kv_partition_info_t<int32_t>(), thrust::raw_pointer_cast(o_device.data()),
114+
/*lse=*/nullptr, num_qo_heads, pos_encoding_mode);
115115
EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status));
116116
} else {
117117
cudaError_t status = flashinfer::BatchDecodeWithPagedKVCacheWrapper<PageStorage::kIndices,

0 commit comments

Comments
 (0)