Skip to content

Commit 82fd8c7

Browse files
authored
refactor: remove page_size from template parameters for prefill kernels (#306)
Similar to #301 , in this PR we remove `page_size` from template parameters so that we can support any `page_size` for prefill kernels (previously we only support something like 1,4,8,16), as well as reduce binary size and accelerate compilation time.
1 parent 955dfc5 commit 82fd8c7

19 files changed

+146
-241
lines changed

CMakeLists.txt

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ flashinfer_option(FLASHINFER_TVM_SOURCE_DIR "The path to tvm for building tvm bi
3737

3838
# The following configurations can impact the binary
3939
# size of the generated library
40-
flashinfer_option(FLASHINFER_GEN_PAGE_SIZES "Prefill page sizes to enable" 1 16 32)
4140
flashinfer_option(FLASHINFER_GEN_HEAD_DIMS "Head dims to enable" 64 128 256)
4241
flashinfer_option(FLASHINFER_GEN_KV_LAYOUTS "KV layouts to enable" 0 1)
4342
flashinfer_option(FLASHINFER_GEN_LOGITS_POST_HOOKS "Logits post hooks" 0 1)
@@ -80,7 +79,6 @@ if(FLASHINFER_ENABLE_BF16)
8079
endif(FLASHINFER_ENABLE_BF16)
8180

8281
# generate kernel inst
83-
set (PAGE_SIZES ${FLASHINFER_GEN_PAGE_SIZES})
8482
set (HEAD_DIMS ${FLASHINFER_GEN_HEAD_DIMS})
8583
set (LOGITS_POST_HOOKS ${FLASHINFER_GEN_LOGITS_POST_HOOKS})
8684
set (KV_LAYOUTS ${FLASHINFER_GEN_KV_LAYOUTS})
@@ -103,7 +101,6 @@ if(FLASHINFER_ENABLE_BF16)
103101
endif(FLASHINFER_ENABLE_BF16)
104102

105103
# log options
106-
message(STATUS "FLASHINFER_PAGE_SIZES=${PAGE_SIZES}")
107104
message(STATUS "FLASHINFER_HEAD_DIMS=${HEAD_DIMS}")
108105
message(STATUS "FLASHINFER_KV_LAYOUTS=${KV_LAYOUTS}")
109106
message(STATUS "FLASHINFER_POS_ENCODING_MODES=${POS_ENCODING_MODES}")
@@ -115,7 +112,7 @@ file(MAKE_DIRECTORY ${PROJECT_SOURCE_DIR}/src/generated)
115112
set(dispatch_inc_file ${PROJECT_SOURCE_DIR}/src/dispatch.inc)
116113
add_custom_command(
117114
OUTPUT ${dispatch_inc_file}
118-
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --page_sizes ${FLASHINFER_GEN_PAGE_SIZES} --logits_post_hooks ${LOGITS_POST_HOOKS} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES}
115+
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --logits_post_hooks ${LOGITS_POST_HOOKS} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES}
119116
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py
120117
COMMENT "Generating additional source file ${generated_dispatch_inc}"
121118
VERBATIM
@@ -249,33 +246,31 @@ foreach(head_dim IN LISTS HEAD_DIMS)
249246
endforeach(head_dim)
250247

251248
# batch paged prefill kernel inst generation
252-
foreach(page_size IN LISTS PAGE_SIZES)
253-
foreach(head_dim IN LISTS HEAD_DIMS)
254-
foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS)
255-
foreach(kv_layout IN LISTS KV_LAYOUTS)
256-
foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES)
257-
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
258-
foreach(mask_mode IN LISTS MASK_MODES)
259-
foreach(dtype IN LISTS PREFILL_DTYPES)
260-
foreach(idtype IN LISTS IDTYPES)
261-
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_page_${page_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
262-
add_custom_command(
263-
OUTPUT ${generated_kernel_src}
264-
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
265-
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py
266-
COMMENT "Generating additional source file ${generated_kernel_src}"
267-
VERBATIM
268-
)
269-
list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src})
270-
endforeach(idtype)
271-
endforeach(dtype)
272-
endforeach(mask_mode)
273-
endforeach(allow_fp16_qk_reduction)
274-
endforeach(pos_encoding_mode)
275-
endforeach(kv_layout)
276-
endforeach(logits_post_hook)
277-
endforeach(head_dim)
278-
endforeach(page_size)
249+
foreach(head_dim IN LISTS HEAD_DIMS)
250+
foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS)
251+
foreach(kv_layout IN LISTS KV_LAYOUTS)
252+
foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES)
253+
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
254+
foreach(mask_mode IN LISTS MASK_MODES)
255+
foreach(dtype IN LISTS PREFILL_DTYPES)
256+
foreach(idtype IN LISTS IDTYPES)
257+
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
258+
add_custom_command(
259+
OUTPUT ${generated_kernel_src}
260+
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
261+
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py
262+
COMMENT "Generating additional source file ${generated_kernel_src}"
263+
VERBATIM
264+
)
265+
list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src})
266+
endforeach(idtype)
267+
endforeach(dtype)
268+
endforeach(mask_mode)
269+
endforeach(allow_fp16_qk_reduction)
270+
endforeach(pos_encoding_mode)
271+
endforeach(kv_layout)
272+
endforeach(logits_post_hook)
273+
endforeach(head_dim)
279274

