@@ -318,56 +318,56 @@ class BatchDecodeHandler {
318
318
<< " initialized for CUDAGraph" ;
319
319
throw std::runtime_error (err_msg.str ());
320
320
}
321
- size_t padded_batch_size_after_partition = max_grid_size / num_kv_heads;
321
+ size_t padded_batch_size = max_grid_size / num_kv_heads;
322
322
if (tmp_size > 0 ) {
323
- padded_batch_size_ = padded_batch_size_after_partition ;
323
+ padded_batch_size_ = padded_batch_size ;
324
324
AlignedAllocator allocator (buffer, workspace_size_in_bytes);
325
325
tmp_v_ = allocator.aligned_alloc <void >(
326
- num_qo_heads * padded_batch_size_after_partition * HEAD_DIM * sizeof (DTypeOut), 16 );
326
+ num_qo_heads * padded_batch_size * HEAD_DIM * sizeof (DTypeOut), 16 );
327
327
tmp_s_ = allocator.aligned_alloc <void >(
328
- num_qo_heads * padded_batch_size_after_partition * 2 * sizeof (float ), 16 );
328
+ num_qo_heads * padded_batch_size * 2 * sizeof (float ), 16 );
329
329
new_indptr_ = allocator.aligned_alloc <void >(
330
- (padded_batch_size_after_partition + 1 ) * sizeof (IdType), 16 );
330
+ (padded_batch_size + 1 ) * sizeof (IdType), 16 );
331
331
332
332
void * new_indptr_h_ = page_locked_buffer_;
333
333
new_last_page_len_ =
334
- allocator.aligned_alloc <void >(padded_batch_size_after_partition * sizeof (IdType), 16 );
334
+ allocator.aligned_alloc <void >(padded_batch_size * sizeof (IdType), 16 );
335
335
void * new_last_page_len_h_ =
336
336
(char *)page_locked_buffer_ + ((char *)new_last_page_len_ - (char *)new_indptr_);
337
337
chunk_indptr_ = allocator.aligned_alloc <void >(
338
- (padded_batch_size_after_partition + 1 ) * sizeof (IdType), 16 );
338
+ (padded_batch_size + 1 ) * sizeof (IdType), 16 );
339
339
void * chunk_indptr_h_ =
340
340
(char *)page_locked_buffer_ + ((char *)chunk_indptr_ - (char *)new_indptr_);
341
341
batch_idx_map_ =
342
- allocator.aligned_alloc <void >(padded_batch_size_after_partition * sizeof (IdType), 16 );
342
+ allocator.aligned_alloc <void >(padded_batch_size * sizeof (IdType), 16 );
343
343
void * batch_idx_map_h_ =
344
344
(char *)page_locked_buffer_ + ((char *)batch_idx_map_ - (char *)new_indptr_);
345
345
chunk_start_pos_ =
346
- allocator.aligned_alloc <void >(padded_batch_size_after_partition * sizeof (IdType), 16 );
346
+ allocator.aligned_alloc <void >(padded_batch_size * sizeof (IdType), 16 );
347
347
void * chunk_start_pos_h_ =
348
348
(char *)page_locked_buffer_ + ((char *)chunk_start_pos_ - (char *)new_indptr_);
349
349
seq_lengths_before_partition_ =
350
- allocator.aligned_alloc <void >(padded_batch_size_after_partition * sizeof (IdType), 16 );
350
+ allocator.aligned_alloc <void >(padded_batch_size * sizeof (IdType), 16 );
351
351
void * seq_lengths_before_partition_h_ =
352
352
(char *)page_locked_buffer_ +
353
353
((char *)seq_lengths_before_partition_ - (char *)new_indptr_);
354
354
block_valid_mask_ =
355
- allocator.aligned_alloc <bool >(padded_batch_size_after_partition * sizeof (bool ), 16 );
355
+ allocator.aligned_alloc <bool >(padded_batch_size * sizeof (bool ), 16 );
356
356
bool * block_valid_mask_h_ =
357
357
(bool *)page_locked_buffer_ + ((bool *)block_valid_mask_ - (bool *)new_indptr_);
358
- std::fill (block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size_after_partition , 0 );
358
+ std::fill (block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size , 0 );
359
359
360
360
size_t num_bytes_to_copy = (char *)allocator.ptr - (char *)new_indptr_;
361
361
FLASHINFER_CUDA_CALL (PartitionPagedKVCacheComputeAuxiliaryInfo (
362
- max_num_pages_per_batch, batch_size, padded_batch_size_after_partition , page_size,
362
+ max_num_pages_per_batch, batch_size, padded_batch_size , page_size,
363
363
indptr, last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_,
364
364
(IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
365
365
(IdType*)seq_lengths_before_partition_h_, block_valid_mask_h_,
366
366
/* device_buffer=*/ new_indptr_,
367
367
/* host_buffer=*/ page_locked_buffer_, num_bytes_to_copy, stream_));
368
368
} else {
369
369
block_valid_mask_ = nullptr ;
370
- padded_batch_size_ = num_kv_heads * batch_size;
370
+ padded_batch_size_ = batch_size;
371
371
}
372
372
} else {
373
373
// NOTE(Zihao): we don't use block_valid_mask when CUDAGraph is disabled.
0 commit comments