Skip to content

Commit 5f0159e

Browse files
authored
fix: resolve cu121 compile wired issue (#446)
cc @yzh119 @Ying1123 @Yard1 @comaniac
2 parents 838d050 + 2740a02 commit 5f0159e

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

include/flashinfer/attention/prefill.cuh

+4-2
Original file line numberDiff line numberDiff line change
@@ -1338,6 +1338,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg
13381338
constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16;
13391339
const uint32_t qo_len = q_indptr[request_idx + 1] - q_indptr[request_idx],
13401340
kv_len = kv_indptr[request_idx + 1] - kv_indptr[request_idx];
1341+
const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1;
13411342
const uint32_t window_left = (maybe_window_left >= 0) ? maybe_window_left : kv_len;
13421343
const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len;
13431344
const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0;
@@ -1558,7 +1559,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg
15581559
// normalize d
15591560
normalize_d<num_frags_x, num_frags_y>(o_frag, m, d);
15601561

1561-
const uint32_t num_kv_chunks = ceil_div(max(kv_len, 1), kv_chunk_size);
1562+
const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size;
15621563

15631564
// write back
15641565
write_o_reg_gmem<num_warps_x, num_warps_z, num_frags_x, num_frags_y>(
@@ -1632,6 +1633,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
16321633
1) * paged_kv.page_size +
16331634
paged_kv.last_page_len[request_idx]
16341635
: 0;
1636+
const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1;
16351637
const uint32_t window_left = (maybe_window_left >= 0) ? maybe_window_left : kv_len;
16361638
const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len;
16371639
const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0;
@@ -1872,7 +1874,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
18721874
// normalize d
18731875
normalize_d<num_frags_x, num_frags_y>(o_frag, m, d);
18741876

1875-
const uint32_t num_kv_chunks = ceil_div(max(kv_len, 1), kv_chunk_size);
1877+
const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size;
18761878

18771879
// write_back
18781880
write_o_reg_gmem<num_warps_x, num_warps_z, num_frags_x, num_frags_y>(

0 commit comments

Comments
 (0)