280275
# batch ragged prefill kernel inst generation
281276
foreach(head_dim IN LISTS HEAD_DIMS)

cmake/config.cmake

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ set(FLASHINFER_DISTRIBUTED ON)
2323
# The following configurations can impact the binary
2424
# size of the generated library
2525
set(FLASHINFER_GEN_LOGITS_POST_HOOKS 0)
26-
set(FLASHINFER_GEN_PAGE_SIZES 1 16 32)
2726
set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
2827
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
2928
set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2)

include/flashinfer/attention/decode.cuh

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -601,9 +601,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
601601
static_assert(num_stages_smem <= bdx);
602602
#pragma unroll
603603
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
604-
k_ptrs_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] = paged_kv.protective_get_k_ptr(
605-
cur_page_indptr_begin + (((j * bdz + tz) * bdy + ty) * bdx + tx) / paged_kv.page_size,
606-
kv_head_idx, (((j * bdz + tz) * bdy + ty) * bdx + tx) % paged_kv.page_size, 0, last_indptr);
604+
uint32_t q, r;
605+
paged_kv.page_size.divmod(((j * bdz + tz) * bdy + ty) * bdx + tx, q, r);
606+
k_ptrs_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] =
607+
paged_kv.protective_get_k_ptr(cur_page_indptr_begin + q, kv_head_idx, r, 0, last_indptr);
607608
}
608609
block.sync();
609610

@@ -643,15 +644,12 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
643644
if ((iter + num_stages_smem) % bdx == 0) {
644645
#pragma unroll
645646
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
647+
uint32_t q, r;
648+
paged_kv.page_size.divmod(((iter + num_stages_smem) * tile_size_per_bdx * bdy * bdz +
649+
((j * bdz + tz) * bdy + ty) * bdx + tx),
650+
q, r);
646651
k_ptrs_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] = paged_kv.protective_get_k_ptr(
647-
cur_page_indptr_begin + ((iter + num_stages_smem) * tile_size_per_bdx * bdy * bdz +
648-
((j * bdz + tz) * bdy + ty) * bdx + tx) /
649-
paged_kv.page_size,
650-
kv_head_idx,
651-
((iter + num_stages_smem) * tile_size_per_bdx * bdy * bdz +
652-
((j * bdz + tz) * bdy + ty) * bdx + tx) %
653-
paged_kv.page_size,
654-
0, last_indptr);
652+
cur_page_indptr_begin + q, kv_head_idx, r, 0, last_indptr);
655653
}
656654
}
657655
// compute qk

