@@ -181,8 +181,8 @@ __device__ __forceinline__ void load_kv(
181
181
typename KTraits::SharedStorage* smem_storage, typename KTraits::DTypeKV* ckv,
182
182
typename KTraits::DTypeKV* kpe, typename KTraits::IdType* indices, const uint32_t ckv_stride_n,
183
183
const uint32_t ckv_stride_page, const uint32_t kpe_stride_n, const uint32_t kpe_stride_page,
184
- const uint32_t kv_bound , const uint32_t packed_block_iter_base, const uint_fastdiv& block_size ,
185
- const uint32_t stage_idx) {
184
+ const uint32_t packed_kv_bound , const uint32_t packed_block_iter_base,
185
+ const uint_fastdiv& block_size, const uint32_t stage_idx) {
186
186
using DTypeKV = typename KTraits::DTypeKV;
187
187
constexpr uint32_t UPCAST_STRIDE_CKV = KTraits::UPCAST_STRIDE_CKV;
188
188
constexpr uint32_t UPCAST_STRIDE_KPE = KTraits::UPCAST_STRIDE_KPE;
@@ -198,20 +198,23 @@ __device__ __forceinline__ void load_kv(
198
198
if constexpr (KTraits::NUM_MMA_KV == 1 ) {
199
199
if (warpgroup_idx == 0 ) {
200
200
uint32_t q, r;
201
+ uint32_t packed_block_iter =
202
+ packed_block_iter_base + lane_idx / 8 + lane_idx / 8 + warp_idx_in_wg * 4 ;
203
+ block_size.divmod (packed_block_iter, q, r);
201
204
202
- block_size.divmod (packed_block_iter_base + lane_idx / 8 + warp_idx_in_wg * 4 , q, r);
203
-
204
- DTypeKV* ckv_ptr = ckv + (q < kv_bound ? indices[q] : 0 ) * ckv_stride_page +
205
+ DTypeKV* ckv_ptr = ckv +
206
+ (packed_block_iter < packed_kv_bound ? indices[q] : 0 ) * ckv_stride_page +
205
207
r * ckv_stride_n + (lane_idx % 8 ) * upcast_size<DTypeKV>();
206
- DTypeKV* kpe_ptr = kpe + (q < kv_bound ? indices[q] : 0 ) * kpe_stride_page +
208
+ DTypeKV* kpe_ptr = kpe +
209
+ (packed_block_iter < packed_kv_bound ? indices[q] : 0 ) * kpe_stride_page +
207
210
r * kpe_stride_n + (lane_idx % 8 ) * upcast_size<DTypeKV>();
208
211
209
212
#pragma unroll
210
213
for (uint32_t mma_d = 0 ; mma_d < KTraits::NUM_MMA_D_CKV / 4 ; ++mma_d) {
211
214
uint32_t ckv_smem_offset_w = ckv_smem.template get_permuted_offset <UPCAST_STRIDE_CKV>(
212
215
warp_idx_in_wg * 4 + lane_idx / 8 , 8 * mma_d + lane_idx % 8 );
213
216
ckv_smem.load_128b_async <SharedMemFillMode::kFillZero >(ckv_smem_offset_w, ckv_ptr,
214
- q < kv_bound );
217
+ packed_block_iter < packed_kv_bound );
215
218
ckv_ptr += 8 * upcast_size<DTypeKV>();
216
219
}
217
220
@@ -220,22 +223,23 @@ __device__ __forceinline__ void load_kv(
220
223
uint32_t kpe_smem_offset_w = kpe_smem.template get_permuted_offset <UPCAST_STRIDE_KPE>(
221
224
warp_idx_in_wg * 4 + lane_idx / 8 , 8 * mma_d + lane_idx % 8 );
222
225
kpe_smem.load_128b_async <SharedMemFillMode::kFillZero >(kpe_smem_offset_w, kpe_ptr,
223
- q < kv_bound );
226
+ packed_block_iter < packed_kv_bound );
224
227
kpe_ptr += 8 * upcast_size<DTypeKV>();
225
228
}
226
229
}
227
230
} else {
228
231
#pragma unroll
229
232
for (uint32_t mma_kv = 0 ; mma_kv < KTraits::NUM_MMA_KV / 2 ; ++mma_kv) {
230
233
uint32_t q, r;
234
+ uint32_t packed_block_iter = packed_block_iter_base + lane_idx / 8 +
235
+ (warpgroup_idx + mma_kv * 2 ) * 16 + warp_idx_in_wg * 4 ;
236
+ block_size.divmod (packed_block_iter, q, r);
231
237
232
- block_size.divmod (packed_block_iter_base + lane_idx / 8 + (warpgroup_idx + mma_kv * 2 ) * 16 +
233
- warp_idx_in_wg * 4 ,
234
- q, r);
235
-
236
- DTypeKV* ckv_ptr = ckv + (q < kv_bound ? indices[q] : 0 ) * ckv_stride_page +
238
+ DTypeKV* ckv_ptr = ckv +
239
+ (packed_block_iter < packed_kv_bound ? indices[q] : 0 ) * ckv_stride_page +
237
240
r * ckv_stride_n + (lane_idx % 8 ) * upcast_size<DTypeKV>();
238
- DTypeKV* kpe_ptr = kpe + (q < kv_bound ? indices[q] : 0 ) * kpe_stride_page +
241
+ DTypeKV* kpe_ptr = kpe +
242
+ (packed_block_iter < packed_kv_bound ? indices[q] : 0 ) * kpe_stride_page +
239
243
r * kpe_stride_n + (lane_idx % 8 ) * upcast_size<DTypeKV>();
240
244
241
245
#pragma unroll
@@ -244,7 +248,7 @@ __device__ __forceinline__ void load_kv(
244
248
32 * mma_kv + warpgroup_idx * 16 + warp_idx_in_wg * 4 + lane_idx / 8 ,
245
249
8 * mma_d + lane_idx % 8 );
246
250
ckv_smem.load_128b_async <SharedMemFillMode::kFillZero >(ckv_smem_offset_w, ckv_ptr,
247
- q < kv_bound );
251
+ packed_block_iter < packed_kv_bound );
248
252
ckv_ptr += 8 * upcast_size<DTypeKV>();
249
253
}
250
254
@@ -254,7 +258,7 @@ __device__ __forceinline__ void load_kv(
254
258
32 * mma_kv + warpgroup_idx * 16 + warp_idx_in_wg * 4 + lane_idx / 8 ,
255
259
8 * mma_d + lane_idx % 8 );
256
260
kpe_smem.load_128b_async <SharedMemFillMode::kFillZero >(kpe_smem_offset_w, kpe_ptr,
257
- q < kv_bound );
261
+ packed_block_iter < packed_kv_bound );
258
262
kpe_ptr += 8 * upcast_size<DTypeKV>();
259
263
}
260
264
}
@@ -863,17 +867,17 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
863
867
uint32_t block_iter_base = kv_indptr * block_size + kv_start;
864
868
// last kv tile
865
869
__syncthreads ();
866
- uint32_t kv_bound = kv_indptr + (kv_len + block_size - 1 ) / block_size; // ceil_div
870
+ uint32_t packed_kv_bound = kv_indptr * block_size + kv_len;
867
871
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
868
- kpe_stride_n, kpe_stride_page, kv_bound ,
872
+ kpe_stride_n, kpe_stride_page, packed_kv_bound ,
869
873
block_iter_base + kv_tile_idx * CTA_TILE_KV, block_size,
870
874
kv_tile_idx % NUM_STAGES);
871
875
cp_async::commit_group ();
872
876
#pragma unroll
873
877
for (int stage_idx = 1 ; stage_idx < NUM_STAGES; ++stage_idx) {
874
878
if (kv_tile_idx - stage_idx >= 0 ) {
875
879
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
876
- kpe_stride_n, kpe_stride_page, kv_bound ,
880
+ kpe_stride_n, kpe_stride_page, packed_kv_bound ,
877
881
block_iter_base + (kv_tile_idx - stage_idx) * CTA_TILE_KV, block_size,
878
882
(kv_tile_idx - stage_idx) % NUM_STAGES);
879
883
cp_async::commit_group ();
@@ -903,7 +907,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
903
907
if (kv_tile_idx - NUM_STAGES >= 0 ) {
904
908
__syncthreads ();
905
909
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
906
- kpe_stride_n, kpe_stride_page, kv_bound ,
910
+ kpe_stride_n, kpe_stride_page, packed_kv_bound ,
907
911
block_iter_base + (kv_tile_idx - NUM_STAGES) * CTA_TILE_KV, block_size,
908
912
(kv_tile_idx - NUM_STAGES) % NUM_STAGES);
909
913
cp_async::commit_group ();
@@ -927,7 +931,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
927
931
928
932
__syncthreads ();
929
933
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
930
- kpe_stride_n, kpe_stride_page, kv_bound ,
934
+ kpe_stride_n, kpe_stride_page, packed_kv_bound ,
931
935
block_iter_base + (kv_tile_idx - NUM_STAGES) * CTA_TILE_KV, block_size,
932
936
(kv_tile_idx - NUM_STAGES) % NUM_STAGES);
933
937
cp_async::commit_group ();
0 commit comments