Skip to content

Commit aff4cf0

Browse files
authored
bugfix: fix wrong padded_batch_size_ (#296)
In #294 , we set `padded_batch_size_` to `num_kv_heads * batch_size`, which should be `batch_size`
1 parent 60459e4 commit aff4cf0

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

include/flashinfer/attention/handler.cuh

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -318,56 +318,56 @@ class BatchDecodeHandler {
318318
<< " initialized for CUDAGraph";
319319
throw std::runtime_error(err_msg.str());
320320
}
321-
size_t padded_batch_size_after_partition = max_grid_size / num_kv_heads;
321+
size_t padded_batch_size = max_grid_size / num_kv_heads;
322322
if (tmp_size > 0) {
323-
padded_batch_size_ = padded_batch_size_after_partition;
323+
padded_batch_size_ = padded_batch_size;
324324
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
325325
tmp_v_ = allocator.aligned_alloc<void>(
326-
num_qo_heads * padded_batch_size_after_partition * HEAD_DIM * sizeof(DTypeOut), 16);
326+
num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16);
327327
tmp_s_ = allocator.aligned_alloc<void>(
328-
num_qo_heads * padded_batch_size_after_partition * 2 * sizeof(float), 16);
328+
num_qo_heads * padded_batch_size * 2 * sizeof(float), 16);
329329
new_indptr_ = allocator.aligned_alloc<void>(
330-
(padded_batch_size_after_partition + 1) * sizeof(IdType), 16);
330+
(padded_batch_size + 1) * sizeof(IdType), 16);
331331

332332
void* new_indptr_h_ = page_locked_buffer_;
333333
new_last_page_len_ =
334-
allocator.aligned_alloc<void>(padded_batch_size_after_partition * sizeof(IdType), 16);
334+
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
335335
void* new_last_page_len_h_ =
336336
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
337337
chunk_indptr_ = allocator.aligned_alloc<void>(
338-
(padded_batch_size_after_partition + 1) * sizeof(IdType), 16);
338+
(padded_batch_size + 1) * sizeof(IdType), 16);
339339
void* chunk_indptr_h_ =
340340
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
341341
batch_idx_map_ =
342-
allocator.aligned_alloc<void>(padded_batch_size_after_partition * sizeof(IdType), 16);
342+
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
343343
void* batch_idx_map_h_ =
344344
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
345345
chunk_start_pos_ =
346-
allocator.aligned_alloc<void>(padded_batch_size_after_partition * sizeof(IdType), 16);
346+
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
347347
void* chunk_start_pos_h_ =
348348
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
349349
seq_lengths_before_partition_ =
350-
allocator.aligned_alloc<void>(padded_batch_size_after_partition * sizeof(IdType), 16);
350+
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
351351
void* seq_lengths_before_partition_h_ =
352352
(char*)page_locked_buffer_ +
353353
((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
354354
block_valid_mask_ =
355-
allocator.aligned_alloc<bool>(padded_batch_size_after_partition * sizeof(bool), 16);
355+
allocator.aligned_alloc<bool>(padded_batch_size * sizeof(bool), 16);
356356
bool* block_valid_mask_h_ =
357357
(bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)new_indptr_);
358-
std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size_after_partition, 0);
358+
std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0);
359359

360360
size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_;
361361
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
362-
max_num_pages_per_batch, batch_size, padded_batch_size_after_partition, page_size,
362+
max_num_pages_per_batch, batch_size, padded_batch_size, page_size,
363363
indptr, last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_,
364364
(IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
365365
(IdType*)seq_lengths_before_partition_h_, block_valid_mask_h_,
366366
/*device_buffer=*/new_indptr_,
367367
/*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_));
368368
} else {
369369
block_valid_mask_ = nullptr;
370-
padded_batch_size_ = num_kv_heads * batch_size;
370+
padded_batch_size_ = batch_size;
371371
}
372372
} else {
373373
// NOTE(Zihao): we don't use block_valid_mask when CUDAGraph is disabled.

0 commit comments

Comments
 (0)