@@ -297,121 +297,125 @@ class BatchDecodeHandler {
297
297
298
298
bool * GetBlockValidMask () const { return block_valid_mask_; }
299
299
300
- template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage,
301
- LogitsPostHook LOGITS_POST_HOOK, QKVLayout kv_layout, PosEncodingMode POS_ENCODING_MODE,
302
- typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
300
+ template <uint32_t HEAD_DIM, PageStorage page_storage, LogitsPostHook LOGITS_POST_HOOK ,
301
+ QKVLayout kv_layout, PosEncodingMode POS_ENCODING_MODE, typename DTypeQ ,
302
+ typename DTypeKV, typename DTypeOut, typename IdType>
303
303
cudaError_t BeginForwardDispatched (void * buffer, size_t workspace_size_in_bytes, IdType* indptr,
304
304
IdType* last_page_len, uint32_t batch_size,
305
- uint32_t num_qo_heads, uint32_t page_size) {
305
+ uint32_t num_qo_heads, uint32_t num_kv_heads,
306
+ uint32_t page_size) {
306
307
batch_size_before_partition_ = batch_size;
307
- uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE;
308
308
uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size;
309
- auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched<
310
- GROUP_SIZE, HEAD_DIM, page_storage, LOGITS_POST_HOOK, kv_layout, POS_ENCODING_MODE, DTypeQ,
311
- DTypeKV, DTypeOut, IdType>;
312
- FLASHINFER_CUDA_CALL (work_estimation_func (tmp_size, max_grid_size, max_num_pages_per_batch,
313
- new_batch_size, batch_size, indptr, num_qo_heads,
314
- page_size,
315
- /* enable_cuda_graph=*/ IsCUDAGraphEnabled (), stream_));
316
- batch_size_after_partition_ = new_batch_size;
317
- if (IsCUDAGraphEnabled ()) {
318
- if (batch_size != fixed_batch_size_) {
319
- std::ostringstream err_msg;
320
- err_msg << " The running batch size " << batch_size
321
- << " is not compatible with the fixed batch size " << fixed_batch_size_
322
- << " initialized for CUDAGraph" ;
323
- throw std::runtime_error (err_msg.str ());
324
- }
325
- size_t padded_batch_size = max_grid_size / num_kv_heads;
326
- if (tmp_size > 0 ) {
327
- padded_batch_size_ = padded_batch_size;
328
- AlignedAllocator allocator (buffer, workspace_size_in_bytes);
329
- tmp_v_ = allocator.aligned_alloc <void >(
330
- num_qo_heads * padded_batch_size * HEAD_DIM * sizeof (DTypeOut), 16 );
331
- tmp_s_ =
332
- allocator.aligned_alloc <void >(num_qo_heads * padded_batch_size * 2 * sizeof (float ), 16 );
333
- new_indptr_ = allocator.aligned_alloc <void >((padded_batch_size + 1 ) * sizeof (IdType), 16 );
334
-
335
- void * new_indptr_h_ = page_locked_buffer_;
336
- new_last_page_len_ = allocator.aligned_alloc <void >(padded_batch_size * sizeof (IdType), 16 );
337
- void * new_last_page_len_h_ =
338
- (char *)page_locked_buffer_ + ((char *)new_last_page_len_ - (char *)new_indptr_);
339
- chunk_indptr_ = allocator.aligned_alloc <void >((padded_batch_size + 1 ) * sizeof (IdType), 16 );
340
- void * chunk_indptr_h_ =
341
- (char *)page_locked_buffer_ + ((char *)chunk_indptr_ - (char *)new_indptr_);
342
- batch_idx_map_ = allocator.aligned_alloc <void >(padded_batch_size * sizeof (IdType), 16 );
343
- void * batch_idx_map_h_ =
344
- (char *)page_locked_buffer_ + ((char *)batch_idx_map_ - (char *)new_indptr_);
345
- chunk_start_pos_ = allocator.aligned_alloc <void >(padded_batch_size * sizeof (IdType), 16 );
346
- void * chunk_start_pos_h_ =
347
- (char *)page_locked_buffer_ + ((char *)chunk_start_pos_ - (char *)new_indptr_);
348
- seq_lengths_before_partition_ =
349
- allocator.aligned_alloc <void >(padded_batch_size * sizeof (IdType), 16 );
350
- void * seq_lengths_before_partition_h_ =
351
- (char *)page_locked_buffer_ +
352
- ((char *)seq_lengths_before_partition_ - (char *)new_indptr_);
353
- block_valid_mask_ = allocator.aligned_alloc <bool >(padded_batch_size * sizeof (bool ), 16 );
354
- bool * block_valid_mask_h_ =
355
- (bool *)page_locked_buffer_ + ((bool *)block_valid_mask_ - (bool *)new_indptr_);
356
- std::fill (block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0 );
357
-
358
- size_t num_bytes_to_copy = (char *)allocator.ptr - (char *)new_indptr_;
359
- FLASHINFER_CUDA_CALL (PartitionPagedKVCacheComputeAuxiliaryInfo (
360
- max_num_pages_per_batch, batch_size, padded_batch_size, page_size, indptr,
361
- last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_,
362
- (IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
363
- (IdType*)seq_lengths_before_partition_h_, block_valid_mask_h_,
364
- /* device_buffer=*/ new_indptr_,
365
- /* host_buffer=*/ page_locked_buffer_, num_bytes_to_copy, stream_));
309
+ DISPATCH_GQA_GROUP_SIZE (num_qo_heads / num_kv_heads, GROUP_SIZE, {
310
+ auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched<
311
+ GROUP_SIZE, HEAD_DIM, page_storage, LOGITS_POST_HOOK, kv_layout, POS_ENCODING_MODE,
312
+ DTypeQ, DTypeKV, DTypeOut, IdType>;
313
+ FLASHINFER_CUDA_CALL (
314
+ work_estimation_func (tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size,
315
+ batch_size, indptr, num_qo_heads, page_size,
316
+ /* enable_cuda_graph=*/ IsCUDAGraphEnabled (), stream_));
317
+ batch_size_after_partition_ = new_batch_size;
318
+ if (IsCUDAGraphEnabled ()) {
319
+ if (batch_size != fixed_batch_size_) {
320
+ std::ostringstream err_msg;
321
+ err_msg << " The running batch size " << batch_size
322
+ << " is not compatible with the fixed batch size " << fixed_batch_size_
323
+ << " initialized for CUDAGraph" ;
324
+ throw std::runtime_error (err_msg.str ());
325
+ }
326
+ size_t padded_batch_size = max_grid_size / num_kv_heads;
327
+ if (tmp_size > 0 ) {
328
+ padded_batch_size_ = padded_batch_size;
329
+ AlignedAllocator allocator (buffer, workspace_size_in_bytes);
330
+ tmp_v_ = allocator.aligned_alloc <void >(
331
+ num_qo_heads * padded_batch_size * HEAD_DIM * sizeof (DTypeOut), 16 );
332
+ tmp_s_ = allocator.aligned_alloc <void >(
333
+ num_qo_heads * padded_batch_size * 2 * sizeof (float ), 16 );
334
+ new_indptr_ = allocator.aligned_alloc <void >((padded_batch_size + 1 ) * sizeof (IdType), 16 );
335
+
336
+ void * new_indptr_h_ = page_locked_buffer_;
337
+ new_last_page_len_ =
338
+ allocator.aligned_alloc <void >(padded_batch_size * sizeof (IdType), 16 );
339
+ void * new_last_page_len_h_ =
340
+ (char *)page_locked_buffer_ + ((char *)new_last_page_len_ - (char *)new_indptr_);
341
+ chunk_indptr_ =
342
+ allocator.aligned_alloc <void >((padded_batch_size + 1 ) * sizeof (IdType), 16 );
343
+ void * chunk_indptr_h_ =
344
+ (char *)page_locked_buffer_ + ((char *)chunk_indptr_ - (char *)new_indptr_);
345
+ batch_idx_map_ = allocator.aligned_alloc <void >(padded_batch_size * sizeof (IdType), 16 );
346
+ void * batch_idx_map_h_ =
347
+ (char *)page_locked_buffer_ + ((char *)batch_idx_map_ - (char *)new_indptr_);
348
+ chunk_start_pos_ = allocator.aligned_alloc <void >(padded_batch_size * sizeof (IdType), 16 );
349
+ void * chunk_start_pos_h_ =
350
+ (char *)page_locked_buffer_ + ((char *)chunk_start_pos_ - (char *)new_indptr_);
351
+ seq_lengths_before_partition_ =
352
+ allocator.aligned_alloc <void >(padded_batch_size * sizeof (IdType), 16 );
353
+ void * seq_lengths_before_partition_h_ =
354
+ (char *)page_locked_buffer_ +
355
+ ((char *)seq_lengths_before_partition_ - (char *)new_indptr_);
356
+ block_valid_mask_ = allocator.aligned_alloc <bool >(padded_batch_size * sizeof (bool ), 16 );
357
+ bool * block_valid_mask_h_ =
358
+ (bool *)page_locked_buffer_ + ((bool *)block_valid_mask_ - (bool *)new_indptr_);
359
+ std::fill (block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0 );
360
+
361
+ size_t num_bytes_to_copy = (char *)allocator.ptr - (char *)new_indptr_;
362
+ FLASHINFER_CUDA_CALL (PartitionPagedKVCacheComputeAuxiliaryInfo (
363
+ max_num_pages_per_batch, batch_size, padded_batch_size, page_size, indptr,
364
+ last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_,
365
+ (IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
366
+ (IdType*)seq_lengths_before_partition_h_, block_valid_mask_h_,
367
+ /* device_buffer=*/ new_indptr_,
368
+ /* host_buffer=*/ page_locked_buffer_, num_bytes_to_copy, stream_));
369
+ } else {
370
+ block_valid_mask_ = nullptr ;
371
+ padded_batch_size_ = batch_size;
372
+ }
366
373
} else {
374
+ // NOTE(Zihao): we don't use block_valid_mask when CUDAGraph is disabled.
367
375
block_valid_mask_ = nullptr ;
368
- padded_batch_size_ = batch_size;
369
- }
370
- } else {
371
- // NOTE(Zihao): we don't use block_valid_mask when CUDAGraph is disabled.
372
- block_valid_mask_ = nullptr ;
373
- // do not pad the batch size when not using CUDAGraph
374
- padded_batch_size_ = batch_size_after_partition_;
375
- if (tmp_size > 0 ) {
376
- AlignedAllocator allocator (buffer, workspace_size_in_bytes);
377
- tmp_v_ = allocator.aligned_alloc <void >(tmp_size, 16 );
378
- tmp_s_ = (char *)tmp_v_ +
379
- num_qo_heads * batch_size_after_partition_ * HEAD_DIM * sizeof (DTypeOut);
380
- new_indptr_ =
381
- allocator.aligned_alloc <void >((batch_size_after_partition_ + 1 ) * sizeof (IdType), 16 );
382
- void * new_indptr_h_ = page_locked_buffer_;
383
- new_last_page_len_ =
384
- allocator.aligned_alloc <void >(batch_size_after_partition_ * sizeof (IdType), 16 );
385
- void * new_last_page_len_h_ =
386
- (char *)page_locked_buffer_ + ((char *)new_last_page_len_ - (char *)new_indptr_);
387
- chunk_indptr_ =
388
- allocator.aligned_alloc <void >((batch_size_before_partition_ + 1 ) * sizeof (IdType), 16 );
389
- void * chunk_indptr_h_ =
390
- (char *)page_locked_buffer_ + ((char *)chunk_indptr_ - (char *)new_indptr_);
391
- batch_idx_map_ =
392
- allocator.aligned_alloc <void >(batch_size_after_partition_ * sizeof (IdType), 16 );
393
- void * batch_idx_map_h_ =
394
- (char *)page_locked_buffer_ + ((char *)batch_idx_map_ - (char *)new_indptr_);
395
- chunk_start_pos_ =
396
- allocator.aligned_alloc <void >(batch_size_after_partition_ * sizeof (IdType), 16 );
397
- void * chunk_start_pos_h_ =
398
- (char *)page_locked_buffer_ + ((char *)chunk_start_pos_ - (char *)new_indptr_);
399
- seq_lengths_before_partition_ =
400
- allocator.aligned_alloc <void >(batch_size_after_partition_ * sizeof (IdType), 16 );
401
- void * seq_lengths_before_partition_h_ =
402
- (char *)page_locked_buffer_ +
403
- ((char *)seq_lengths_before_partition_ - (char *)new_indptr_);
404
- size_t num_bytes_to_copy = (char *)allocator.ptr - (char *)new_indptr_;
405
- FLASHINFER_CUDA_CALL (PartitionPagedKVCacheComputeAuxiliaryInfo (
406
- max_num_pages_per_batch, batch_size, batch_size_after_partition_, page_size, indptr,
407
- last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_,
408
- (IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
409
- (IdType*)seq_lengths_before_partition_h_,
410
- /* block_valid_mask_h=*/ nullptr ,
411
- /* device_buffer=*/ new_indptr_,
412
- /* host_buffer=*/ page_locked_buffer_, num_bytes_to_copy, stream_));
376
+ // do not pad the batch size when not using CUDAGraph
377
+ padded_batch_size_ = batch_size_after_partition_;
378
+ if (tmp_size > 0 ) {
379
+ AlignedAllocator allocator (buffer, workspace_size_in_bytes);
380
+ tmp_v_ = allocator.aligned_alloc <void >(tmp_size, 16 );
381
+ tmp_s_ = (char *)tmp_v_ +
382
+ num_qo_heads * batch_size_after_partition_ * HEAD_DIM * sizeof (DTypeOut);
383
+ new_indptr_ =
384
+ allocator.aligned_alloc <void >((batch_size_after_partition_ + 1 ) * sizeof (IdType), 16 );
385
+ void * new_indptr_h_ = page_locked_buffer_;
386
+ new_last_page_len_ =
387
+ allocator.aligned_alloc <void >(batch_size_after_partition_ * sizeof (IdType), 16 );
388
+ void * new_last_page_len_h_ =
389
+ (char *)page_locked_buffer_ + ((char *)new_last_page_len_ - (char *)new_indptr_);
390
+ chunk_indptr_ = allocator.aligned_alloc <void >(
391
+ (batch_size_before_partition_ + 1 ) * sizeof (IdType), 16 );
392
+ void * chunk_indptr_h_ =
393
+ (char *)page_locked_buffer_ + ((char *)chunk_indptr_ - (char *)new_indptr_);
394
+ batch_idx_map_ =
395
+ allocator.aligned_alloc <void >(batch_size_after_partition_ * sizeof (IdType), 16 );
396
+ void * batch_idx_map_h_ =
397
+ (char *)page_locked_buffer_ + ((char *)batch_idx_map_ - (char *)new_indptr_);
398
+ chunk_start_pos_ =
399
+ allocator.aligned_alloc <void >(batch_size_after_partition_ * sizeof (IdType), 16 );
400
+ void * chunk_start_pos_h_ =
401
+ (char *)page_locked_buffer_ + ((char *)chunk_start_pos_ - (char *)new_indptr_);
402
+ seq_lengths_before_partition_ =
403
+ allocator.aligned_alloc <void >(batch_size_after_partition_ * sizeof (IdType), 16 );
404
+ void * seq_lengths_before_partition_h_ =
405
+ (char *)page_locked_buffer_ +
406
+ ((char *)seq_lengths_before_partition_ - (char *)new_indptr_);
407
+ size_t num_bytes_to_copy = (char *)allocator.ptr - (char *)new_indptr_;
408
+ FLASHINFER_CUDA_CALL (PartitionPagedKVCacheComputeAuxiliaryInfo (
409
+ max_num_pages_per_batch, batch_size, batch_size_after_partition_, page_size, indptr,
410
+ last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_,
411
+ (IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
412
+ (IdType*)seq_lengths_before_partition_h_,
413
+ /* block_valid_mask_h=*/ nullptr ,
414
+ /* device_buffer=*/ new_indptr_,
415
+ /* host_buffer=*/ page_locked_buffer_, num_bytes_to_copy, stream_));
416
+ }
413
417
}
414
- }
418
+ });
415
419
forward_started_ = true ;
416
420
return cudaSuccess;
417
421
}
0 commit comments