@@ -166,51 +166,15 @@ cudaError_t BatchDecodeWithPagedKVCache(
166
166
* \note This wrapper function should be only called after we call BeginForward function in the
167
167
* BatchDecodeHandler.
168
168
*/
169
- template <PageStorage page_storage, QKVLayout kv_layout, uint32_t GROUP_SIZE, uint32_t HEAD_DIM,
170
- PosEncodingMode pos_encoding_mode, typename DTypeIn, typename DTypeOut, typename IdType>
171
- cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched (
172
- BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset,
173
- paged_kv_t <page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float * lse,
174
- float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) {
175
- paged_kv_t <page_storage, kv_layout, DTypeIn, IdType> new_paged_kv = paged_kv;
176
- kv_partition_info_t <IdType> kv_partition_info;
177
- DTypeOut* tmp = handler->GetTempFloatBuffer <DTypeOut>();
178
-
179
- if (handler->IsForwardStarted ()) {
180
- if (tmp != nullptr ) {
181
- // create auxiliary information for cooperative kernels
182
- new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition ();
183
- new_paged_kv.indptr = handler->GetNewIndPtr <IdType>();
184
- new_paged_kv.last_page_len = handler->GetNewLastPageLen <IdType>();
185
- kv_partition_info.batch_size_before_partition = handler->GetBatchSizeBeforePartition ();
186
- kv_partition_info.chunk_indptr = handler->GetChunkIndPtr <IdType>();
187
- kv_partition_info.batch_idx_map = handler->GetBatchIdxMap <IdType>();
188
- kv_partition_info.chunk_start_pos = handler->GetChunkStartPos <IdType>();
189
- kv_partition_info.seq_lens_before_partition = handler->GetSeqLengthsBeforePartition <IdType>();
190
- }
191
- } else {
192
- std::ostringstream err_msg;
193
- err_msg << " Please call BatchDecodeHandler's BeginForward() before calling "
194
- " BatchDecodeWithPagedKVCacheWrapper()" ;
195
- throw std::runtime_error (err_msg.str ());
196
- }
197
-
198
- return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage, kv_layout,
199
- pos_encoding_mode, DTypeIn, DTypeOut, IdType>(
200
- q, q_offset, new_paged_kv, kv_partition_info, o, tmp, lse, sm_scale, rope_scale, rope_theta,
201
- stream);
202
- return cudaSuccess;
203
- }
204
-
205
- template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
169
+ template <PageStorage page_storage, QKVLayout KV_LAYOUT, typename DTypeIn, typename DTypeOut,
206
170
typename IdType>
207
171
cudaError_t BatchDecodeWithPagedKVCacheWrapper (
208
172
BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset,
209
- paged_kv_t <page_storage, kv_layout , DTypeIn, IdType> paged_kv, DTypeOut* o, float * lse,
173
+ paged_kv_t <page_storage, KV_LAYOUT , DTypeIn, IdType> paged_kv, DTypeOut* o, float * lse,
210
174
uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone ,
211
175
std::optional<float > maybe_sm_scale = std::nullopt, float rope_scale = 1 .f,
212
176
float rope_theta = 1e4 , cudaStream_t stream = nullptr ) {
213
- const float sm_scale = maybe_sm_scale.value_or (1 .f / std::sqrt (float (paged_kv.head_dim )));
177
+ float sm_scale = maybe_sm_scale.value_or (1 .f / std::sqrt (float (paged_kv.head_dim )));
214
178
const uint32_t num_kv_heads = paged_kv.num_heads ;
215
179
if (num_qo_heads % num_kv_heads != 0 ) {
216
180
std::ostringstream err_msg;
@@ -219,18 +183,42 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper(
219
183
throw std::invalid_argument (err_msg.str ());
220
184
}
221
185
222
- // DISPATCH_GQA_GROUP_SIZE(
223
- // num_qo_heads / num_kv_heads, GROUP_SIZE,
224
- // {DISPATCH_HEAD_DIM(
225
- // paged_kv.head_dim, HEAD_DIM,
226
- // {DISPATCH_POS_ENCODING_MODE(
227
- // pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, {
228
- // return BatchDecodeWithPagedKVCacheWrapperDispatched<
229
- // page_storage, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, DTypeIn,
230
- // DTypeOut, IdType>(handler, q, q_offset, paged_kv, o, lse, sm_scale,
231
- // rope_scale,
232
- // rope_theta, stream);
233
- // })})})});
186
+ DISPATCH_GQA_GROUP_SIZE (
187
+ num_qo_heads / num_kv_heads, GROUP_SIZE,
188
+ {DISPATCH_HEAD_DIM (
189
+ paged_kv.head_dim , HEAD_DIM,
190
+ {DISPATCH_POS_ENCODING_MODE (pos_encoding_mode, POS_ENCODING_MODE, {
191
+ paged_kv_t <page_storage, KV_LAYOUT, DTypeIn, IdType> new_paged_kv = paged_kv;
192
+ kv_partition_info_t <IdType> kv_partition_info;
193
+ DTypeOut* tmp = handler->GetTempFloatBuffer <DTypeOut>();
194
+
195
+ if (handler->IsForwardStarted ()) {
196
+ if (tmp != nullptr ) {
197
+ // create auxiliary information for cooperative kernels
198
+ new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition ();
199
+ new_paged_kv.indptr = handler->GetNewIndPtr <IdType>();
200
+ new_paged_kv.last_page_len = handler->GetNewLastPageLen <IdType>();
201
+ kv_partition_info.batch_size_before_partition =
202
+ handler->GetBatchSizeBeforePartition ();
203
+ kv_partition_info.chunk_indptr = handler->GetChunkIndPtr <IdType>();
204
+ kv_partition_info.batch_idx_map = handler->GetBatchIdxMap <IdType>();
205
+ kv_partition_info.chunk_start_pos = handler->GetChunkStartPos <IdType>();
206
+ kv_partition_info.seq_lens_before_partition =
207
+ handler->GetSeqLengthsBeforePartition <IdType>();
208
+ }
209
+ } else {
210
+ std::ostringstream err_msg;
211
+ err_msg << " Please call BatchDecodeHandler's BeginForward() before calling "
212
+ " BatchDecodeWithPagedKVCacheWrapper()" ;
213
+ throw std::runtime_error (err_msg.str ());
214
+ }
215
+
216
+ return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
217
+ KV_LAYOUT, POS_ENCODING_MODE, DTypeIn,
218
+ DTypeOut, IdType>(
219
+ q, q_offset, new_paged_kv, kv_partition_info, o, tmp, lse, sm_scale, rope_scale,
220
+ rope_theta, stream);
221
+ })})});
234
222
return cudaSuccess;
235
223
}
236
224
0 commit comments