Skip to content

Commit 2fa74bb

Browse files
committed
Revert "perf: refactor fa2 prefill template (flashinfer-ai#776)"
This reverts commit fc03772.
1 parent 088e81f commit 2fa74bb

9 files changed

+798
-797
lines changed

aot_build_utils/generate_batch_paged_prefill_inst.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def get_cu_file_str(
3737
dtype_out,
3838
idtype,
3939
):
40-
cta_tile_q_choice = [128, 64, 32, 16]
40+
cta_tile_q_choice = [128, 64, 16]
4141

4242
def get_insts(attention_variant, dtype_out):
4343
return "\n".join(

aot_build_utils/generate_batch_ragged_prefill_inst.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def get_cu_file_str(
3737
dtype_out,
3838
idtype,
3939
):
40-
cta_tile_q_choice = [128, 64, 32, 16]
40+
cta_tile_q_choice = [128, 64, 16]
4141

4242
def get_insts(attention_variant, dtype_out):
4343
return "\n".join(

csrc/batch_prefill_paged_kernel_inst.jinja

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace flashinfer {
55

66
constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom;
77

8-
{% for cta_tile_q in [16, 32, 64, 128] %}
8+
{% for cta_tile_q in [16, 64, 128] %}
99
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<
1010
/*CTA_TILE_Q=*/{{cta_tile_q}}, {{head_dim_qk}}, {{head_dim_vo}}, {{pos_encoding_mode}}, {{use_fp16_qk_reduction}}, {{mask_mode}},
1111
{{ variant_name }}, PagedParams>(PagedParams params, {{ dtype_o }}* tmp_v, float* tmp_s, cudaStream_t stream);

csrc/batch_prefill_ragged_kernel_inst.jinja

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace flashinfer {
55

66
constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom;
77

8-
{% for cta_tile_q in [16, 32, 64, 128] %}
8+
{% for cta_tile_q in [16, 64, 128] %}
99
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<
1010
/*CTA_TILE_Q=*/{{cta_tile_q}}, {{head_dim_qk}}, {{head_dim_vo}}, {{pos_encoding_mode}}, {{use_fp16_qk_reduction}}, {{mask_mode}},
1111
{{ variant_name }}, RaggedParams>(RaggedParams params, {{ dtype_o }}* tmp_v, float* tmp_s, cudaStream_t stream);

include/flashinfer/attention/prefill.cuh

+759-751
Large diffs are not rendered by default.

include/flashinfer/attention/scheduler.cuh

+23-2
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,27 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in
418418
return cudaSuccess;
419419
}
420420

421+
inline uint32_t DetermineCtaTileQ(int64_t avg_packed_qo_len, uint32_t head_dim) {
422+
if (avg_packed_qo_len > 64 && head_dim < 256) {
423+
return 128;
424+
} else {
425+
auto compute_capacity = GetCudaComputeCapability();
426+
if (compute_capacity.first >= 8) {
427+
// Ampere or newer
428+
if (avg_packed_qo_len > 16) {
429+
// avg_packed_qo_len <= 64
430+
return 64;
431+
} else {
432+
// avg_packed_qo_len <= 16
433+
return 16;
434+
}
435+
} else {
436+
// NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
437+
return 64;
438+
}
439+
}
440+
}
441+
421442
template <typename IdType>
422443
inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
423444
uint32_t total_num_rows, uint32_t batch_size,
@@ -459,7 +480,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
459480
// the CUDA graph is created fixes the maximum number of tokens.
460481
const uint64_t max_seq_len = total_num_rows - batch_size + 1;
461482
uint64_t max_qo_len = uint64_t(max_seq_len) * gqa_group_size;
462-
cta_tile_q = FA2DetermineCtaTileQ(max_qo_len, head_dim);
483+
cta_tile_q = DetermineCtaTileQ(max_qo_len, head_dim);
463484

464485
// Find an upper bound for the number of tiles, derived from the total
465486
// number of rows and the batch size. The sum of qo lengths rounded
@@ -472,7 +493,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
472493
sum_packed_qo_len += packed_qo_len_arr[i];
473494
}
474495
const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size;
475-
cta_tile_q = FA2DetermineCtaTileQ(avg_packed_qo_len, head_dim);
496+
cta_tile_q = DetermineCtaTileQ(avg_packed_qo_len, head_dim);
476497

477498
total_num_tiles_q = 0;
478499
for (uint32_t i = 0; i < batch_size; ++i) {

include/flashinfer/mma.cuh

+2-2
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u
480480
* \brief Use mma instructions to compute rowsum.
481481
*/
482482
template <typename DType>
483-
__device__ __forceinline__ void m16k32_rowsum_f8f8f32(float* d, DType* s) {
483+
__device__ __forceinline__ void rowsum_f8f8f32(float* d, DType* s) {
484484
static_assert(sizeof(DType) == 1, "DType must be 8bit floating data type");
485485
uint32_t* s_u32 = (uint32_t*)(s);
486486
#if defined(FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED)
@@ -519,7 +519,7 @@ __device__ __forceinline__ void m16k32_rowsum_f8f8f32(float* d, DType* s) {
519519
* \brief Use mma instructions to compute rowsum.
520520
*/
521521
template <typename DType>
522-
__device__ __forceinline__ void m16k16_rowsum_f16f16f32(float* d, DType* s) {
522+
__device__ __forceinline__ void rowsum_f16f16f32(float* d, DType* s) {
523523
static_assert(sizeof(DType) == 2, "DType must be 16bit floating data type");
524524
uint32_t* s_u32 = (uint32_t*)(s);
525525
#if defined(FLASHINFER_MMA_F16F16F32_M16N8K16_ENABLED)

include/flashinfer/utils.cuh

-29
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,6 @@
108108
__VA_ARGS__ \
109109
break; \
110110
} \
111-
case 32: { \
112-
constexpr uint32_t CTA_TILE_Q = 32; \
113-
__VA_ARGS__ \
114-
break; \
115-
} \
116111
case 16: { \
117112
constexpr uint32_t CTA_TILE_Q = 16; \
118113
__VA_ARGS__ \
@@ -295,30 +290,6 @@ inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix =
295290
std::cout << std::endl;
296291
}
297292

298-
inline uint32_t FA2DetermineCtaTileQ(int64_t avg_packed_qo_len, uint32_t head_dim) {
299-
if (avg_packed_qo_len > 64 && head_dim < 256) {
300-
return 128;
301-
} else {
302-
auto compute_capacity = GetCudaComputeCapability();
303-
if (compute_capacity.first >= 8) {
304-
// Ampere or newer
305-
if (avg_packed_qo_len > 32) {
306-
// avg_packed_qo_len <= 64
307-
return 64;
308-
} else if (avg_packed_qo_len > 16) {
309-
// avg_packed_qo_len <= 32
310-
return 32;
311-
} else {
312-
// avg_packed_qo_len <= 16
313-
return 16;
314-
}
315-
} else {
316-
// NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
317-
return 64;
318-
}
319-
}
320-
}
321-
322293
/*!
323294
* \brief Return x - y if x > y, otherwise return 0.
324295
*/

tests/test_jit_example.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ def test_batch_prefill_sm90_flash_sigmoid():
561561
sigmoid_bias = 0.25
562562

563563
o = wrapper.run(q, k, v, logits_scale, sigmoid_bias)
564+
print(o)
564565
wrapper_paged = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
565566
float_workspace_buffer, kv_layout="NHD", backend="fa3", jit_args=jit_args
566567
)
@@ -695,7 +696,7 @@ def test_sm90_debug_print_logits():
695696
696697
template <int NUM_ROWS_PER_THREAD>
697698
__device__ auto GetAttentionUpdater() {
698-
return OnlineSoftmax<NUM_ROWS_PER_THREAD, /*WITH_SCALE*/false>(sm_scale_log2);
699+
return OnlineSoftmax<NUM_ROWS_PER_THREAD, /*WITH_SCALE*/true>(sm_scale_log2);
699700
}
700701
701702
@@ -752,12 +753,12 @@ def test_sm90_debug_print_logits():
752753

753754

754755
if __name__ == "__main__":
755-
test_single_decode_mask()
756-
test_flash_sigmoid()
757-
test_dump_logits()
758-
test_debug_print_logits()
759-
test_sm90_debug_print_logits()
760-
test_batch_decode_flash_sigmoid(False)
761-
test_batch_decode_flash_sigmoid(True)
762-
test_batch_prefill_flash_sigmoid()
756+
# test_single_decode_mask()
757+
# test_flash_sigmoid()
758+
# test_dump_logits()
759+
# test_debug_print_logits()
760+
# test_sm90_debug_print_logits()
761+
# test_batch_decode_flash_sigmoid(False)
762+
# test_batch_decode_flash_sigmoid(True)
763+
# test_batch_prefill_flash_sigmoid()
763764
test_batch_prefill_sm90_flash_sigmoid()

0 commit comments

Comments
 (0)