Skip to content

Commit fc03772

Browse files
authoredFeb 4, 2025··
perf: refactor fa2 prefill template (#776)
This PR refactors the FA2-based prefill template, including the following changes: 1. Using KernelTraits for all constexpr and data types. 2. Using SharedStorage class for a clean interface shared memory management. 3. Unlock `CTA_TILE_Q=32`. We also tried `CTA_TILE_Q=8`, the half-mma optimization for GQA decoding with low group ratio (<=8), however, the performance improvement is very marginal (<1%) and make codebase complicated and thus we didn't incorporate this feature in the PR.
1 parent 0ca046a commit fc03772

9 files changed

+797
-798
lines changed
 

‎aot_build_utils/generate_batch_paged_prefill_inst.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def get_cu_file_str(
3737
dtype_out,
3838
idtype,
3939
):
40-
cta_tile_q_choice = [128, 64, 16]
40+
cta_tile_q_choice = [128, 64, 32, 16]
4141

4242
def get_insts(attention_variant, dtype_out):
4343
return "\n".join(

‎aot_build_utils/generate_batch_ragged_prefill_inst.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def get_cu_file_str(
3737
dtype_out,
3838
idtype,
3939
):
40-
cta_tile_q_choice = [128, 64, 16]
40+
cta_tile_q_choice = [128, 64, 32, 16]
4141

4242
def get_insts(attention_variant, dtype_out):
4343
return "\n".join(

‎csrc/batch_prefill_paged_kernel_inst.jinja

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace flashinfer {
55

66
constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom;
77

8-
{% for cta_tile_q in [16, 64, 128] %}
8+
{% for cta_tile_q in [16, 32, 64, 128] %}
99
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<
1010
/*CTA_TILE_Q=*/{{cta_tile_q}}, {{head_dim_qk}}, {{head_dim_vo}}, {{pos_encoding_mode}}, {{use_fp16_qk_reduction}}, {{mask_mode}},
1111
{{ variant_name }}, PagedParams>(PagedParams params, {{ dtype_o }}* tmp_v, float* tmp_s, cudaStream_t stream);

‎csrc/batch_prefill_ragged_kernel_inst.jinja

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace flashinfer {
55

66
constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom;
77

8-
{% for cta_tile_q in [16, 64, 128] %}
8+
{% for cta_tile_q in [16, 32, 64, 128] %}
99
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<
1010
/*CTA_TILE_Q=*/{{cta_tile_q}}, {{head_dim_qk}}, {{head_dim_vo}}, {{pos_encoding_mode}}, {{use_fp16_qk_reduction}}, {{mask_mode}},
1111
{{ variant_name }}, RaggedParams>(RaggedParams params, {{ dtype_o }}* tmp_v, float* tmp_s, cudaStream_t stream);

‎include/flashinfer/attention/prefill.cuh

+751-759
Large diffs are not rendered by default.

‎include/flashinfer/attention/scheduler.cuh

+2-23
Original file line numberDiff line numberDiff line change
@@ -418,27 +418,6 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in
418418
return cudaSuccess;
419419
}
420420

421-
inline uint32_t DetermineCtaTileQ(int64_t avg_packed_qo_len, uint32_t head_dim) {
422-
if (avg_packed_qo_len > 64 && head_dim < 256) {
423-
return 128;
424-
} else {
425-
auto compute_capacity = GetCudaComputeCapability();
426-
if (compute_capacity.first >= 8) {
427-
// Ampere or newer
428-
if (avg_packed_qo_len > 16) {
429-
// avg_packed_qo_len <= 64
430-
return 64;
431-
} else {
432-
// avg_packed_qo_len <= 16
433-
return 16;
434-
}
435-
} else {
436-
// NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
437-
return 64;
438-
}
439-
}
440-
}
441-
442421
template <typename IdType>
443422
inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
444423
uint32_t total_num_rows, uint32_t batch_size,
@@ -480,7 +459,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
480459
// the CUDA graph is created fixes the maximum number of tokens.
481460
const uint64_t max_seq_len = total_num_rows - batch_size + 1;
482461
uint64_t max_qo_len = uint64_t(max_seq_len) * gqa_group_size;
483-
cta_tile_q = DetermineCtaTileQ(max_qo_len, head_dim);
462+
cta_tile_q = FA2DetermineCtaTileQ(max_qo_len, head_dim);
484463

485464
// Find an upper bound for the number of tiles, derived from the total
486465
// number of rows and the batch size. The sum of qo lengths rounded
@@ -493,7 +472,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
493472
sum_packed_qo_len += packed_qo_len_arr[i];
494473
}
495474
const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size;
496-
cta_tile_q = DetermineCtaTileQ(avg_packed_qo_len, head_dim);
475+
cta_tile_q = FA2DetermineCtaTileQ(avg_packed_qo_len, head_dim);
497476

498477
total_num_tiles_q = 0;
499478
for (uint32_t i = 0; i < batch_size; ++i) {

‎include/flashinfer/mma.cuh

+2-2
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u
480480
* \brief Use mma instructions to compute rowsum.
481481
*/
482482
template <typename DType>
483-
__device__ __forceinline__ void rowsum_f8f8f32(float* d, DType* s) {
483+
__device__ __forceinline__ void m16k32_rowsum_f8f8f32(float* d, DType* s) {
484484
static_assert(sizeof(DType) == 1, "DType must be 8bit floating data type");
485485
uint32_t* s_u32 = (uint32_t*)(s);
486486
#if defined(FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED)
@@ -519,7 +519,7 @@ __device__ __forceinline__ void rowsum_f8f8f32(float* d, DType* s) {
519519
* \brief Use mma instructions to compute rowsum.
520520
*/
521521
template <typename DType>
522-
__device__ __forceinline__ void rowsum_f16f16f32(float* d, DType* s) {
522+
__device__ __forceinline__ void m16k16_rowsum_f16f16f32(float* d, DType* s) {
523523
static_assert(sizeof(DType) == 2, "DType must be 16bit floating data type");
524524
uint32_t* s_u32 = (uint32_t*)(s);
525525
#if defined(FLASHINFER_MMA_F16F16F32_M16N8K16_ENABLED)

‎include/flashinfer/utils.cuh

+29
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@
108108
__VA_ARGS__ \
109109
break; \
110110
} \
111+
case 32: { \
112+
constexpr uint32_t CTA_TILE_Q = 32; \
113+
__VA_ARGS__ \
114+
break; \
115+
} \
111116
case 16: { \
112117
constexpr uint32_t CTA_TILE_Q = 16; \
113118
__VA_ARGS__ \
@@ -290,6 +295,30 @@ inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix =
290295
std::cout << std::endl;
291296
}
292297

298+
inline uint32_t FA2DetermineCtaTileQ(int64_t avg_packed_qo_len, uint32_t head_dim) {
299+
if (avg_packed_qo_len > 64 && head_dim < 256) {
300+
return 128;
301+
} else {
302+
auto compute_capacity = GetCudaComputeCapability();
303+
if (compute_capacity.first >= 8) {
304+
// Ampere or newer
305+
if (avg_packed_qo_len > 32) {
306+
// avg_packed_qo_len <= 64
307+
return 64;
308+
} else if (avg_packed_qo_len > 16) {
309+
// avg_packed_qo_len <= 32
310+
return 32;
311+
} else {
312+
// avg_packed_qo_len <= 16
313+
return 16;
314+
}
315+
} else {
316+
// NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
317+
return 64;
318+
}
319+
}
320+
}
321+
293322
/*!
294323
* \brief Return x - y if x > y, otherwise return 0.
295324
*/

‎tests/test_jit_example.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,6 @@ def test_batch_prefill_sm90_flash_sigmoid():
561561
sigmoid_bias = 0.25
562562

563563
o = wrapper.run(q, k, v, logits_scale, sigmoid_bias)
564-
print(o)
565564
wrapper_paged = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
566565
float_workspace_buffer, kv_layout="NHD", backend="fa3", jit_args=jit_args
567566
)
@@ -696,7 +695,7 @@ def test_sm90_debug_print_logits():
696695
697696
template <int NUM_ROWS_PER_THREAD>
698697
__device__ auto GetAttentionUpdater() {
699-
return OnlineSoftmax<NUM_ROWS_PER_THREAD, /*WITH_SCALE*/true>(sm_scale_log2);
698+
return OnlineSoftmax<NUM_ROWS_PER_THREAD, /*WITH_SCALE*/false>(sm_scale_log2);
700699
}
701700
702701
@@ -753,12 +752,12 @@ def test_sm90_debug_print_logits():
753752

754753

755754
if __name__ == "__main__":
756-
# test_single_decode_mask()
757-
# test_flash_sigmoid()
758-
# test_dump_logits()
759-
# test_debug_print_logits()
760-
# test_sm90_debug_print_logits()
761-
# test_batch_decode_flash_sigmoid(False)
762-
# test_batch_decode_flash_sigmoid(True)
763-
# test_batch_prefill_flash_sigmoid()
755+
test_single_decode_mask()
756+
test_flash_sigmoid()
757+
test_dump_logits()
758+
test_debug_print_logits()
759+
test_sm90_debug_print_logits()
760+
test_batch_decode_flash_sigmoid(False)
761+
test_batch_decode_flash_sigmoid(True)
762+
test_batch_prefill_flash_sigmoid()
764763
test_batch_prefill_sm90_flash_sigmoid()

0 commit comments

Comments
 (0)
Please sign in to comment.