Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: refactor fa2 prefill template #776

Merged
merged 12 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aot_build_utils/generate_batch_paged_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_cu_file_str(
dtype_out,
idtype,
):
cta_tile_q_choice = [128, 64, 16]
cta_tile_q_choice = [128, 64, 32, 16]

def get_insts(attention_variant, dtype_out):
return "\n".join(
Expand Down
2 changes: 1 addition & 1 deletion aot_build_utils/generate_batch_ragged_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_cu_file_str(
dtype_out,
idtype,
):
cta_tile_q_choice = [128, 64, 16]
cta_tile_q_choice = [128, 64, 32, 16]

def get_insts(attention_variant, dtype_out):
return "\n".join(
Expand Down
2 changes: 1 addition & 1 deletion csrc/batch_prefill_paged_kernel_inst.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace flashinfer {

constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom;

{% for cta_tile_q in [16, 64, 128] %}
{% for cta_tile_q in [16, 32, 64, 128] %}
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<
/*CTA_TILE_Q=*/{{cta_tile_q}}, {{head_dim_qk}}, {{head_dim_vo}}, {{pos_encoding_mode}}, {{use_fp16_qk_reduction}}, {{mask_mode}},
{{ variant_name }}, PagedParams>(PagedParams params, {{ dtype_o }}* tmp_v, float* tmp_s, cudaStream_t stream);
Expand Down
2 changes: 1 addition & 1 deletion csrc/batch_prefill_ragged_kernel_inst.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace flashinfer {

constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom;

{% for cta_tile_q in [16, 64, 128] %}
{% for cta_tile_q in [16, 32, 64, 128] %}
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<
/*CTA_TILE_Q=*/{{cta_tile_q}}, {{head_dim_qk}}, {{head_dim_vo}}, {{pos_encoding_mode}}, {{use_fp16_qk_reduction}}, {{mask_mode}},
{{ variant_name }}, RaggedParams>(RaggedParams params, {{ dtype_o }}* tmp_v, float* tmp_s, cudaStream_t stream);
Expand Down
1,510 changes: 751 additions & 759 deletions include/flashinfer/attention/prefill.cuh

Large diffs are not rendered by default.

25 changes: 2 additions & 23 deletions include/flashinfer/attention/scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -418,27 +418,6 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in
return cudaSuccess;
}

inline uint32_t DetermineCtaTileQ(int64_t avg_packed_qo_len, uint32_t head_dim) {
if (avg_packed_qo_len > 64 && head_dim < 256) {
return 128;
} else {
auto compute_capacity = GetCudaComputeCapability();
if (compute_capacity.first >= 8) {
// Ampere or newer
if (avg_packed_qo_len > 16) {
// avg_packed_qo_len <= 64
return 64;
} else {
// avg_packed_qo_len <= 16
return 16;
}
} else {
// NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
return 64;
}
}
}

template <typename IdType>
inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
uint32_t total_num_rows, uint32_t batch_size,
Expand Down Expand Up @@ -480,7 +459,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
// the CUDA graph is created fixes the maximum number of tokens.
const uint64_t max_seq_len = total_num_rows - batch_size + 1;
uint64_t max_qo_len = uint64_t(max_seq_len) * gqa_group_size;
cta_tile_q = DetermineCtaTileQ(max_qo_len, head_dim);
cta_tile_q = FA2DetermineCtaTileQ(max_qo_len, head_dim);

// Find an upper bound for the number of tiles, derived from the total
// number of rows and the batch size. The sum of qo lengths rounded
Expand All @@ -493,7 +472,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
sum_packed_qo_len += packed_qo_len_arr[i];
}
const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size;
cta_tile_q = DetermineCtaTileQ(avg_packed_qo_len, head_dim);
cta_tile_q = FA2DetermineCtaTileQ(avg_packed_qo_len, head_dim);

total_num_tiles_q = 0;
for (uint32_t i = 0; i < batch_size; ++i) {
Expand Down
4 changes: 2 additions & 2 deletions include/flashinfer/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u
* \brief Use mma instructions to compute rowsum.
*/
template <typename DType>
__device__ __forceinline__ void rowsum_f8f8f32(float* d, DType* s) {
__device__ __forceinline__ void m16k32_rowsum_f8f8f32(float* d, DType* s) {
static_assert(sizeof(DType) == 1, "DType must be 8bit floating data type");
uint32_t* s_u32 = (uint32_t*)(s);
#if defined(FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED)
Expand Down Expand Up @@ -519,7 +519,7 @@ __device__ __forceinline__ void rowsum_f8f8f32(float* d, DType* s) {
* \brief Use mma instructions to compute rowsum.
*/
template <typename DType>
__device__ __forceinline__ void rowsum_f16f16f32(float* d, DType* s) {
__device__ __forceinline__ void m16k16_rowsum_f16f16f32(float* d, DType* s) {
static_assert(sizeof(DType) == 2, "DType must be 16bit floating data type");
uint32_t* s_u32 = (uint32_t*)(s);
#if defined(FLASHINFER_MMA_F16F16F32_M16N8K16_ENABLED)
Expand Down
29 changes: 29 additions & 0 deletions include/flashinfer/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@
__VA_ARGS__ \
break; \
} \
case 32: { \
constexpr uint32_t CTA_TILE_Q = 32; \
__VA_ARGS__ \
break; \
} \
case 16: { \
constexpr uint32_t CTA_TILE_Q = 16; \
__VA_ARGS__ \
Expand Down Expand Up @@ -290,6 +295,30 @@ inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix =
std::cout << std::endl;
}

inline uint32_t FA2DetermineCtaTileQ(int64_t avg_packed_qo_len, uint32_t head_dim) {
if (avg_packed_qo_len > 64 && head_dim < 256) {
return 128;
} else {
auto compute_capacity = GetCudaComputeCapability();
if (compute_capacity.first >= 8) {
// Ampere or newer
if (avg_packed_qo_len > 32) {
// avg_packed_qo_len <= 64
return 64;
} else if (avg_packed_qo_len > 16) {
// avg_packed_qo_len <= 32
return 32;
} else {
// avg_packed_qo_len <= 16
return 16;
}
} else {
// NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
return 64;
}
}
}

/*!
* \brief Return x - y if x > y, otherwise return 0.
*/
Expand Down
19 changes: 9 additions & 10 deletions tests/test_jit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,6 @@ def test_batch_prefill_sm90_flash_sigmoid():
sigmoid_bias = 0.25

o = wrapper.run(q, k, v, logits_scale, sigmoid_bias)
print(o)
wrapper_paged = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
float_workspace_buffer, kv_layout="NHD", backend="fa3", jit_args=jit_args
)
Expand Down Expand Up @@ -696,7 +695,7 @@ def test_sm90_debug_print_logits():

template <int NUM_ROWS_PER_THREAD>
__device__ auto GetAttentionUpdater() {
return OnlineSoftmax<NUM_ROWS_PER_THREAD, /*WITH_SCALE*/true>(sm_scale_log2);
return OnlineSoftmax<NUM_ROWS_PER_THREAD, /*WITH_SCALE*/false>(sm_scale_log2);
}


Expand Down Expand Up @@ -753,12 +752,12 @@ def test_sm90_debug_print_logits():


if __name__ == "__main__":
# test_single_decode_mask()
# test_flash_sigmoid()
# test_dump_logits()
# test_debug_print_logits()
# test_sm90_debug_print_logits()
# test_batch_decode_flash_sigmoid(False)
# test_batch_decode_flash_sigmoid(True)
# test_batch_prefill_flash_sigmoid()
test_single_decode_mask()
test_flash_sigmoid()
test_dump_logits()
test_debug_print_logits()
test_sm90_debug_print_logits()
test_batch_decode_flash_sigmoid(False)
test_batch_decode_flash_sigmoid(True)
test_batch_prefill_flash_sigmoid()
test_batch_prefill_sm90_flash_sigmoid()