@@ -176,8 +176,8 @@ __device__ __forceinline__ void load_kv(
176
176
typename KTraits::SharedStorage* smem_storage, typename KTraits::DTypeKV* ckv,
177
177
typename KTraits::DTypeKV* kpe, typename KTraits::IdType* indices, const uint32_t ckv_stride_n,
178
178
const uint32_t ckv_stride_page, const uint32_t kpe_stride_n, const uint32_t kpe_stride_page,
179
- const uint32_t kv_bound , const uint32_t packed_block_iter_base, const uint_fastdiv& block_size ,
180
- const uint32_t stage_idx) {
179
+ const uint32_t packed_kv_bound , const uint32_t packed_block_iter_base,
180
+ const uint_fastdiv& block_size, const uint32_t stage_idx) {
181
181
using DTypeKV = typename KTraits::DTypeKV;
182
182
constexpr uint32_t UPCAST_STRIDE_CKV = KTraits::UPCAST_STRIDE_CKV;
183
183
constexpr uint32_t UPCAST_STRIDE_KPE = KTraits::UPCAST_STRIDE_KPE;
@@ -195,13 +195,15 @@ __device__ __forceinline__ void load_kv(
195
195
#pragma unroll
196
196
for (uint32_t j = 0 ; j < 2 ; ++j) {
197
197
uint32_t q, r;
198
+ uint32_t packed_block_iter =
199
+ packed_block_iter_base + lane_idx / 8 + (j + mma_kv * 2 ) * 16 + warp_idx_in_wg * 4 ;
200
+ block_size.divmod (packed_block_iter, q, r);
198
201
199
- block_size.divmod (
200
- packed_block_iter_base + lane_idx / 8 + (j + mma_kv * 2 ) * 16 + warp_idx_in_wg * 4 , q, r);
201
-
202
- DTypeKV* ckv_ptr = ckv + (q < kv_bound ? indices[q] : 0 ) * ckv_stride_page +
202
+ DTypeKV* ckv_ptr = ckv +
203
+ (packed_block_iter < packed_kv_bound ? indices[q] : 0 ) * ckv_stride_page +
203
204
r * ckv_stride_n + (lane_idx % 8 ) * upcast_size<DTypeKV>();
204
- DTypeKV* kpe_ptr = kpe + (q < kv_bound ? indices[q] : 0 ) * kpe_stride_page +
205
+ DTypeKV* kpe_ptr = kpe +
206
+ (packed_block_iter < packed_kv_bound ? indices[q] : 0 ) * kpe_stride_page +
205
207
r * kpe_stride_n + (lane_idx % 8 ) * upcast_size<DTypeKV>();
206
208
uint32_t ckv_smem_offset_w = get_swizzle_offset<KTraits::SWIZZLE_MODE_CKV, UPCAST_STRIDE_CKV>(
207
209
32 * mma_kv + j * 16 + warp_idx_in_wg * 4 + lane_idx / 8 , 8 * 0 + lane_idx % 8 );
@@ -211,8 +213,8 @@ __device__ __forceinline__ void load_kv(
211
213
#pragma unroll
212
214
for (uint32_t mma_d = 0 ; mma_d < KTraits::NUM_MMA_D_CKV / 4 ; ++mma_d) {
213
215
if constexpr (predicate) {
214
- ckv_smem.load_128b_async <SharedMemFillMode::kFillZero >(ckv_smem_offset_w, ckv_ptr,
215
- q < kv_bound );
216
+ ckv_smem.load_128b_async <SharedMemFillMode::kFillZero >(
217
+ ckv_smem_offset_w, ckv_ptr, packed_block_iter < packed_kv_bound );
216
218
} else {
217
219
ckv_smem.load_128b_async (ckv_smem_offset_w, ckv_ptr);
218
220
}
@@ -223,8 +225,8 @@ __device__ __forceinline__ void load_kv(
223
225
#pragma unroll
224
226
for (uint32_t mma_d = 0 ; mma_d < KTraits::NUM_MMA_D_KPE / 4 ; ++mma_d) {
225
227
if constexpr (predicate) {
226
- kpe_smem.load_128b_async <SharedMemFillMode::kFillZero >(kpe_smem_offset_w, kpe_ptr,
227
- q < kv_bound );
228
+ kpe_smem.load_128b_async <SharedMemFillMode::kFillZero >(
229
+ kpe_smem_offset_w, kpe_ptr, packed_block_iter < packed_kv_bound );
228
230
} else {
229
231
kpe_smem.load_128b_async (kpe_smem_offset_w, kpe_ptr);
230
232
}
@@ -673,7 +675,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
673
675
const uint32_t qo_upperbound =
674
676
min (q_len, ceil_div (qo_packed_idx_base + KTraits::CTA_TILE_Q, num_heads));
675
677
676
- uint32_t kv_bound = kv_indptr + (kv_len + block_size - 1 ) / block_size ;
678
+ uint32_t packed_kv_bound = kv_indptr * block_size + kv_len ;
677
679
int kv_tile_idx =
678
680
ceil_div (
679
681
(CAUSAL ? min (kv_end, kv_len - q_len + (packed_qo_start + cluster_tile_q) / num_heads)
@@ -687,7 +689,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
687
689
688
690
PROFILER_EVENT_START (variant, ProfileEventType::kIssueLoadKV );
689
691
load_kv<true , KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
690
- kpe_stride_n, kpe_stride_page, kv_bound ,
692
+ kpe_stride_n, kpe_stride_page, packed_kv_bound ,
691
693
block_iter_base + kv_tile_idx * CTA_TILE_KV, block_size,
692
694
smem_pipe_write_kv.index ());
693
695
@@ -715,7 +717,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
715
717
pipeline_kv.producer_acquire (smem_pipe_write_kv);
716
718
PROFILER_EVENT_START (variant, ProfileEventType::kIssueLoadKV );
717
719
load_kv<false , KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
718
- kpe_stride_n, kpe_stride_page, kv_bound ,
720
+ kpe_stride_n, kpe_stride_page, packed_kv_bound ,
719
721
block_iter_base + kv_tile_idx * CTA_TILE_KV, block_size,
720
722
smem_pipe_write_kv.index ());
721
723
PROFILER_EVENT_END (variant, ProfileEventType::kIssueLoadKV );
@@ -734,7 +736,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
734
736
pipeline_kv.producer_acquire (smem_pipe_write_kv);
735
737
PROFILER_EVENT_START (variant, ProfileEventType::kIssueLoadKV );
736
738
load_kv<false , KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
737
- kpe_stride_n, kpe_stride_page, kv_bound ,
739
+ kpe_stride_n, kpe_stride_page, packed_kv_bound ,
738
740
block_iter_base + kv_tile_idx * CTA_TILE_KV, block_size,
739
741
smem_pipe_write_kv.index ());
740
742
PROFILER_EVENT_END (variant, ProfileEventType::kIssueLoadKV );
0 commit comments