include/flashinfer/attention/handler.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
133133
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
134134
&num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size));
135135
max_grid_size = num_blocks_per_sm * num_sm;
136-
if (batch_size * num_kv_heads >= num_sm) {
136+
if (batch_size * num_kv_heads >= max_grid_size) {
137137
tmp_size = 0;
138138
new_batch_size = batch_size;
139139
} else {

include/flashinfer/attention/prefill.cuh

Lines changed: 41 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@
2323
#endif
2424
#include <cuda_runtime.h>
2525

26-
#include <optional>
27-
#include <tuple>
28-
2926
#include "../cp_async.cuh"
3027
#include "../fastdiv.cuh"
3128
#include "../layout.cuh"
@@ -175,65 +172,41 @@ __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* smem_offset, T
175172
*smem_offset -= num_frags_z * 16 * channel_size_128b_in;
176173
}
177174

178-
template <bool produce_v, uint32_t page_size, uint32_t num_warps, uint32_t num_frags_y,
179-
uint32_t num_frags_z, PageStorage page_storage, QKVLayout kv_layout, typename DType,
180-
typename IdType>
175+
template <bool produce_v, uint32_t num_warps, uint32_t num_frags_y, uint32_t num_frags_z,
176+
PageStorage page_storage, QKVLayout kv_layout, typename DType, typename IdType>
181177
__device__ __forceinline__ void page_produce_kv(
182178
smem_t smem, uint32_t* smem_offset,
183179
paged_kv_t<page_storage, kv_layout, DType, IdType>& paged_kv, const uint32_t kv_idx_base,
184-
const uint32_t page_iter_base, const uint32_t kv_len, const IdType last_indptr) {
180+
const uint32_t packed_page_iter_base, const uint32_t kv_len, const IdType last_indptr) {
185181
constexpr SharedMemFillMode fill_mode =
186182
produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill;
187183
constexpr uint32_t head_dim = num_frags_y * 16;
188184
constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b<DType>();
189185
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
190186
const uint32_t kv_head_idx = blockIdx.z;
191187
uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8;
192-
if constexpr (page_size % 4 == 0) {
193-
#pragma unroll
194-
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps; ++i) {
195-
const uint32_t page_iter = page_iter_base + (4 * num_warps * i + ty * 4) / page_size;
196-
const uint32_t entry_idx = (4 * num_warps * i + ty * 4) % page_size + tx / 8;
197-
DType* gptr =
198-
produce_v
199-
? paged_kv.protective_get_v_ptr(page_iter, kv_head_idx, entry_idx,
200-
(tx % 8) * num_elems_per_128b<DType>(), last_indptr)
201-
: paged_kv.protective_get_k_ptr(page_iter, kv_head_idx, entry_idx,
202-
(tx % 8) * num_elems_per_128b<DType>(), last_indptr);
203-
#pragma unroll
204-
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
205-
smem.load_128b_async<fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
206-
*smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j);
207-
gptr += 8 * num_elems_per_128b<DType>();
208-
}
209-
kv_idx += num_warps * 4;
210-
*smem_offset = smem.advance_offset_by_row<num_warps * 4, channel_size_128b_in>(*smem_offset) -
211-
2 * num_frags_y;
212-
}
213-
*smem_offset -= num_frags_z * 16 * channel_size_128b_in;
214-
} else {
215188
#pragma unroll
216-
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps; ++i) {
217-
const uint32_t page_iter = page_iter_base + (4 * num_warps * i + ty * 4 + tx / 8) / page_size;
218-
const uint32_t entry_idx = (4 * num_warps * i + ty * 4 + tx / 8) % page_size;
219-
DType* gptr =
220-
produce_v
221-
? paged_kv.protective_get_v_ptr(page_iter, kv_head_idx, entry_idx,
222-
(tx % 8) * num_elems_per_128b<DType>(), last_indptr)
223-
: paged_kv.protective_get_k_ptr(page_iter, kv_head_idx, entry_idx,
224-
(tx % 8) * num_elems_per_128b<DType>(), last_indptr);
225-
#pragma unroll
226-
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
227-
smem.load_128b_async<fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
228-
*smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j);
229-
gptr += 8 * num_elems_per_128b<DType>();
230-
}
231-
kv_idx += num_warps * 4;
232-
*smem_offset = smem.advance_offset_by_row<num_warps * 4, channel_size_128b_in>(*smem_offset) -
233-
2 * num_frags_y;
189+
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps; ++i) {
190+
uint32_t page_iter, entry_idx;
191+
paged_kv.page_size.divmod(packed_page_iter_base + ty * 4 + tx / 8 + 4 * num_warps * i,
192+
page_iter, entry_idx);
193+
DType* gptr =
194+
produce_v
195+
? paged_kv.protective_get_v_ptr(page_iter, kv_head_idx, entry_idx,
196+
(tx % 8) * num_elems_per_128b<DType>(), last_indptr)
197+
: paged_kv.protective_get_k_ptr(page_iter, kv_head_idx, entry_idx,
198+
(tx % 8) * num_elems_per_128b<DType>(), last_indptr);
199+
#pragma unroll
200+
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
201+
smem.load_128b_async<fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
202+
*smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j);
203+
gptr += 8 * num_elems_per_128b<DType>();
234204
}
235-
*smem_offset -= num_frags_z * 16 * channel_size_128b_in;
205+
kv_idx += num_warps * 4;
206+
*smem_offset = smem.advance_offset_by_row<num_warps * 4, channel_size_128b_in>(*smem_offset) -
207+
2 * num_frags_y;
236208
}
209+
*smem_offset -= num_frags_z * 16 * channel_size_128b_in;
237210
}
238211

239212
template <uint32_t num_frags_y>
@@ -1342,10 +1315,10 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel(
13421315
}
13431316
}
13441317

