23
23
#endif
24
24
#include < cuda_runtime.h>
25
25
26
- #include < optional>
27
- #include < tuple>
28
-
29
26
#include " ../cp_async.cuh"
30
27
#include " ../fastdiv.cuh"
31
28
#include " ../layout.cuh"
@@ -175,65 +172,41 @@ __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* smem_offset, T
175
172
*smem_offset -= num_frags_z * 16 * channel_size_128b_in;
176
173
}
177
174
178
- template <bool produce_v, uint32_t page_size, uint32_t num_warps, uint32_t num_frags_y,
179
- uint32_t num_frags_z, PageStorage page_storage, QKVLayout kv_layout, typename DType,
180
- typename IdType>
175
+ template <bool produce_v, uint32_t num_warps, uint32_t num_frags_y, uint32_t num_frags_z,
176
+ PageStorage page_storage, QKVLayout kv_layout, typename DType, typename IdType>
181
177
__device__ __forceinline__ void page_produce_kv (
182
178
smem_t smem, uint32_t * smem_offset,
183
179
paged_kv_t <page_storage, kv_layout, DType, IdType>& paged_kv, const uint32_t kv_idx_base,
184
- const uint32_t page_iter_base , const uint32_t kv_len, const IdType last_indptr) {
180
+ const uint32_t packed_page_iter_base , const uint32_t kv_len, const IdType last_indptr) {
185
181
constexpr SharedMemFillMode fill_mode =
186
182
produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill ;
187
183
constexpr uint32_t head_dim = num_frags_y * 16 ;
188
184
constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b<DType>();
189
185
const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
190
186
const uint32_t kv_head_idx = blockIdx .z ;
191
187
uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8 ;
192
- if constexpr (page_size % 4 == 0 ) {
193
- #pragma unroll
194
- for (uint32_t i = 0 ; i < num_frags_z * 4 / num_warps; ++i) {
195
- const uint32_t page_iter = page_iter_base + (4 * num_warps * i + ty * 4 ) / page_size;
196
- const uint32_t entry_idx = (4 * num_warps * i + ty * 4 ) % page_size + tx / 8 ;
197
- DType* gptr =
198
- produce_v
199
- ? paged_kv.protective_get_v_ptr (page_iter, kv_head_idx, entry_idx,
200
- (tx % 8 ) * num_elems_per_128b<DType>(), last_indptr)
201
- : paged_kv.protective_get_k_ptr (page_iter, kv_head_idx, entry_idx,
202
- (tx % 8 ) * num_elems_per_128b<DType>(), last_indptr);
203
- #pragma unroll
204
- for (uint32_t j = 0 ; j < num_frags_y / 4 ; ++j) {
205
- smem.load_128b_async <fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
206
- *smem_offset = smem.advance_offset_by_column <8 >(*smem_offset, j);
207
- gptr += 8 * num_elems_per_128b<DType>();
208
- }
209
- kv_idx += num_warps * 4 ;
210
- *smem_offset = smem.advance_offset_by_row <num_warps * 4 , channel_size_128b_in>(*smem_offset) -
211
- 2 * num_frags_y;
212
- }
213
- *smem_offset -= num_frags_z * 16 * channel_size_128b_in;
214
- } else {
215
188
#pragma unroll
216
- for (uint32_t i = 0 ; i < num_frags_z * 4 / num_warps; ++i) {
217
- const uint32_t page_iter = page_iter_base + (4 * num_warps * i + ty * 4 + tx / 8 ) / page_size;
218
- const uint32_t entry_idx = (4 * num_warps * i + ty * 4 + tx / 8 ) % page_size;
219
- DType* gptr =
220
- produce_v
221
- ? paged_kv.protective_get_v_ptr (page_iter, kv_head_idx, entry_idx,
222
- (tx % 8 ) * num_elems_per_128b<DType>(), last_indptr)
223
- : paged_kv.protective_get_k_ptr (page_iter, kv_head_idx, entry_idx,
224
- (tx % 8 ) * num_elems_per_128b<DType>(), last_indptr);
225
- #pragma unroll
226
- for (uint32_t j = 0 ; j < num_frags_y / 4 ; ++j) {
227
- smem.load_128b_async <fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
228
- *smem_offset = smem.advance_offset_by_column <8 >(*smem_offset, j);
229
- gptr += 8 * num_elems_per_128b<DType>();
230
- }
231
- kv_idx += num_warps * 4 ;
232
- *smem_offset = smem.advance_offset_by_row <num_warps * 4 , channel_size_128b_in>(*smem_offset) -
233
- 2 * num_frags_y;
189
+ for (uint32_t i = 0 ; i < num_frags_z * 4 / num_warps; ++i) {
190
+ uint32_t page_iter, entry_idx;
191
+ paged_kv.page_size .divmod (packed_page_iter_base + ty * 4 + tx / 8 + 4 * num_warps * i,
192
+ page_iter, entry_idx);
193
+ DType* gptr =
194
+ produce_v
195
+ ? paged_kv.protective_get_v_ptr (page_iter, kv_head_idx, entry_idx,
196
+ (tx % 8 ) * num_elems_per_128b<DType>(), last_indptr)
197
+ : paged_kv.protective_get_k_ptr (page_iter, kv_head_idx, entry_idx,
198
+ (tx % 8 ) * num_elems_per_128b<DType>(), last_indptr);
199
+ #pragma unroll
200
+ for (uint32_t j = 0 ; j < num_frags_y / 4 ; ++j) {
201
+ smem.load_128b_async <fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
202
+ *smem_offset = smem.advance_offset_by_column <8 >(*smem_offset, j);
203
+ gptr += 8 * num_elems_per_128b<DType>();
234
204
}
235
- *smem_offset -= num_frags_z * 16 * channel_size_128b_in;
205
+ kv_idx += num_warps * 4 ;
206
+ *smem_offset = smem.advance_offset_by_row <num_warps * 4 , channel_size_128b_in>(*smem_offset) -
207
+ 2 * num_frags_y;
236
208
}
209
+ *smem_offset -= num_frags_z * 16 * channel_size_128b_in;
237
210
}
238
211
239
212
template <uint32_t num_frags_y>
@@ -1342,10 +1315,10 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel(
1342
1315
}
1343
1316
}
1344
1317
1345
- template <LogitsPostHook logits_post_hook, uint32_t page_size, MaskMode mask_mode ,
1346
- PosEncodingMode pos_encoding_mode , uint32_t num_frags_x , uint32_t num_frags_y ,
1347
- uint32_t num_frags_z, uint32_t num_warps, PageStorage page_storage, QKVLayout kv_layout ,
1348
- typename DTypeIn, typename DTypeQKAccum, typename DTypeOut, typename IdType>
1318
+ template <LogitsPostHook logits_post_hook, MaskMode mask_mode, PosEncodingMode pos_encoding_mode ,
1319
+ uint32_t num_frags_x , uint32_t num_frags_y , uint32_t num_frags_z, uint32_t num_warps ,
1320
+ PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeQKAccum ,
1321
+ typename DTypeOut, typename IdType>
1349
1322
__global__ void BatchPrefillWithPagedKVCacheKernel (
1350
1323
IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices,
1351
1324
DTypeIn* __restrict__ q, paged_kv_t <page_storage, kv_layout, DTypeIn, IdType> paged_kv,
@@ -1448,12 +1421,12 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
1448
1421
smem_t ::get_permuted_offset<channel_size_128b_in>(ty * 4 + tx / 8 , tx % 8 );
1449
1422
const IdType last_indptr = paged_kv.indptr [paged_kv.batch_size ];
1450
1423
1451
- uint32_t page_iter_base = paged_kv.indptr [request_idx];
1452
- page_produce_kv<false , page_size, num_warps, num_frags_y, num_frags_z>(
1453
- k_smem, &kv_smem_offset_w, paged_kv, 0 , page_iter_base , kv_len, last_indptr);
1424
+ uint32_t packed_page_iter_base = paged_kv.indptr [request_idx] * paged_kv. page_size ;
1425
+ page_produce_kv<false , num_warps, num_frags_y, num_frags_z>(
1426
+ k_smem, &kv_smem_offset_w, paged_kv, 0 , packed_page_iter_base , kv_len, last_indptr);
1454
1427
cp_async::commit_group ();
1455
- page_produce_kv<true , page_size, num_warps, num_frags_y, num_frags_z>(
1456
- v_smem, &kv_smem_offset_w, paged_kv, 0 , page_iter_base , kv_len, last_indptr);
1428
+ page_produce_kv<true , num_warps, num_frags_y, num_frags_z>(
1429
+ v_smem, &kv_smem_offset_w, paged_kv, 0 , packed_page_iter_base , kv_len, last_indptr);
1457
1430
cp_async::commit_group ();
1458
1431
1459
1432
const uint32_t num_iterations = ceil_div (
@@ -1508,10 +1481,10 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
1508
1481
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(s_frag, o_frag, m, d);
1509
1482
1510
1483
block.sync ();
1511
- page_iter_base += 16 * num_frags_z / page_size ;
1512
- page_produce_kv<false , page_size, num_warps, num_frags_y, num_frags_z>(
1513
- k_smem, &kv_smem_offset_w, paged_kv, (iter + 1 ) * 16 * num_frags_z, page_iter_base, kv_len ,
1514
- last_indptr);
1484
+ packed_page_iter_base += 16 * num_frags_z;
1485
+ page_produce_kv<false , num_warps, num_frags_y, num_frags_z>(
1486
+ k_smem, &kv_smem_offset_w, paged_kv, (iter + 1 ) * 16 * num_frags_z, packed_page_iter_base ,
1487
+ kv_len, last_indptr);
1515
1488
cp_async::commit_group ();
1516
1489
cp_async::wait_group<1 >();
1517
1490
block.sync ();
@@ -1521,9 +1494,9 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
1521
1494
o_frag, d);
1522
1495
1523
1496
block.sync ();
1524
- page_produce_kv<true , page_size, num_warps, num_frags_y, num_frags_z>(
1525
- v_smem, &kv_smem_offset_w, paged_kv, (iter + 1 ) * 16 * num_frags_z, page_iter_base, kv_len ,
1526
- last_indptr);
1497
+ page_produce_kv<true , num_warps, num_frags_y, num_frags_z>(
1498
+ v_smem, &kv_smem_offset_w, paged_kv, (iter + 1 ) * 16 * num_frags_z, packed_page_iter_base ,
1499
+ kv_len, last_indptr);
1527
1500
cp_async::commit_group ();
1528
1501
}
1529
1502
cp_async::wait_group<0 >();
@@ -1776,7 +1749,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
1776
1749
return cudaSuccess;
1777
1750
}
1778
1751
1779
- template <PageStorage page_storage, uint32_t num_frags_x, uint32_t PAGE_SIZE, uint32_t HEAD_DIM,
1752
+ template <PageStorage page_storage, uint32_t num_frags_x, uint32_t HEAD_DIM,
1780
1753
LogitsPostHook LOGITS_POST_HOOK, QKVLayout kv_layout, PosEncodingMode pos_encoding_mode,
1781
1754
bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut,
1782
1755
typename IdType>
@@ -1831,8 +1804,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(
1831
1804
throw std::invalid_argument (err_msg.str ());
1832
1805
} else {
1833
1806
auto kernel = BatchPrefillWithPagedKVCacheKernel<
1834
- LOGITS_POST_HOOK, PAGE_SIZE, MASK_MODE, pos_encoding_mode, num_frags_x, num_frags_y,
1835
- num_frags_z, num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>;
1807
+ LOGITS_POST_HOOK, MASK_MODE, pos_encoding_mode, num_frags_x, num_frags_y, num_frags_z ,
1808
+ num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>;
1836
1809
uint32_t smem_size =
1837
1810
(num_frags_x * num_warps + num_frags_z * 2 ) * 16 * HEAD_DIM * sizeof (DTypeIn);
1838
1811
FLASHINFER_CUDA_CALL (
0 commit comments