17
17
#define FLASHINFER_HANDLER_CUH_
18
18
19
19
#include < algorithm>
20
+ #include < cstddef>
20
21
#include < memory>
21
22
#include < unordered_map>
22
23
#include < vector>
@@ -101,7 +102,7 @@ template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVL
101
102
cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched (
102
103
uint32_t & tmp_size, uint32_t & max_grid_size, uint32_t & max_num_pages_per_batch,
103
104
uint32_t & new_batch_size, uint32_t batch_size, IdType* kv_indptr, const uint32_t num_qo_heads,
104
- const uint32_t page_size, cudaStream_t stream) {
105
+ const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) {
105
106
constexpr uint32_t vec_size = std::max (16UL / sizeof (DTypeIn), HEAD_DIM / 32UL );
106
107
constexpr uint32_t num_stages_smem = 2U ;
107
108
constexpr uint32_t bdx = HEAD_DIM / vec_size;
@@ -126,8 +127,10 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
126
127
FLASHINFER_CUDA_CALL (cudaOccupancyMaxActiveBlocksPerMultiprocessor (
127
128
&num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size));
128
129
max_grid_size = num_blocks_per_sm * num_sm;
129
- if (batch_size * num_kv_heads >= max_grid_size) {
130
+ if (batch_size * num_kv_heads >= max_grid_size && !enable_cuda_graph ) {
130
131
// do not use partition-kv kernel
132
+ // TODO(Zihao): if enable_cuda_graph, we should always use partition-kv kernel
133
+ // so that only one kernel will be captured in the graph.
131
134
tmp_size = 0 ;
132
135
new_batch_size = batch_size;
133
136
} else {
@@ -299,39 +302,42 @@ class BatchDecodeHandler {
299
302
DTypeOut, IdType>;
300
303
FLASHINFER_CUDA_CALL (work_estimation_func (tmp_size, max_grid_size, max_num_pages_per_batch,
301
304
new_batch_size, batch_size, indptr, num_qo_heads,
302
- page_size, stream_));
305
+ page_size,
306
+ /* enable_cuda_graph=*/ false , stream_));
303
307
batch_size_after_partition_ = new_batch_size;
304
308
if (tmp_size > 0 ) {
305
309
AlignedAlloactor allocator (buffer, workspace_size_in_bytes);
306
310
float_buffer_ = allocator.aligned_alloc <void *>(tmp_size, 16 );
307
311
new_indptr_ =
308
312
allocator.aligned_alloc <void *>((batch_size_after_partition_ + 1 ) * sizeof (IdType), 16 );
309
- void * new_indptr_h_ = host_buffer_ ;
313
+ void * new_indptr_h_ = page_locked_buffer_ ;
310
314
new_last_page_len_ =
311
315
allocator.aligned_alloc <void *>(batch_size_after_partition_ * sizeof (IdType), 16 );
312
316
void * new_last_page_len_h_ =
313
- (char *)host_buffer_ + ((char *)new_last_page_len_ - (char *)new_indptr_);
317
+ (char *)page_locked_buffer_ + ((char *)new_last_page_len_ - (char *)new_indptr_);
314
318
chunk_indptr_ =
315
319
allocator.aligned_alloc <void *>((batch_size_before_partition_ + 1 ) * sizeof (IdType), 16 );
316
- void * chunk_indptr_h_ = (char *)host_buffer_ + ((char *)chunk_indptr_ - (char *)new_indptr_);
320
+ void * chunk_indptr_h_ =
321
+ (char *)page_locked_buffer_ + ((char *)chunk_indptr_ - (char *)new_indptr_);
317
322
batch_idx_map_ =
318
323
allocator.aligned_alloc <void *>(batch_size_after_partition_ * sizeof (IdType), 16 );
319
- void * batch_idx_map_h_ = (char *)host_buffer_ + ((char *)batch_idx_map_ - (char *)new_indptr_);
324
+ void * batch_idx_map_h_ =
325
+ (char *)page_locked_buffer_ + ((char *)batch_idx_map_ - (char *)new_indptr_);
320
326
chunk_start_pos_ =
321
327
allocator.aligned_alloc <void *>(batch_size_after_partition_ * sizeof (IdType), 16 );
322
328
void * chunk_start_pos_h_ =
323
- (char *)host_buffer_ + ((char *)chunk_start_pos_ - (char *)new_indptr_);
329
+ (char *)page_locked_buffer_ + ((char *)chunk_start_pos_ - (char *)new_indptr_);
324
330
seq_lengths_before_partition_ =
325
331
allocator.aligned_alloc <void *>(batch_size_after_partition_ * sizeof (IdType), 16 );
326
332
void * seq_lengths_before_partition_h_ =
327
- (char *)host_buffer_ + ((char *)seq_lengths_before_partition_ - (char *)new_indptr_);
333
+ (char *)page_locked_buffer_ + ((char *)seq_lengths_before_partition_ - (char *)new_indptr_);
328
334
size_t num_bytes_to_copy = (char *)allocator.ptr - (char *)new_indptr_;
329
335
FLASHINFER_CUDA_CALL (PartitionPagedKVCacheComputeAuxiliaryInfo (
330
336
max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len,
331
337
(IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, (IdType*)chunk_indptr_h_,
332
338
(IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
333
- (IdType*)seq_lengths_before_partition_h_, new_indptr_, host_buffer_, num_bytes_to_copy ,
334
- stream_));
339
+ (IdType*)seq_lengths_before_partition_h_, new_indptr_, page_locked_buffer_ ,
340
+ num_bytes_to_copy, stream_));
335
341
}
336
342
forward_started_ = true ;
337
343
return cudaSuccess;
@@ -353,6 +359,11 @@ class BatchDecodeHandler {
353
359
354
360
bool IsForwardStarted () const { return forward_started_; }
355
361
362
+ void UpdatePageLockedBufferSize (size_t max_workspace_size_in_bytes) {
363
+ cudaFreeHost (page_locked_buffer_);
364
+ cudaMallocHost (&page_locked_buffer_, max_workspace_size_in_bytes);
365
+ }
366
+
356
367
uint32_t GetBatchSizeBeforePartition () const { return batch_size_before_partition_; }
357
368
358
369
uint32_t GetBatchSizeAfterPartition () const { return batch_size_after_partition_; }
@@ -372,17 +383,19 @@ class BatchDecodeHandler {
372
383
seq_lengths_before_partition_(nullptr ),
373
384
forward_started_(false ),
374
385
stream_(nullptr ) {
375
- cudaMallocHost (&host_buffer_ , max_workspace_size_in_bytes);
386
+ cudaMallocHost (&page_locked_buffer_ , max_workspace_size_in_bytes);
376
387
}
377
388
~BatchDecodeHandler () {
378
389
EndForward ();
379
- cudaFreeHost (host_buffer_ );
390
+ cudaFreeHost (page_locked_buffer_ );
380
391
}
381
392
382
- private:
393
+ virtual bool IsCUDAGraphMode () const { return false ; }
394
+
395
+ protected:
383
396
uint32_t batch_size_before_partition_;
384
397
uint32_t batch_size_after_partition_;
385
- void * host_buffer_ ;
398
+ void * page_locked_buffer_ ;
386
399
void * float_buffer_;
387
400
void * new_indptr_;
388
401
void * new_last_page_len_;
@@ -394,6 +407,86 @@ class BatchDecodeHandler {
394
407
cudaStream_t stream_;
395
408
};
396
409
410
+ class CUDAGraphBatchDecodeHandler : public BatchDecodeHandler {
411
+ public:
412
+ template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
413
+ PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
414
+ cudaError_t CUDAGraphBeginForwardDispatched (void * buffer, size_t workspace_size_in_bytes,
415
+ IdType* indptr, IdType* last_page_len,
416
+ uint32_t batch_size, uint32_t num_qo_heads,
417
+ uint32_t page_size) {
418
+ batch_size_before_partition_ = batch_size;
419
+ uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size;
420
+ auto work_estimation_func =
421
+ BatchDecodeWithPagedKVCacheWorkEstimationDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
422
+ kv_layout, POS_ENCODING_MODE, DTypeIn,
423
+ DTypeOut, IdType>;
424
+ FLASHINFER_CUDA_CALL (work_estimation_func (tmp_size, max_grid_size, max_num_pages_per_batch,
425
+ new_batch_size, batch_size, indptr, num_qo_heads,
426
+ page_size,
427
+ /* enable_cuda_graph=*/ true , stream_));
428
+ // NOTE(Zihao): max_batch_size_after_partition_ is determined in handler initialization.
429
+ // the value should not be changed during the lifetime of the handler.
430
+ // So it should be compatible with CUDAGraph which requires fixed pointer.
431
+ batch_size_after_partition_ = new_batch_size;
432
+ size_t max_tmp_size = num_qo_heads * max_batch_size_after_partition_ *
433
+ (HEAD_DIM * sizeof (DTypeOut) + 2 * sizeof (float ));
434
+ AlignedAlloactor allocator (buffer, workspace_size_in_bytes);
435
+ float_buffer_ = allocator.aligned_alloc <void *>(max_tmp_size, 16 );
436
+ new_indptr_ =
437
+ allocator.aligned_alloc <void *>((max_batch_size_after_partition_ + 1 ) * sizeof (IdType), 16 );
438
+
439
+ void * new_indptr_h_ = page_locked_buffer_;
440
+ new_last_page_len_ =
441
+ allocator.aligned_alloc <void *>(max_batch_size_after_partition_ * sizeof (IdType), 16 );
442
+ void * new_last_page_len_h_ =
443
+ (char *)page_locked_buffer_ + ((char *)new_last_page_len_ - (char *)new_indptr_);
444
+ chunk_indptr_ =
445
+ allocator.aligned_alloc <void *>((max_batch_size_after_partition_ + 1 ) * sizeof (IdType), 16 );
446
+ void * chunk_indptr_h_ =
447
+ (char *)page_locked_buffer_ + ((char *)chunk_indptr_ - (char *)new_indptr_);
448
+ batch_idx_map_ =
449
+ allocator.aligned_alloc <void *>(max_batch_size_after_partition_ * sizeof (IdType), 16 );
450
+ void * batch_idx_map_h_ =
451
+ (char *)page_locked_buffer_ + ((char *)batch_idx_map_ - (char *)new_indptr_);
452
+ chunk_start_pos_ =
453
+ allocator.aligned_alloc <void *>(max_batch_size_after_partition_ * sizeof (IdType), 16 );
454
+ void * chunk_start_pos_h_ =
455
+ (char *)page_locked_buffer_ + ((char *)chunk_start_pos_ - (char *)new_indptr_);
456
+ seq_lengths_before_partition_ =
457
+ allocator.aligned_alloc <void *>(max_batch_size_after_partition_ * sizeof (IdType), 16 );
458
+ void * seq_lengths_before_partition_h_ =
459
+ (char *)page_locked_buffer_ + ((char *)seq_lengths_before_partition_ - (char *)new_indptr_);
460
+
461
+ size_t num_bytes_to_copy = (char *)allocator.ptr - (char *)new_indptr_;
462
+ FLASHINFER_CUDA_CALL (PartitionPagedKVCacheComputeAuxiliaryInfo (
463
+ max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len,
464
+ (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, (IdType*)chunk_indptr_h_,
465
+ (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
466
+ (IdType*)seq_lengths_before_partition_h_, new_indptr_, page_locked_buffer_,
467
+ num_bytes_to_copy, stream_));
468
+ forward_started_ = true ;
469
+ return cudaSuccess;
470
+ }
471
+ CUDAGraphBatchDecodeHandler (size_t max_batch_size) {
472
+ int dev_id = 0 , num_sm = 0 , max_thread_blocks_per_sm = 0 ;
473
+ cudaGetDevice (&dev_id);
474
+ cudaDeviceGetAttribute (&num_sm, cudaDevAttrMultiProcessorCount, dev_id);
475
+ cudaDeviceGetAttribute (&max_thread_blocks_per_sm, cudaDevAttrMaxBlocksPerMultiprocessor,
476
+ dev_id);
477
+ max_batch_size_after_partition_ =
478
+ std::max<size_t >(max_thread_blocks_per_sm * num_sm, max_batch_size);
479
+ std::cout << max_thread_blocks_per_sm * num_sm << " " << max_batch_size << std::endl;
480
+ size_t max_workspace_size_in_bytes =
481
+ 6 * (sizeof (uint64_t ) * (max_batch_size_after_partition_ + 1 ) + 16 );
482
+ cudaMallocHost (&page_locked_buffer_, max_workspace_size_in_bytes);
483
+ }
484
+ bool IsCUDAGraphMode () const override { return true ; }
485
+
486
+ private:
487
+ uint32_t max_batch_size_after_partition_;
488
+ };
489
+
397
490
class BatchPrefillHandler {
398
491
public:
399
492
template <typename IdType>
@@ -412,6 +505,11 @@ class BatchPrefillHandler {
412
505
413
506
bool IsForwardStarted () const { return request_indices_ != nullptr ; }
414
507
508
+ void UpdatePageLockedBufferSize (size_t max_workspace_size_in_bytes) {
509
+ cudaFreeHost (page_locked_buffer_);
510
+ cudaMallocHost (&page_locked_buffer_, max_workspace_size_in_bytes);
511
+ }
512
+
415
513
template <typename IdType>
416
514
cudaError_t BeginForward (void * buffer, size_t workspace_size_in_bytes, IdType* qo_indptr,
417
515
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
@@ -429,14 +527,15 @@ class BatchPrefillHandler {
429
527
AlignedAlloactor allocator (buffer, workspace_size_in_bytes);
430
528
request_indices_ =
431
529
allocator.aligned_alloc <void *>(sizeof (IdType) * request_indices_vec.size (), 16 );
432
- void * request_indices_h_ = host_buffer_ ;
530
+ void * request_indices_h_ = page_locked_buffer_ ;
433
531
tile_indices_ = allocator.aligned_alloc <void *>(sizeof (IdType) * tile_indices_vec.size (), 16 );
434
- void * tile_indices_h_ = (char *)host_buffer_ + ((char *)tile_indices_ - (char *)request_indices_);
532
+ void * tile_indices_h_ =
533
+ (char *)page_locked_buffer_ + ((char *)tile_indices_ - (char *)request_indices_);
435
534
std::copy (request_indices_vec.begin (), request_indices_vec.end (), (IdType*)request_indices_h_);
436
535
std::copy (tile_indices_vec.begin (), tile_indices_vec.end (), (IdType*)tile_indices_h_);
437
536
size_t num_bytes_to_copy = (char *)allocator.ptr - (char *)request_indices_;
438
537
439
- FLASHINFER_CUDA_CALL (cudaMemcpyAsync (request_indices_, host_buffer_ , num_bytes_to_copy,
538
+ FLASHINFER_CUDA_CALL (cudaMemcpyAsync (request_indices_, page_locked_buffer_ , num_bytes_to_copy,
440
539
cudaMemcpyHostToDevice, stream_));
441
540
442
541
return cudaSuccess;
@@ -462,15 +561,15 @@ class BatchPrefillHandler {
462
561
num_qo_tiles_(0U ),
463
562
forward_started_(false ),
464
563
stream_(nullptr ) {
465
- cudaMallocHost (&host_buffer_ , max_workspace_size_in_bytes);
564
+ cudaMallocHost (&page_locked_buffer_ , max_workspace_size_in_bytes);
466
565
}
467
566
~BatchPrefillHandler () {
468
567
EndForward ();
469
- cudaFreeHost (host_buffer_ );
568
+ cudaFreeHost (page_locked_buffer_ );
470
569
}
471
570
472
571
private:
473
- void * host_buffer_ ;
572
+ void * page_locked_buffer_ ;
474
573
void * request_indices_;
475
574
void * tile_indices_;
476
575
uint32_t num_frags_x_;
0 commit comments