1345-
template <LogitsPostHook logits_post_hook, uint32_t page_size, MaskMode mask_mode,
1346-
PosEncodingMode pos_encoding_mode, uint32_t num_frags_x, uint32_t num_frags_y,
1347-
uint32_t num_frags_z, uint32_t num_warps, PageStorage page_storage, QKVLayout kv_layout,
1348-
typename DTypeIn, typename DTypeQKAccum, typename DTypeOut, typename IdType>
1318+
template <LogitsPostHook logits_post_hook, MaskMode mask_mode, PosEncodingMode pos_encoding_mode,
1319+
uint32_t num_frags_x, uint32_t num_frags_y, uint32_t num_frags_z, uint32_t num_warps,
1320+
PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeQKAccum,
1321+
typename DTypeOut, typename IdType>
13491322
__global__ void BatchPrefillWithPagedKVCacheKernel(
13501323
IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices,
13511324
DTypeIn* __restrict__ q, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
@@ -1448,12 +1421,12 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
14481421
smem_t::get_permuted_offset<channel_size_128b_in>(ty * 4 + tx / 8, tx % 8);
14491422
const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size];
14501423

1451-
uint32_t page_iter_base = paged_kv.indptr[request_idx];
1452-
page_produce_kv<false, page_size, num_warps, num_frags_y, num_frags_z>(
1453-
k_smem, &kv_smem_offset_w, paged_kv, 0, page_iter_base, kv_len, last_indptr);
1424+
uint32_t packed_page_iter_base = paged_kv.indptr[request_idx] * paged_kv.page_size;
1425+
page_produce_kv<false, num_warps, num_frags_y, num_frags_z>(
1426+
k_smem, &kv_smem_offset_w, paged_kv, 0, packed_page_iter_base, kv_len, last_indptr);
14541427
cp_async::commit_group();
1455-
page_produce_kv<true, page_size, num_warps, num_frags_y, num_frags_z>(
1456-
v_smem, &kv_smem_offset_w, paged_kv, 0, page_iter_base, kv_len, last_indptr);
1428+
page_produce_kv<true, num_warps, num_frags_y, num_frags_z>(
1429+
v_smem, &kv_smem_offset_w, paged_kv, 0, packed_page_iter_base, kv_len, last_indptr);
14571430
cp_async::commit_group();
14581431

14591432
const uint32_t num_iterations = ceil_div(
@@ -1508,10 +1481,10 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
15081481
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(s_frag, o_frag, m, d);
15091482

15101483
block.sync();
1511-
page_iter_base += 16 * num_frags_z / page_size;
1512-
page_produce_kv<false, page_size, num_warps, num_frags_y, num_frags_z>(
1513-
k_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * num_frags_z, page_iter_base, kv_len,
1514-
last_indptr);
1484+
packed_page_iter_base += 16 * num_frags_z;
1485+
page_produce_kv<false, num_warps, num_frags_y, num_frags_z>(
1486+
k_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * num_frags_z, packed_page_iter_base,
1487+
kv_len, last_indptr);
15151488
cp_async::commit_group();
15161489
cp_async::wait_group<1>();
15171490
block.sync();
@@ -1521,9 +1494,9 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
15211494
o_frag, d);
15221495

15231496
block.sync();
1524-
page_produce_kv<true, page_size, num_warps, num_frags_y, num_frags_z>(
1525-
v_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * num_frags_z, page_iter_base, kv_len,
1526-
last_indptr);
1497+
page_produce_kv<true, num_warps, num_frags_y, num_frags_z>(
1498+
v_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * num_frags_z, packed_page_iter_base,
1499+
kv_len, last_indptr);
15271500
cp_async::commit_group();
15281501
}
15291502
cp_async::wait_group<0>();
@@ -1776,7 +1749,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
17761749
return cudaSuccess;
17771750
}
17781751

1779-
template <PageStorage page_storage, uint32_t num_frags_x, uint32_t PAGE_SIZE, uint32_t HEAD_DIM,
1752+
template <PageStorage page_storage, uint32_t num_frags_x, uint32_t HEAD_DIM,
17801753
LogitsPostHook LOGITS_POST_HOOK, QKVLayout kv_layout, PosEncodingMode pos_encoding_mode,
17811754
bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut,
17821755
typename IdType>
@@ -1831,8 +1804,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(
18311804
throw std::invalid_argument(err_msg.str());
18321805
} else {
18331806
auto kernel = BatchPrefillWithPagedKVCacheKernel<
1834-
LOGITS_POST_HOOK, PAGE_SIZE, MASK_MODE, pos_encoding_mode, num_frags_x, num_frags_y,
1835-
num_frags_z, num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>;
1807+
LOGITS_POST_HOOK, MASK_MODE, pos_encoding_mode, num_frags_x, num_frags_y, num_frags_z,
1808+
num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>;
18361809
uint32_t smem_size =
18371810
(num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn);
18381811
FLASHINFER_CUDA_CALL(

include/flashinfer/fastdiv.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
*/
2020
#ifndef FLASHINFER_FASTDIV_CUH_
2121
#define FLASHINFER_FASTDIV_CUH_
22+
#include <cstdint>
2223

2324
namespace flashinfer {
2425

0 commit comments

Comments
 (0)