Skip to content

Commit 5602659

Browse files
authored
feat: Separate Q and KV dtypes for decode (#286)
Closes #285 Modified unit tests pass. May need some extra validation.
1 parent 1250b68 commit 5602659

19 files changed

+473
-371
lines changed

include/flashinfer/attention/decode.cuh

+59-54
Large diffs are not rendered by default.

include/flashinfer/attention/handler.cuh

+16-14
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ namespace flashinfer {
3131

3232
template <bool partition_kv, PosEncodingMode pos_encoding_mode, uint32_t num_stages_smem,
3333
uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz,
34-
PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
35-
typename IdType>
34+
PageStorage page_storage, QKVLayout kv_layout, typename DTypeQ, typename DTypeKV,
35+
typename DTypeOut, typename IdType>
3636
__global__ void BatchDecodeWithPagedKVCacheKernel(
37-
DTypeIn* __restrict__ q, IdType* __restrict__ q_offset,
38-
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
37+
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
38+
paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
3939
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
4040
DTypeOut* __restrict__ tmp_v, float* __restrict__ tmp_s, float* __restrict__ lse,
4141
bool* __restrict__ block_valid_mask, float sm_scale, float rope_rcp_scale,
@@ -86,7 +86,7 @@ std::pair<uint32_t, uint32_t> PartitionPagedKVCacheBinarySearchMinNumPagePerBatc
8686
* \brief Estimate the temporary buffer size and the maximum grid size for the
8787
* partition-kv BatchDecodeWithPagedKVCache kernel
8888
* \tparam page_storage Whether to store indices or pointers of each active page
89-
* \tparam DTypeIn A template type indicates the input data type
89+
* \tparam DTypeKV A template type indicates the key-value data type
9090
* \tparam DTypeOut A template type indicates the output data type
9191
* \tparam IdType A template type indicates the index data type
9292
* \param tmp_size The estimated temporary buffer size, return 0 if not use partition-kv kernel
@@ -100,27 +100,29 @@ std::pair<uint32_t, uint32_t> PartitionPagedKVCacheBinarySearchMinNumPagePerBatc
100100
* \return status Indicates whether CUDA calls are successful
101101
*/
102102
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
103-
PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
103+
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
104104
cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
105105
uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch,
106106
uint32_t& new_batch_size, uint32_t batch_size, IdType* kv_indptr, const uint32_t num_qo_heads,
107107
const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) {
108-
constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL);
108+
constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL);
109109
constexpr uint32_t num_stages_smem = 2U;
110110
constexpr uint32_t bdx = HEAD_DIM / vec_size;
111111
static_assert(bdx <= 32);
112112
constexpr uint32_t bdy = GROUP_SIZE;
113113
constexpr uint32_t num_threads = std::max(128U, bdx * bdy);
114114
constexpr uint32_t bdz = num_threads / (bdx * bdy);
115-
constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeIn) == 1 ? 2U : 4U) : 1U;
115+
constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U;
116116
const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE;
117117
const uint32_t smem_size =
118-
2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeIn) +
119-
std::max(tile_size_per_bdx * num_threads * sizeof(DTypeIn*), 2 * bdy * bdz * sizeof(float));
118+
2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) +
119+
std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float));
120120

