@@ -1338,6 +1338,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg
1338
1338
constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16 ;
1339
1339
const uint32_t qo_len = q_indptr[request_idx + 1 ] - q_indptr[request_idx],
1340
1340
kv_len = kv_indptr[request_idx + 1 ] - kv_indptr[request_idx];
1341
+ const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1 ;
1341
1342
const uint32_t window_left = (maybe_window_left >= 0 ) ? maybe_window_left : kv_len;
1342
1343
const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len;
1343
1344
const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0 ;
@@ -1558,7 +1559,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg
1558
1559
// normalize d
1559
1560
normalize_d<num_frags_x, num_frags_y>(o_frag, m, d);
1560
1561
1561
- const uint32_t num_kv_chunks = ceil_div ( max (kv_len, 1 ), kv_chunk_size) ;
1562
+ const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1 ) / kv_chunk_size;
1562
1563
1563
1564
// write back
1564
1565
write_o_reg_gmem<num_warps_x, num_warps_z, num_frags_x, num_frags_y>(
@@ -1632,6 +1633,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
1632
1633
1 ) * paged_kv.page_size +
1633
1634
paged_kv.last_page_len [request_idx]
1634
1635
: 0 ;
1636
+ const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1 ;
1635
1637
const uint32_t window_left = (maybe_window_left >= 0 ) ? maybe_window_left : kv_len;
1636
1638
const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len;
1637
1639
const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0 ;
@@ -1872,7 +1874,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
1872
1874
// normalize d
1873
1875
normalize_d<num_frags_x, num_frags_y>(o_frag, m, d);
1874
1876
1875
- const uint32_t num_kv_chunks = ceil_div ( max (kv_len, 1 ), kv_chunk_size) ;
1877
+ const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1 ) / kv_chunk_size;
1876
1878
1877
1879
// write_back
1878
1880
write_o_reg_gmem<num_warps_x, num_warps_z, num_frags_x, num_frags_y>(
0 commit comments