Skip to content

Commit c111ca6

Browse files
authored
rafactor: move gqa_group_size from template parameter to input arguments (#301)
#262 is out of sync with main, this PR rebased the code on main branch.
1 parent bb1783b commit c111ca6

27 files changed

+1303
-1553
lines changed

CMakeLists.txt

+133-146
Large diffs are not rendered by default.

cmake/config.cmake

-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ set(FLASHINFER_FASTDIV_TEST ON)
2222
set(FLASHINFER_DISTRIBUTED ON)
2323
# The following configurations can impact the binary
2424
# size of the generated library
25-
set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8)
2625
set(FLASHINFER_GEN_LOGITS_POST_HOOKS 0)
2726
set(FLASHINFER_GEN_PAGE_SIZES 1 16 32)
2827
set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)

include/flashinfer/attention/decode.cuh

+190-186
Large diffs are not rendered by default.

include/flashinfer/attention/handler.cuh

+112-108
Original file line numberDiff line numberDiff line change
@@ -297,121 +297,125 @@ class BatchDecodeHandler {
297297

298298
bool* GetBlockValidMask() const { return block_valid_mask_; }
299299

300-
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage,
301-
LogitsPostHook LOGITS_POST_HOOK, QKVLayout kv_layout, PosEncodingMode POS_ENCODING_MODE,
302-
typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
300+
template <uint32_t HEAD_DIM, PageStorage page_storage, LogitsPostHook LOGITS_POST_HOOK,
301+
QKVLayout kv_layout, PosEncodingMode POS_ENCODING_MODE, typename DTypeQ,
302+
typename DTypeKV, typename DTypeOut, typename IdType>
303303
cudaError_t BeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes, IdType* indptr,
304304
IdType* last_page_len, uint32_t batch_size,
305-
uint32_t num_qo_heads, uint32_t page_size) {
305+
uint32_t num_qo_heads, uint32_t num_kv_heads,
306+
uint32_t page_size) {
306307
batch_size_before_partition_ = batch_size;
307-
uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE;
308308
uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size;
309-
auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched<
310-
GROUP_SIZE, HEAD_DIM, page_storage, LOGITS_POST_HOOK, kv_layout, POS_ENCODING_MODE, DTypeQ,
311-
DTypeKV, DTypeOut, IdType>;
312-
FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch,
313-
new_batch_size, batch_size, indptr, num_qo_heads,
314-
page_size,
315-
/*enable_cuda_graph=*/IsCUDAGraphEnabled(), stream_));
316-
batch_size_after_partition_ = new_batch_size;
317-
if (IsCUDAGraphEnabled()) {
318-
if (batch_size != fixed_batch_size_) {
319-
std::ostringstream err_msg;
320-
err_msg << "The running batch size " << batch_size
321-
<< " is not compatible with the fixed batch size " << fixed_batch_size_
322-
<< " initialized for CUDAGraph";
323-
throw std::runtime_error(err_msg.str());
324-
}
325-
size_t padded_batch_size = max_grid_size / num_kv_heads;
326-
if (tmp_size > 0) {
327-
padded_batch_size_ = padded_batch_size;
328-
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
329-
tmp_v_ = allocator.aligned_alloc<void>(
330-
num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16);
331-
tmp_s_ =
332-
allocator.aligned_alloc<void>(num_qo_heads * padded_batch_size * 2 * sizeof(float), 16);
333-
new_indptr_ = allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType), 16);
334-
335-
void* new_indptr_h_ = page_locked_buffer_;
336-
new_last_page_len_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
337-
void* new_last_page_len_h_ =
338-
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
339-
chunk_indptr_ = allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType), 16);
340-
void* chunk_indptr_h_ =
341-
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
342-
batch_idx_map_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
343-
void* batch_idx_map_h_ =
344-
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
345-
chunk_start_pos_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
346-
void* chunk_start_pos_h_ =
347-
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
348-
seq_lengths_before_partition_ =
349-
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
350-
void* seq_lengths_before_partition_h_ =
351-
(char*)page_locked_buffer_ +
352-
((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
353-
block_valid_mask_ = allocator.aligned_alloc<bool>(padded_batch_size * sizeof(bool), 16);
354-
bool* block_valid_mask_h_ =
355-
(bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)new_indptr_);
356-
std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0);
357-
358-
size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_;
359-
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
360-
max_num_pages_per_batch, batch_size, padded_batch_size, page_size, indptr,
361-
last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_,
362-
(IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
363-
(IdType*)seq_lengths_before_partition_h_, block_valid_mask_h_,
364-
/*device_buffer=*/new_indptr_,
365-
/*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_));
309+
DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, {
310+
auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched<
311+
GROUP_SIZE, HEAD_DIM, page_storage, LOGITS_POST_HOOK, kv_layout, POS_ENCODING_MODE,
312+
DTypeQ, DTypeKV, DTypeOut, IdType>;
313+
FLASHINFER_CUDA_CALL(
314+
work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size,
315+
batch_size, indptr, num_qo_heads, page_size,
316+
/*enable_cuda_graph=*/IsCUDAGraphEnabled(), stream_));
317+
batch_size_after_partition_ = new_batch_size;
318+
if (IsCUDAGraphEnabled()) {
319+
if (batch_size != fixed_batch_size_) {
320+
std::ostringstream err_msg;
321+
err_msg << "The running batch size " << batch_size
322+
<< " is not compatible with the fixed batch size " << fixed_batch_size_
323+
<< " initialized for CUDAGraph";
324+
throw std::runtime_error(err_msg.str());
325+
}
326+
size_t padded_batch_size = max_grid_size / num_kv_heads;
327+
if (tmp_size > 0) {
328+
padded_batch_size_ = padded_batch_size;
329+
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
330+
tmp_v_ = allocator.aligned_alloc<void>(
331+
num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16);
332+
tmp_s_ = allocator.aligned_alloc<void>(
333+
num_qo_heads * padded_batch_size * 2 * sizeof(float), 16);
334+
new_indptr_ = allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType), 16);
335+
336+
void* new_indptr_h_ = page_locked_buffer_;
337+
new_last_page_len_ =
338+
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
339+
void* new_last_page_len_h_ =
340+
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
341+
chunk_indptr_ =
342+
allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType), 16);
343+
void* chunk_indptr_h_ =
344+
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
345+
batch_idx_map_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
346+
void* batch_idx_map_h_ =
347+
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
348+
chunk_start_pos_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
349+
void* chunk_start_pos_h_ =
350+
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
351+
seq_lengths_before_partition_ =
352+
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
353+
void* seq_lengths_before_partition_h_ =
354+
(char*)page_locked_buffer_ +
355+
((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
356+
block_valid_mask_ = allocator.aligned_alloc<bool>(padded_batch_size * sizeof(bool), 16);
357+
bool* block_valid_mask_h_ =
358+
(bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)new_indptr_);
359+
std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0);
360+
361+
size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_;
362+
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
363+
max_num_pages_per_batch, batch_size, padded_batch_size, page_size, indptr,
364+
last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_,
365+
(IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
366+
(IdType*)seq_lengths_before_partition_h_, block_valid_mask_h_,
367+
/*device_buffer=*/new_indptr_,
368+
/*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_));
369+
} else {
370+
block_valid_mask_ = nullptr;
371+
padded_batch_size_ = batch_size;
372+
}
366373
} else {
374+
// NOTE(Zihao): we don't use block_valid_mask when CUDAGraph is disabled.
367375
block_valid_mask_ = nullptr;
368-
padded_batch_size_ = batch_size;
369-
}
370-
} else {
371-
// NOTE(Zihao): we don't use block_valid_mask when CUDAGraph is disabled.
372-
block_valid_mask_ = nullptr;
373-
// do not pad the batch size when not using CUDAGraph
374-
padded_batch_size_ = batch_size_after_partition_;
375-
if (tmp_size > 0) {
376-
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
377-
tmp_v_ = allocator.aligned_alloc<void>(tmp_size, 16);
378-
tmp_s_ = (char*)tmp_v_ +
379-
num_qo_heads * batch_size_after_partition_ * HEAD_DIM * sizeof(DTypeOut);
380-
new_indptr_ =
381-
allocator.aligned_alloc<void>((batch_size_after_partition_ + 1) * sizeof(IdType), 16);
382-
void* new_indptr_h_ = page_locked_buffer_;
383-
new_last_page_len_ =
384-
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
385-
void* new_last_page_len_h_ =
386-
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
387-
chunk_indptr_ =
388-
allocator.aligned_alloc<void>((batch_size_before_partition_ + 1) * sizeof(IdType), 16);
389-
void* chunk_indptr_h_ =
390-
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
391-
batch_idx_map_ =
392-
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
393-
void* batch_idx_map_h_ =
394-
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
395-
chunk_start_pos_ =
396-
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
397-
void* chunk_start_pos_h_ =
398-
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
399-
seq_lengths_before_partition_ =
400-
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
401-
void* seq_lengths_before_partition_h_ =
402-
(char*)page_locked_buffer_ +
403-
((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
404-
size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_;
405-
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
406-
max_num_pages_per_batch, batch_size, batch_size_after_partition_, page_size, indptr,
407-
last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_,
408-
(IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
409-
(IdType*)seq_lengths_before_partition_h_,
410-
/*block_valid_mask_h=*/nullptr,
411-
/*device_buffer=*/new_indptr_,
412-
/*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_));
376+
// do not pad the batch size when not using CUDAGraph
377+
padded_batch_size_ = batch_size_after_partition_;
378+
if (tmp_size > 0) {
379+
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
380+
tmp_v_ = allocator.aligned_alloc<void>(tmp_size, 16);
381+
tmp_s_ = (char*)tmp_v_ +
382+
num_qo_heads * batch_size_after_partition_ * HEAD_DIM * sizeof(DTypeOut);
383+
new_indptr_ =
384+
allocator.aligned_alloc<void>((batch_size_after_partition_ + 1) * sizeof(IdType), 16);
385+
void* new_indptr_h_ = page_locked_buffer_;
386+
new_last_page_len_ =
387+
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
388+
void* new_last_page_len_h_ =
389+
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
390+
chunk_indptr_ = allocator.aligned_alloc<void>(
391+
(batch_size_before_partition_ + 1) * sizeof(IdType), 16);
392+
void* chunk_indptr_h_ =
393+
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
394+
batch_idx_map_ =
395+
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
396+
void* batch_idx_map_h_ =
397+
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
398+
chunk_start_pos_ =
399+
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
400+
void* chunk_start_pos_h_ =
401+
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
402+
seq_lengths_before_partition_ =
403+
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
404+
void* seq_lengths_before_partition_h_ =
405+
(char*)page_locked_buffer_ +
406+
((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
407+
size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_;
408+
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
409+
max_num_pages_per_batch, batch_size, batch_size_after_partition_, page_size, indptr,
410+
last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_,
411+
(IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
412+
(IdType*)seq_lengths_before_partition_h_,
413+
/*block_valid_mask_h=*/nullptr,
414+
/*device_buffer=*/new_indptr_,
415+
/*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_));
416+
}
413417
}
414-
}
418+
});
415419
forward_started_ = true;
416420
return cudaSuccess;
417421
}

0 commit comments

Comments
 (0)