Skip to content

Commit cdac577

Browse files
authored
bugfix: Fix invalid kernel configuration for sm86 (#385)
Related issue: vllm-project/vllm#6395
1 parent 457a0ae commit cdac577

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

include/flashinfer/attention/prefill.cuh

+6-3
Original file line numberDiff line numberDiff line change
@@ -1804,7 +1804,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched(
18041804
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(
18051805
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id));
18061806
// we expect each sm execute two threadblocks
1807-
const int max_smem_per_threadblock = max_smem_per_sm / 2;
1807+
const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeIn) * 16) ? 2: 1;
1808+
const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;
18081809

18091810
constexpr uint32_t num_warps_x = get_num_warps_x<WARP_LAYOUT>();
18101811
constexpr uint32_t num_warps_z = get_num_warps_z<WARP_LAYOUT>();
@@ -1949,7 +1950,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
19491950
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm,
19501951
cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id));
19511952
// we expect each sm execute two threadblocks
1952-
const int max_smem_per_threadblock = max_smem_per_sm / 2;
1953+
const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeIn) * 16) ? 2: 1;
1954+
const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;
19531955

19541956
const uint32_t max_num_frags_z_reg =
19551957
(HEAD_DIM >= 128 && num_frags_x == 2 && pos_encoding_mode == PosEncodingMode::kRoPELlama &&
@@ -2089,7 +2091,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(
20892091
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm,
20902092
cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id));
20912093
// we expect each sm execute two threadblocks
2092-
const int max_smem_per_threadblock = max_smem_per_sm / 2;
2094+
const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeIn) * 16) ? 2: 1;
2095+
const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;
20932096

20942097
const uint32_t max_num_frags_z_reg =
20952098
(HEAD_DIM >= 128 && num_frags_x == 2 && pos_encoding_mode == PosEncodingMode::kRoPELlama &&

0 commit comments

Comments
 (0)