@@ -31,11 +31,11 @@ namespace flashinfer {
31
31
32
32
template <bool partition_kv, PosEncodingMode pos_encoding_mode, uint32_t num_stages_smem,
33
33
uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz,
34
- PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn , typename DTypeOut ,
35
- typename IdType>
34
+ PageStorage page_storage, QKVLayout kv_layout, typename DTypeQ , typename DTypeKV ,
35
+ typename DTypeOut, typename IdType>
36
36
__global__ void BatchDecodeWithPagedKVCacheKernel (
37
- DTypeIn * __restrict__ q, IdType* __restrict__ q_offset,
38
- paged_kv_t <page_storage, kv_layout, DTypeIn , IdType> paged_kv,
37
+ DTypeQ * __restrict__ q, IdType* __restrict__ q_offset,
38
+ paged_kv_t <page_storage, kv_layout, DTypeKV , IdType> paged_kv,
39
39
kv_partition_info_t <IdType> kv_partition_info, DTypeOut* __restrict__ o,
40
40
DTypeOut* __restrict__ tmp_v, float * __restrict__ tmp_s, float * __restrict__ lse,
41
41
bool * __restrict__ block_valid_mask, float sm_scale, float rope_rcp_scale,
@@ -86,7 +86,7 @@ std::pair<uint32_t, uint32_t> PartitionPagedKVCacheBinarySearchMinNumPagePerBatc
86
86
* \brief Estimate the temporary buffer size and the maximum grid size for the
87
87
* partition-kv BatchDecodeWithPagedKVCache kernel
88
88
* \tparam page_storage Whether to store indices or pointers of each active page
89
- * \tparam DTypeIn A template type indicates the input data type
89
+ * \tparam DTypeKV A template type indicates the key-value data type
90
90
* \tparam DTypeOut A template type indicates the output data type
91
91
* \tparam IdType A template type indicates the index data type
92
92
* \param tmp_size The estimated temporary buffer size, return 0 if not use partition-kv kernel
@@ -100,27 +100,29 @@ std::pair<uint32_t, uint32_t> PartitionPagedKVCacheBinarySearchMinNumPagePerBatc
100
100
* \return status Indicates whether CUDA calls are successful
101
101
*/
102
102
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
103
- PosEncodingMode POS_ENCODING_MODE, typename DTypeIn , typename DTypeOut, typename IdType>
103
+ PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV , typename DTypeOut, typename IdType>
104
104
cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched (
105
105
uint32_t & tmp_size, uint32_t & max_grid_size, uint32_t & max_num_pages_per_batch,
106
106
uint32_t & new_batch_size, uint32_t batch_size, IdType* kv_indptr, const uint32_t num_qo_heads,
107
107
const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) {
108
- constexpr uint32_t vec_size = std::max (16UL / sizeof (DTypeIn ), HEAD_DIM / 32UL );
108
+ constexpr uint32_t vec_size = std::max (16UL / sizeof (DTypeKV ), HEAD_DIM / 32UL );
109
109
constexpr uint32_t num_stages_smem = 2U ;
110
110
constexpr uint32_t bdx = HEAD_DIM / vec_size;
111
111
static_assert (bdx <= 32 );
112
112
constexpr uint32_t bdy = GROUP_SIZE;
113
113
constexpr uint32_t num_threads = std::max (128U , bdx * bdy);
114
114
constexpr uint32_t bdz = num_threads / (bdx * bdy);
115
- constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof (DTypeIn ) == 1 ? 2U : 4U ) : 1U ;
115
+ constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof (DTypeKV ) == 1 ? 2U : 4U ) : 1U ;
116
116
const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE;
117
117
const uint32_t smem_size =
118
- 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof (DTypeIn ) +
119
- std::max (tile_size_per_bdx * num_threads * sizeof (DTypeIn *), 2 * bdy * bdz * sizeof (float ));
118
+ 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof (DTypeKV ) +
119
+ std::max (tile_size_per_bdx * num_threads * sizeof (DTypeKV *), 2 * bdy * bdz * sizeof (float ));
120
120
121
+ // Note that the dtype of Q should not impact the cudaOccupancyMaxActiveBlocksPerMultiprocessor
122
+ // return, which is why we just use DTypeKV as it simplifies the API.
121
123
auto partition_kv_kernel = BatchDecodeWithPagedKVCacheKernel<
122
124
/* partition_kv=*/ true , POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx, vec_size, bdx,
123
- bdy, bdz, page_storage, kv_layout, DTypeIn , DTypeOut, IdType>;
125
+ bdy, bdz, page_storage, kv_layout, DTypeQ, DTypeKV , DTypeOut, IdType>;
124
126
int num_blocks_per_sm = 0 ;
125
127
int num_sm = 0 ;
126
128
int dev_id = 0 ;
@@ -294,7 +296,7 @@ class BatchDecodeHandler {
294
296
bool * GetBlockValidMask () const { return block_valid_mask_; }
295
297
296
298
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
297
- PosEncodingMode POS_ENCODING_MODE, typename DTypeIn , typename DTypeOut, typename IdType>
299
+ PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV , typename DTypeOut, typename IdType>
298
300
cudaError_t BeginForwardDispatched (void * buffer, size_t workspace_size_in_bytes, IdType* indptr,
299
301
IdType* last_page_len, uint32_t batch_size,
300
302
uint32_t num_qo_heads, uint32_t page_size) {
@@ -303,8 +305,8 @@ class BatchDecodeHandler {
303
305
uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size;
304
306
auto work_estimation_func =
305
307
BatchDecodeWithPagedKVCacheWorkEstimationDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
306
- kv_layout, POS_ENCODING_MODE, DTypeIn,
307
- DTypeOut, IdType>;
308
+ kv_layout, POS_ENCODING_MODE,
309
+ DTypeQ, DTypeKV, DTypeOut, IdType>;
308
310
FLASHINFER_CUDA_CALL (work_estimation_func (tmp_size, max_grid_size, max_num_pages_per_batch,
309
311
new_batch_size, batch_size, indptr, num_qo_heads,
310
312
page_size,
0 commit comments