@@ -418,6 +418,27 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in
418
418
return cudaSuccess;
419
419
}
420
420
421
+ inline uint32_t DetermineCtaTileQ (int64_t avg_packed_qo_len, uint32_t head_dim) {
422
+ if (avg_packed_qo_len > 64 && head_dim < 256 ) {
423
+ return 128 ;
424
+ } else {
425
+ auto compute_capacity = GetCudaComputeCapability ();
426
+ if (compute_capacity.first >= 8 ) {
427
+ // Ampere or newer
428
+ if (avg_packed_qo_len > 16 ) {
429
+ // avg_packed_qo_len <= 64
430
+ return 64 ;
431
+ } else {
432
+ // avg_packed_qo_len <= 16
433
+ return 16 ;
434
+ }
435
+ } else {
436
+ // NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
437
+ return 64 ;
438
+ }
439
+ }
440
+ }
441
+
421
442
template <typename IdType>
422
443
inline auto PrefillSplitQOKVIndptr (IdType* qo_indptr_h, IdType* kv_indptr_h,
423
444
uint32_t total_num_rows, uint32_t batch_size,
@@ -459,7 +480,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
459
480
// the CUDA graph is created fixes the maximum number of tokens.
460
481
const uint64_t max_seq_len = total_num_rows - batch_size + 1 ;
461
482
uint64_t max_qo_len = uint64_t (max_seq_len) * gqa_group_size;
462
- cta_tile_q = FA2DetermineCtaTileQ (max_qo_len, head_dim);
483
+ cta_tile_q = DetermineCtaTileQ (max_qo_len, head_dim);
463
484
464
485
// Find an upper bound for the number of tiles, derived from the total
465
486
// number of rows and the batch size. The sum of qo lengths rounded
@@ -472,7 +493,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
472
493
sum_packed_qo_len += packed_qo_len_arr[i];
473
494
}
474
495
const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size;
475
- cta_tile_q = FA2DetermineCtaTileQ (avg_packed_qo_len, head_dim);
496
+ cta_tile_q = DetermineCtaTileQ (avg_packed_qo_len, head_dim);
476
497
477
498
total_num_tiles_q = 0 ;
478
499
for (uint32_t i = 0 ; i < batch_size; ++i) {
0 commit comments