@@ -1804,7 +1804,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched(
1804
1804
FLASHINFER_CUDA_CALL (cudaDeviceGetAttribute (
1805
1805
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id));
1806
1806
// we expect each sm execute two threadblocks
1807
- const int max_smem_per_threadblock = max_smem_per_sm / 2 ;
1807
+ const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof (DTypeIn) * 16 ) ? 2 : 1 ;
1808
+ const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;
1808
1809
1809
1810
constexpr uint32_t num_warps_x = get_num_warps_x<WARP_LAYOUT>();
1810
1811
constexpr uint32_t num_warps_z = get_num_warps_z<WARP_LAYOUT>();
@@ -1949,7 +1950,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
1949
1950
FLASHINFER_CUDA_CALL (cudaDeviceGetAttribute (&max_smem_per_sm,
1950
1951
cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id));
1951
1952
// we expect each sm execute two threadblocks
1952
- const int max_smem_per_threadblock = max_smem_per_sm / 2 ;
1953
+ const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof (DTypeIn) * 16 ) ? 2 : 1 ;
1954
+ const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;
1953
1955
1954
1956
const uint32_t max_num_frags_z_reg =
1955
1957
(HEAD_DIM >= 128 && num_frags_x == 2 && pos_encoding_mode == PosEncodingMode::kRoPELlama &&
@@ -2089,7 +2091,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(
2089
2091
FLASHINFER_CUDA_CALL (cudaDeviceGetAttribute (&max_smem_per_sm,
2090
2092
cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id));
2091
2093
// we expect each sm execute two threadblocks
2092
- const int max_smem_per_threadblock = max_smem_per_sm / 2 ;
2094
+ const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof (DTypeIn) * 16 ) ? 2 : 1 ;
2095
+ const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;
2093
2096
2094
2097
const uint32_t max_num_frags_z_reg =
2095
2098
(HEAD_DIM >= 128 && num_frags_x == 2 && pos_encoding_mode == PosEncodingMode::kRoPELlama &&
0 commit comments