Skip to content

Commit 83bab99

Browse files
authored
bugfix: drop CTA_TILE_Q=32 (#785)
#776 added CTA_TILE_Q=32 but it produces incorrect result.
1 parent 2d2e13a commit 83bab99

5 files changed

+5
-13
lines changed

aot_build_utils/generate_batch_paged_prefill_inst.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def get_cu_file_str(
3737
dtype_out,
3838
idtype,
3939
):
40-
cta_tile_q_choice = [128, 64, 32, 16]
40+
cta_tile_q_choice = [128, 64, 16]
4141

4242
def get_insts(attention_variant, dtype_out):
4343
return "\n".join(

aot_build_utils/generate_batch_ragged_prefill_inst.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def get_cu_file_str(
3737
dtype_out,
3838
idtype,
3939
):
40-
cta_tile_q_choice = [128, 64, 32, 16]
40+
cta_tile_q_choice = [128, 64, 16]
4141

4242
def get_insts(attention_variant, dtype_out):
4343
return "\n".join(

csrc/batch_prefill_paged_kernel_inst.jinja

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace flashinfer {
55

66
constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom;
77

8-
{% for cta_tile_q in [16, 32, 64, 128] %}
8+
{% for cta_tile_q in [16, 64, 128] %}
99
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<
1010
/*CTA_TILE_Q=*/{{cta_tile_q}}, {{head_dim_qk}}, {{head_dim_vo}}, {{pos_encoding_mode}}, {{use_fp16_qk_reduction}}, {{mask_mode}},
1111
{{ variant_name }}, PagedParams>(PagedParams params, {{ dtype_o }}* tmp_v, float* tmp_s, cudaStream_t stream);

csrc/batch_prefill_ragged_kernel_inst.jinja

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace flashinfer {
55

66
constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom;
77

8-
{% for cta_tile_q in [16, 32, 64, 128] %}
8+
{% for cta_tile_q in [16, 64, 128] %}
99
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<
1010
/*CTA_TILE_Q=*/{{cta_tile_q}}, {{head_dim_qk}}, {{head_dim_vo}}, {{pos_encoding_mode}}, {{use_fp16_qk_reduction}}, {{mask_mode}},
1111
{{ variant_name }}, RaggedParams>(RaggedParams params, {{ dtype_o }}* tmp_v, float* tmp_s, cudaStream_t stream);

include/flashinfer/utils.cuh

+1-9
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,6 @@
108108
__VA_ARGS__ \
109109
break; \
110110
} \
111-
case 32: { \
112-
constexpr uint32_t CTA_TILE_Q = 32; \
113-
__VA_ARGS__ \
114-
break; \
115-
} \
116111
case 16: { \
117112
constexpr uint32_t CTA_TILE_Q = 16; \
118113
__VA_ARGS__ \
@@ -302,12 +297,9 @@ inline uint32_t FA2DetermineCtaTileQ(int64_t avg_packed_qo_len, uint32_t head_di
302297
auto compute_capacity = GetCudaComputeCapability();
303298
if (compute_capacity.first >= 8) {
304299
// Ampere or newer
305-
if (avg_packed_qo_len > 32) {
300+
if (avg_packed_qo_len > 16) {
306301
// avg_packed_qo_len <= 64
307302
return 64;
308-
} else if (avg_packed_qo_len > 16) {
309-
// avg_packed_qo_len <= 32
310-
return 32;
311303
} else {
312304
// avg_packed_qo_len <= 16
313305
return 16;

0 commit comments

Comments
 (0)