121+
// Note that the dtype of Q should not impact the cudaOccupancyMaxActiveBlocksPerMultiprocessor
122+
// return, which is why we just use DTypeKV as it simplifies the API.
121123
auto partition_kv_kernel = BatchDecodeWithPagedKVCacheKernel<
122124
/*partition_kv=*/true, POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx, vec_size, bdx,
123-
bdy, bdz, page_storage, kv_layout, DTypeIn, DTypeOut, IdType>;
125+
bdy, bdz, page_storage, kv_layout, DTypeQ, DTypeKV, DTypeOut, IdType>;
124126
int num_blocks_per_sm = 0;
125127
int num_sm = 0;
126128
int dev_id = 0;
@@ -294,7 +296,7 @@ class BatchDecodeHandler {
294296
bool* GetBlockValidMask() const { return block_valid_mask_; }
295297

296298
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
297-
PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
299+
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
298300
cudaError_t BeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes, IdType* indptr,
299301
IdType* last_page_len, uint32_t batch_size,
300302
uint32_t num_qo_heads, uint32_t page_size) {
@@ -303,8 +305,8 @@ class BatchDecodeHandler {
303305
uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size;
304306
auto work_estimation_func =
305307
BatchDecodeWithPagedKVCacheWorkEstimationDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
306-
kv_layout, POS_ENCODING_MODE, DTypeIn,
307-
DTypeOut, IdType>;
308+
kv_layout, POS_ENCODING_MODE,
309+
DTypeQ, DTypeKV, DTypeOut, IdType>;
308310
FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch,
309311
new_batch_size, batch_size, indptr, num_qo_heads,
310312
page_size,

include/flashinfer/decode_attention_decl.cuh

+11-11
Original file line numberDiff line numberDiff line change
@@ -29,35 +29,35 @@
2929
namespace flashinfer {
3030

3131
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT,
32-
PosEncodingMode pos_encoding_mode, typename DTypeIn, typename DTypeOut>
33-
cudaError_t SingleDecodeWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o,
32+
PosEncodingMode pos_encoding_mode, typename DTypeQ, typename DTypeKV, typename DTypeOut>
33+
cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o,
3434
DTypeOut* tmp, uint32_t num_kv_heads,
3535
uint32_t seq_len, float sm_scale, float rope_scale,
3636
float rope_theta, cudaStream_t stream);
3737

3838
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
39-
PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
39+
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
4040
cudaError_t BatchDecodeWithPagedKVCacheDispatched(
41-
DTypeIn* q, IdType* q_offset, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
41+
DTypeQ* q, IdType* q_offset, paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
4242
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s,
4343
float* lse, bool* block_valid_mask, uint32_t padded_batch_size, float sm_scale,
4444
float rope_scale, float rope_theta, cudaStream_t stream);
4545

4646
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT,
47-
PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut>
48-
cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o,
47+
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut>
48+
cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o,
4949
DTypeOut* tmp, float* lse, uint32_t batch_size,
5050
uint32_t padded_kv_len, uint32_t num_qo_heads,
5151
float sm_scale, float rope_scale,
5252
float rope_theta, cudaStream_t stream);
5353

5454
template <PageStorage page_storage, QKVLayout KV_LAYOUT, uint32_t GROUP_SIZE, uint32_t HEAD_DIM,
55-
PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
55+
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
5656
cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched(
57-
BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset,
58-
paged_kv_t<page_storage, KV_LAYOUT, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
57+
BatchDecodeHandler* handler, DTypeQ* q, IdType* q_offset,
58+
paged_kv_t<page_storage, KV_LAYOUT, DTypeKV, IdType> paged_kv, DTypeOut* o, float* lse,
5959
float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) {
60-
paged_kv_t<page_storage, KV_LAYOUT, DTypeIn, IdType> new_paged_kv = paged_kv;
60+
paged_kv_t<page_storage, KV_LAYOUT, DTypeKV, IdType> new_paged_kv = paged_kv;
6161
kv_partition_info_t<IdType> kv_partition_info;
6262
DTypeOut* tmp_v = handler->GetTempV<DTypeOut>();
6363
float* tmp_s = handler->GetTempS<float>();
@@ -82,7 +82,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched(
8282
}
8383

8484
return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage, KV_LAYOUT,
85-
POS_ENCODING_MODE, DTypeIn, DTypeOut, IdType>(
85+
POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>(
8686
q, q_offset, new_paged_kv, kv_partition_info, o, tmp_v, tmp_s, lse,
8787
handler->GetBlockValidMask(), handler->GetPaddedBatchSize(), sm_scale, rope_scale, rope_theta,
8888
stream);

0 commit comments

Comments
 (0)