Skip to content

Commit 7e9cc7f

Browse files
authored
perf: initial cuda graph support (#256)
As requested in #187 , this PR adds initial support of `CUDAGraph` compatibility of flashinfer batch decode attention kernels. This PR is the first step towards full CUDAGraph support and we will implement CUDAGraph compatible prefill operators in later PRs. # Proposed APIs We add another wrapper `CUDAGraphBatchDecodeWithPagedKVCacheWrapper`, and user need to pre-allocation page data structure buffers to initialize this wrapper class. Once initiated, these buffers are pinned on GPUs in the life cycle of the wrapper class. The behavior of `CUDAGraphBatchDecodeWithPagedKVCacheWrapper` is a little bit different from `BatchDecodeWithPagedKVCacheWrapper`'s: we will only run a fixed set of kernels in CUDAGraph mode, no matter what the input shape is (the original implementation will dispatch to different kernels according to different input shapes). This PR also fix the address of all kernel input pointers to accomodate the constraint of CUDAGraph capturing. # Examples See `test_cuda_graph_batch_decode_with_paged_kv_cache` in unittests. `begin_forward` functions should not be captured as some of the operators are not allowed to be captured. cc @AgrawalAmey @LiuXiaoxuanPKU @comaniac
1 parent ed20304 commit 7e9cc7f

File tree

12 files changed

+710
-96
lines changed

12 files changed

+710
-96
lines changed

Diff for: docs/api/python/decode.rst

+3
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,6 @@ Batch Decoding
2424

2525
.. autoclass:: BatchDecodeWithPagedKVCacheWrapper
2626
:members:
27+
28+
.. autoclass:: CUDAGraphDecodeWithPagedKVCacheWrapper
29+
:members:

Diff for: include/flashinfer/attention/handler.cuh

+120-21
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#define FLASHINFER_HANDLER_CUH_
1818

1919
#include <algorithm>
20+
#include <cstddef>
2021
#include <memory>
2122
#include <unordered_map>
2223
#include <vector>
@@ -101,7 +102,7 @@ template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVL
101102
cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
102103
uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch,
103104
uint32_t& new_batch_size, uint32_t batch_size, IdType* kv_indptr, const uint32_t num_qo_heads,
104-
const uint32_t page_size, cudaStream_t stream) {
105+
const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) {
105106
constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL);
106107
constexpr uint32_t num_stages_smem = 2U;
107108
constexpr uint32_t bdx = HEAD_DIM / vec_size;
@@ -126,8 +127,10 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
126127
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
127128
&num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size));
128129
max_grid_size = num_blocks_per_sm * num_sm;
129-
if (batch_size * num_kv_heads >= max_grid_size) {
130+
if (batch_size * num_kv_heads >= max_grid_size && !enable_cuda_graph) {
130131
// do not use partition-kv kernel
132+
// TODO(Zihao): if enable_cuda_graph, we should always use partition-kv kernel
133+
// so that only one kernel will be captured in the graph.
131134
tmp_size = 0;
132135
new_batch_size = batch_size;
133136
} else {
@@ -299,39 +302,42 @@ class BatchDecodeHandler {
299302
DTypeOut, IdType>;
300303
FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch,
301304
new_batch_size, batch_size, indptr, num_qo_heads,
302-
page_size, stream_));
305+
page_size,
306+
/*enable_cuda_graph=*/false, stream_));
303307
batch_size_after_partition_ = new_batch_size;
304308
if (tmp_size > 0) {
305309
AlignedAlloactor allocator(buffer, workspace_size_in_bytes);
306310
float_buffer_ = allocator.aligned_alloc<void*>(tmp_size, 16);
307311
new_indptr_ =
308312
allocator.aligned_alloc<void*>((batch_size_after_partition_ + 1) * sizeof(IdType), 16);
309-
void* new_indptr_h_ = host_buffer_;
313+
void* new_indptr_h_ = page_locked_buffer_;
310314
new_last_page_len_ =
311315
allocator.aligned_alloc<void*>(batch_size_after_partition_ * sizeof(IdType), 16);
312316
void* new_last_page_len_h_ =
313-
(char*)host_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
317+
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
314318
chunk_indptr_ =
315319
allocator.aligned_alloc<void*>((batch_size_before_partition_ + 1) * sizeof(IdType), 16);
316-
void* chunk_indptr_h_ = (char*)host_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
320+
void* chunk_indptr_h_ =
321+
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
317322
batch_idx_map_ =
318323
allocator.aligned_alloc<void*>(batch_size_after_partition_ * sizeof(IdType), 16);
319-
void* batch_idx_map_h_ = (char*)host_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
324+
void* batch_idx_map_h_ =
325+
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
320326
chunk_start_pos_ =
321327
allocator.aligned_alloc<void*>(batch_size_after_partition_ * sizeof(IdType), 16);
322328
void* chunk_start_pos_h_ =
323-
(char*)host_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
329+
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
324330
seq_lengths_before_partition_ =
325331
allocator.aligned_alloc<void*>(batch_size_after_partition_ * sizeof(IdType), 16);
326332
void* seq_lengths_before_partition_h_ =
327-
(char*)host_buffer_ + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
333+
(char*)page_locked_buffer_ + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
328334
size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_;
329335
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
330336
max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len,
331337
(IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, (IdType*)chunk_indptr_h_,
332338
(IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
333-
(IdType*)seq_lengths_before_partition_h_, new_indptr_, host_buffer_, num_bytes_to_copy,
334-
stream_));
339+
(IdType*)seq_lengths_before_partition_h_, new_indptr_, page_locked_buffer_,
340+
num_bytes_to_copy, stream_));
335341
}
336342
forward_started_ = true;
337343
return cudaSuccess;
@@ -353,6 +359,11 @@ class BatchDecodeHandler {
353359

354360
bool IsForwardStarted() const { return forward_started_; }
355361

362+
void UpdatePageLockedBufferSize(size_t max_workspace_size_in_bytes) {
363+
cudaFreeHost(page_locked_buffer_);
364+
cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes);
365+
}
366+
356367
uint32_t GetBatchSizeBeforePartition() const { return batch_size_before_partition_; }
357368

