Skip to content

Commit 989dbfa

Browse files
authored
perf: fix the iteration bound of SWA in FA2 prefill template (#714)
We forgot to divide the packed row index by group_size when computing the sliding window iteration bound, making it larger than its actual value, and slows down the execution. Thank @Ying1123 for spotting this bug.
1 parent 0f80329 commit 989dbfa

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

include/flashinfer/attention/prefill.cuh

+3-3
Original file line numberDiff line numberDiff line change
@@ -1226,7 +1226,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV
12261226
16 * NUM_WARPS_KV * NUM_MMA_KV);
12271227

12281228
const uint32_t window_iteration =
1229-
ceil_div(sub_if_greater_or_zero(kv_len + (bx + 1) * num_rows_per_cta,
1229+
ceil_div(sub_if_greater_or_zero(kv_len + (bx + 1) * num_rows_per_cta / group_size,
12301230
qo_len + window_left + chunk_start),
12311231
(16 * NUM_WARPS_KV * NUM_MMA_KV));
12321232

@@ -1652,7 +1652,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag
16521652
16 * NUM_WARPS_KV * NUM_MMA_KV);
16531653

16541654
const uint32_t window_iteration =
1655-
ceil_div(sub_if_greater_or_zero(kv_len + (qo_tile_idx + 1) * num_rows_per_cta,
1655+
ceil_div(sub_if_greater_or_zero(kv_len + (qo_tile_idx + 1) * num_rows_per_cta / group_size,
16561656
qo_len + window_left + chunk_start),
16571657
(16 * NUM_WARPS_KV * NUM_MMA_KV));
16581658

@@ -1980,7 +1980,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag
19801980
16 * NUM_WARPS_KV * NUM_MMA_KV);
19811981

19821982
const uint32_t window_iteration =
1983-
ceil_div(sub_if_greater_or_zero(kv_len + (qo_tile_idx + 1) * num_rows_per_cta,
1983+
ceil_div(sub_if_greater_or_zero(kv_len + (qo_tile_idx + 1) * num_rows_per_cta / group_size,
19841984
qo_len + window_left + chunk_start),
19851985
(16 * NUM_WARPS_KV * NUM_MMA_KV));
19861986

0 commit comments

Comments
 (0)