358369
uint32_t GetBatchSizeAfterPartition() const { return batch_size_after_partition_; }
@@ -372,17 +383,19 @@ class BatchDecodeHandler {
372383
seq_lengths_before_partition_(nullptr),
373384
forward_started_(false),
374385
stream_(nullptr) {
375-
cudaMallocHost(&host_buffer_, max_workspace_size_in_bytes);
386+
cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes);
376387
}
377388
~BatchDecodeHandler() {
378389
EndForward();
379-
cudaFreeHost(host_buffer_);
390+
cudaFreeHost(page_locked_buffer_);
380391
}
381392

382-
private:
393+
virtual bool IsCUDAGraphMode() const { return false; }
394+
395+
protected:
383396
uint32_t batch_size_before_partition_;
384397
uint32_t batch_size_after_partition_;
385-
void* host_buffer_;
398+
void* page_locked_buffer_;
386399
void* float_buffer_;
387400
void* new_indptr_;
388401
void* new_last_page_len_;
@@ -394,6 +407,86 @@ class BatchDecodeHandler {
394407
cudaStream_t stream_;
395408
};
396409

410+
class CUDAGraphBatchDecodeHandler : public BatchDecodeHandler {
411+
public:
412+
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
413+
PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
414+
cudaError_t CUDAGraphBeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes,
415+
IdType* indptr, IdType* last_page_len,
416+
uint32_t batch_size, uint32_t num_qo_heads,
417+
uint32_t page_size) {
418+
batch_size_before_partition_ = batch_size;
419+
uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size;
420+
auto work_estimation_func =
421+
BatchDecodeWithPagedKVCacheWorkEstimationDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
422+
kv_layout, POS_ENCODING_MODE, DTypeIn,
423+
DTypeOut, IdType>;
424+
FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch,
425+
new_batch_size, batch_size, indptr, num_qo_heads,
426+
page_size,
427+
/*enable_cuda_graph=*/true, stream_));
428+
// NOTE(Zihao): max_batch_size_after_partition_ is determined in handler initialization.
429+
// the value should not be changed during the lifetime of the handler.
430+
// So it should be compatible with CUDAGraph which requires fixed pointer.
431+
batch_size_after_partition_ = new_batch_size;
432+
size_t max_tmp_size = num_qo_heads * max_batch_size_after_partition_ *
433+
(HEAD_DIM * sizeof(DTypeOut) + 2 * sizeof(float));
434+
AlignedAlloactor allocator(buffer, workspace_size_in_bytes);
435+
float_buffer_ = allocator.aligned_alloc<void*>(max_tmp_size, 16);
436+
new_indptr_ =
437+
allocator.aligned_alloc<void*>((max_batch_size_after_partition_ + 1) * sizeof(IdType), 16);
438+
439+
void* new_indptr_h_ = page_locked_buffer_;
440+
new_last_page_len_ =
441+
allocator.aligned_alloc<void*>(max_batch_size_after_partition_ * sizeof(IdType), 16);
442+
void* new_last_page_len_h_ =
443+
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
444+
chunk_indptr_ =
445+
allocator.aligned_alloc<void*>((max_batch_size_after_partition_ + 1) * sizeof(IdType), 16);
446+
void* chunk_indptr_h_ =
447+
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
448+
batch_idx_map_ =
449+
allocator.aligned_alloc<void*>(max_batch_size_after_partition_ * sizeof(IdType), 16);
450+
void* batch_idx_map_h_ =
451+
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
452+
chunk_start_pos_ =
453+
allocator.aligned_alloc<void*>(max_batch_size_after_partition_ * sizeof(IdType), 16);
454+
void* chunk_start_pos_h_ =
455+
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
456+
seq_lengths_before_partition_ =
457+
allocator.aligned_alloc<void*>(max_batch_size_after_partition_ * sizeof(IdType), 16);
458+
void* seq_lengths_before_partition_h_ =
459+
(char*)page_locked_buffer_ + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
460+
461+
size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_;
462+
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
463+
max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len,
464+
(IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, (IdType*)chunk_indptr_h_,
465+
(IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
466+
(IdType*)seq_lengths_before_partition_h_, new_indptr_, page_locked_buffer_,
467+
num_bytes_to_copy, stream_));
468+
forward_started_ = true;
469+
return cudaSuccess;
470+
}
471+
CUDAGraphBatchDecodeHandler(size_t max_batch_size) {
472+
int dev_id = 0, num_sm = 0, max_thread_blocks_per_sm = 0;
473+
cudaGetDevice(&dev_id);
474+
cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id);
475+
cudaDeviceGetAttribute(&max_thread_blocks_per_sm, cudaDevAttrMaxBlocksPerMultiprocessor,
476+
dev_id);
477+
max_batch_size_after_partition_ =
478+
std::max<size_t>(max_thread_blocks_per_sm * num_sm, max_batch_size);
479+
std::cout << max_thread_blocks_per_sm * num_sm << " " << max_batch_size << std::endl;
480+
size_t max_workspace_size_in_bytes =
481+
6 * (sizeof(uint64_t) * (max_batch_size_after_partition_ + 1) + 16);
482+
cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes);
483+
}
484+
bool IsCUDAGraphMode() const override { return true; }
485+
486+
private:
487+
uint32_t max_batch_size_after_partition_;
488+
};
489+
397490
class BatchPrefillHandler {
398491
public:
399492
template <typename IdType>
@@ -412,6 +505,11 @@ class BatchPrefillHandler {
412505

413506
bool IsForwardStarted() const { return request_indices_ != nullptr; }
414507

508+
void UpdatePageLockedBufferSize(size_t max_workspace_size_in_bytes) {
509+
cudaFreeHost(page_locked_buffer_);
510+
cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes);
511+
}
512+
415513
template <typename IdType>
416514
cudaError_t BeginForward(void* buffer, size_t workspace_size_in_bytes, IdType* qo_indptr,
417515
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
@@ -429,14 +527,15 @@ class BatchPrefillHandler {
429527
AlignedAlloactor allocator(buffer, workspace_size_in_bytes);
430528
request_indices_ =
431529
allocator.aligned_alloc<void*>(sizeof(IdType) * request_indices_vec.size(), 16);
432-
void* request_indices_h_ = host_buffer_;
530+
void* request_indices_h_ = page_locked_buffer_;
433531
tile_indices_ = allocator.aligned_alloc<void*>(sizeof(IdType) * tile_indices_vec.size(), 16);
434-
void* tile_indices_h_ = (char*)host_buffer_ + ((char*)tile_indices_ - (char*)request_indices_);
532+
void* tile_indices_h_ =
533+
(char*)page_locked_buffer_ + ((char*)tile_indices_ - (char*)request_indices_);
435534
std::copy(request_indices_vec.begin(), request_indices_vec.end(), (IdType*)request_indices_h_);
436535
std::copy(tile_indices_vec.begin(), tile_indices_vec.end(), (IdType*)tile_indices_h_);
437536
size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)request_indices_;
438537

439-
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(request_indices_, host_buffer_, num_bytes_to_copy,
538+
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(request_indices_, page_locked_buffer_, num_bytes_to_copy,
440539
cudaMemcpyHostToDevice, stream_));
441540

442541
return cudaSuccess;
@@ -462,15 +561,15 @@ class BatchPrefillHandler {
462561
num_qo_tiles_(0U),
463562
forward_started_(false),
464563
stream_(nullptr) {
465-
cudaMallocHost(&host_buffer_, max_workspace_size_in_bytes);
564+
cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes);
466565
}
467566
~BatchPrefillHandler() {
468567
EndForward();
469-
cudaFreeHost(host_buffer_);
568+
cudaFreeHost(page_locked_buffer_);
470569
}
471570

472571
private:
473-
void* host_buffer_;
572+
void* page_locked_buffer_;
474573
void* request_indices_;
475574
void* tile_indices_;
476575
uint32_t num_frags_x_;

Diff for: include/flashinfer/attention/prefill.cuh

+20-16
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,8 @@ __device__ __forceinline__ void load_q_global_smem(uint32_t q_idx_base,
309309
#pragma unroll
310310
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
311311
// load q fragment from gmem to smem
312-
q_smem->load_128b_async<SharedMemFillMode::kFillZero>(q_smem_offset_w, q_ptr,
313-
q_idx < qo_upper_bound && group_id < group_size);
312+
q_smem->load_128b_async<SharedMemFillMode::kFillZero>(
313+
q_smem_offset_w, q_ptr, q_idx < qo_upper_bound && group_id < group_size);
314314
q_smem_offset_w = q_smem->advance_offset_by_column<8>(q_smem_offset_w, fyo);
315315
q_ptr += 8 * num_elems_per_128b<DTypeIn>();
316316
}
@@ -933,7 +933,7 @@ __global__ void SinglePrefillWithKVCacheKernel(
933933
constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b<DTypeOut>();
934934

935935
static_assert(num_frags_z * num_frags_y % num_warps == 0);
936-
static_assert(group_size == 1 || group_size >= 4 && group_size <=8);
936+
static_assert(group_size == 1 || group_size >= 4 && group_size <= 8);
937937

938938
extern __shared__ uint8_t smem[];
939939

@@ -1341,7 +1341,8 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
13411341
kv_len = (paged_kv.indptr[request_idx + 1] - paged_kv.indptr[request_idx] - 1) *
13421342
paged_kv.page_size +
13431343
paged_kv.last_page_len[request_idx];
1344-
const uint32_t qo_upper_bound = min(qo_len, (tile_idx + 1) * (num_rows_per_cta / aligned_group_size));
1344+
const uint32_t qo_upper_bound =
1345+
min(qo_len, (tile_idx + 1) * (num_rows_per_cta / aligned_group_size));
13451346

13461347
constexpr bool partition_kv = false;
13471348
constexpr uint32_t head_dim = num_frags_y * 16;
@@ -1364,7 +1365,8 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
13641365
}
13651366
init_states<num_frags_x, num_frags_y>(o_frag, m, d);
13661367

1367-
const uint32_t qo_idx_base = ((tile_idx * num_warps + ty) * num_frags_x * 16) / aligned_group_size;
1368+
const uint32_t qo_idx_base =
1369+
((tile_idx * num_warps + ty) * num_frags_x * 16) / aligned_group_size;
13681370
const uint32_t qo_n_stride = get_n_stride_impl<QKVLayout::kNHD, head_dim>(num_qo_heads),
13691371
qo_h_stride = get_h_stride_impl<QKVLayout::kNHD, head_dim>(qo_len);
13701372
smem_t qo_smem(smem);
@@ -1386,12 +1388,12 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
13861388

13871389
if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) {
13881390
if (q_offset == nullptr) {
1389-
q_smem_inplace_apply_rotary_multiply_sm_scale<aligned_group_size, num_warps, num_frags_x, num_frags_y,
1390-
DTypeIn>(qo_idx_base, qo_len, kv_len, &qo_smem,
1391-
&q_smem_offset_r, rope_freq, sm_scale);
1391+
q_smem_inplace_apply_rotary_multiply_sm_scale<aligned_group_size, num_warps, num_frags_x,
1392+
num_frags_y, DTypeIn>(
1393+
qo_idx_base, qo_len, kv_len, &qo_smem, &q_smem_offset_r, rope_freq, sm_scale);
13921394
} else {
1393-
q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale<aligned_group_size, num_warps, num_frags_x,
1394-
num_frags_y, DTypeIn>(
1395+
q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale<aligned_group_size, num_warps,
1396+
num_frags_x, num_frags_y, DTypeIn>(
13951397
qo_indptr[request_idx] + qo_idx_base, q_offset, &qo_smem, &q_smem_offset_r, rope_freq,
13961398
sm_scale);
13971399
}
@@ -1418,14 +1420,16 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
14181420
cp_async::commit_group();
14191421

14201422
const uint32_t num_iterations = ceil_div(
1421-
(causal ? min(kv_len,
1422-
kv_len - qo_len + ((tile_idx + 1) * num_frags_x * num_warps * 16) / aligned_group_size)
1423-
: kv_len),
1423+
(causal
1424+
? min(kv_len, kv_len - qo_len +
1425+
((tile_idx + 1) * num_frags_x * num_warps * 16) / aligned_group_size)
1426+
: kv_len),
14241427
16 * num_frags_z);
14251428

14261429
const uint32_t mask_iteration =
14271430
(causal
1428-
? min(kv_len + (tile_idx * num_warps * num_frags_x * 16) / aligned_group_size - qo_len, kv_len)
1431+
? min(kv_len + (tile_idx * num_warps * num_frags_x * 16) / aligned_group_size - qo_len,
1432+
kv_len)
14291433
: kv_len) /
14301434
(16 * num_frags_z);
14311435

@@ -1453,8 +1457,8 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
14531457
}
14541458
// apply mask
14551459
if (iter >= mask_iteration) {
1456-
mask_s<partition_kv, causal, aligned_group_size, num_warps, num_frags_x, num_frags_y, num_frags_z>(
1457-
qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, s_frag);
1460+
mask_s<partition_kv, causal, aligned_group_size, num_warps, num_frags_x, num_frags_y,
1461+
num_frags_z>(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, s_frag);
14581462
}
14591463

14601464
// compute m,d states in online softmax

Diff for: include/flashinfer/utils.cuh

+2-2
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,8 @@ std::tuple<IdType, IdType, std::vector<IdType>, std::vector<IdType>> split_qo_in
282282
uint32_t num_qo_tiles = 0;
283283

284284
for (uint32_t i = 0; i < batch_size; ++i) {
285-
for (uint32_t j = qo_indptr_h[i] * aligned_gqa_group_size; j < qo_indptr_h[i + 1] * aligned_gqa_group_size;
286-
j += num_rows_per_cta) {
285+
for (uint32_t j = qo_indptr_h[i] * aligned_gqa_group_size;
286+
j < qo_indptr_h[i + 1] * aligned_gqa_group_size; j += num_rows_per_cta) {
287287
request_indices.push_back(i);
288288
tile_indices.push_back((j - qo_indptr_h[i] * aligned_gqa_group_size) / num_rows_per_cta);
289289
++num_qo_tiles;

0 commit comments

Comments
 (0)