Skip to content

perf: prefetch page indices for mla kernel #991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/flashinfer/attention/mla.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ __device__ void DevicePersistentMergeStates(
typename KTraits::IdType* merge_partial_stride, typename KTraits::DTypeO* partial_o,
float* partial_lse, typename KTraits::DTypeO* final_o, float* final_lse,
const uint32_t o_stride_n, const uint32_t o_stride_h, const uint_fastdiv& num_heads) {
constexpr uint32_t VEC_SIZE = 4; // partial o has data type float
constexpr uint32_t VEC_SIZE = 8; // partial o has data type float
constexpr uint32_t NUM_THRS_PER_ROW = KTraits::HEAD_DIM_CKV / VEC_SIZE;
constexpr uint32_t ROWS_PER_ITERATION = (KTraits::NUM_THREADS) / NUM_THRS_PER_ROW;
const uint32_t cta_idx = (gridDim.x * blockIdx.y + blockIdx.x);
Expand Down
79 changes: 56 additions & 23 deletions include/flashinfer/attention/mla_hopper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -170,20 +170,47 @@ __device__ __forceinline__ void load_q(
}
}

template <typename KTraits>
__device__ __forceinline__ void prefetch_offset(
const uint32_t packed_block_iter_base, const uint32_t packed_kv_bound,
const uint32_t ckv_stride_page, const uint32_t ckv_stride_n, const uint32_t kpe_stride_page,
const uint32_t kpe_stride_n, const uint_fastdiv& block_size, typename KTraits::IdType* indices,
int64_t (*ckv_offset)[2], int64_t (*kpe_offset)[2]) {
using DTypeKV = typename KTraits::DTypeKV;
const uint32_t lane_idx = cutlass::canonical_lane_idx();
const uint32_t warp_idx_in_wg = cutlass::canonical_warp_idx() % 4;
#pragma unroll
for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV / 2; ++mma_kv) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
uint32_t q, r;
uint32_t packed_block_iter =
packed_block_iter_base + lane_idx / 8 + (j + mma_kv * 2) * 16 + warp_idx_in_wg * 4;
block_size.divmod(packed_block_iter, q, r);
ckv_offset[mma_kv][j] =
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * ckv_stride_page +
r * ckv_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
kpe_offset[mma_kv][j] =
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * kpe_stride_page +
r * kpe_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
}
}
}

template <bool predicate, typename KTraits>
__device__ __forceinline__ void load_kv(
typename KTraits::SharedStorage* smem_storage, typename KTraits::DTypeKV* ckv,
typename KTraits::DTypeKV* kpe, typename KTraits::IdType* indices, const uint32_t ckv_stride_n,
const uint32_t ckv_stride_page, const uint32_t kpe_stride_n, const uint32_t kpe_stride_page,
const uint32_t packed_kv_bound, const uint32_t packed_block_iter_base,
const uint_fastdiv& block_size, const uint32_t stage_idx) {
__device__ __forceinline__ void load_kv(typename KTraits::SharedStorage* smem_storage,
typename KTraits::DTypeKV* ckv,
typename KTraits::DTypeKV* kpe,
const uint32_t packed_kv_bound,
const uint32_t packed_block_iter_base,
const uint32_t stage_idx, int64_t (*ckv_offset)[2],
int64_t (*kpe_offset)[2]) {
using DTypeKV = typename KTraits::DTypeKV;
constexpr uint32_t UPCAST_STRIDE_CKV = KTraits::UPCAST_STRIDE_CKV;
constexpr uint32_t UPCAST_STRIDE_KPE = KTraits::UPCAST_STRIDE_KPE;
constexpr uint32_t NUM_MMA_D_CKV = KTraits::NUM_MMA_D_CKV;
constexpr uint32_t NUM_MMA_D_KPE = KTraits::NUM_MMA_D_KPE;
const uint32_t lane_idx = cutlass::canonical_lane_idx();
const uint32_t warp_group_idx = cutlass::canonical_warp_group_idx();
const uint32_t warp_idx_in_wg = cutlass::canonical_warp_idx() % 4;

smem_t<KTraits::SWIZZLE_MODE_CKV> ckv_smem(smem_storage->kv_o_smem[stage_idx].ckv);
Expand All @@ -193,17 +220,11 @@ __device__ __forceinline__ void load_kv(
for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV / 2; ++mma_kv) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
uint32_t q, r;
uint32_t packed_block_iter =
packed_block_iter_base + lane_idx / 8 + (j + mma_kv * 2) * 16 + warp_idx_in_wg * 4;
block_size.divmod(packed_block_iter, q, r);

DTypeKV* ckv_ptr = ckv +
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * ckv_stride_page +
r * ckv_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
DTypeKV* kpe_ptr = kpe +
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * kpe_stride_page +
r * kpe_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
DTypeKV* ckv_ptr = ckv + ckv_offset[mma_kv][j];
DTypeKV* kpe_ptr = kpe + kpe_offset[mma_kv][j];
uint32_t ckv_smem_offset_w = get_swizzle_offset<KTraits::SWIZZLE_MODE_CKV, UPCAST_STRIDE_CKV>(
32 * mma_kv + j * 16 + warp_idx_in_wg * 4 + lane_idx / 8, 8 * 0 + lane_idx % 8);
uint32_t kpe_smem_offset_w = get_swizzle_offset<KTraits::SWIZZLE_MODE_KPE, UPCAST_STRIDE_KPE>(
Expand Down Expand Up @@ -657,6 +678,9 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_read_kv;

int64_t ckv_offset[KTraits::NUM_MMA_KV / 2][2];
int64_t kpe_offset[KTraits::NUM_MMA_KV / 2][2];

#pragma unroll 1
for (IdType work_idx = work_indptr[blockIdx.y]; work_idx < work_indptr[blockIdx.y + 1];
++work_idx) {
Expand All @@ -681,15 +705,20 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop

const uint32_t block_iter_base = kv_indptr * block_size + kv_start;

prefetch_offset<KTraits>(block_iter_base + kv_tile_idx * CTA_TILE_KV, packed_kv_bound,
ckv_stride_page, ckv_stride_n, kpe_stride_page, kpe_stride_n,
block_size, kv_indices, ckv_offset, kpe_offset);
if (has_kv) {
pipeline_kv.producer_acquire(smem_pipe_write_kv);
load_kv<true, KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
kpe_stride_n, kpe_stride_page, packed_kv_bound,
block_iter_base + kv_tile_idx * CTA_TILE_KV, block_size,
smem_pipe_write_kv.index());
load_kv<true, KTraits>(&smem_storage, ckv, kpe, packed_kv_bound,
block_iter_base + kv_tile_idx * CTA_TILE_KV,
smem_pipe_write_kv.index(), ckv_offset, kpe_offset);
pipeline_kv.producer_commit(smem_pipe_write_kv, cutlass::arch::cpasync_barrier_arrive);
kv_tile_idx -= 1;
++smem_pipe_write_kv;
prefetch_offset<KTraits>(block_iter_base + kv_tile_idx * CTA_TILE_KV, packed_kv_bound,
ckv_stride_page, ckv_stride_n, kpe_stride_page, kpe_stride_n,
block_size, kv_indices, ckv_offset, kpe_offset);
}

pipeline_q.producer_acquire(smem_pipe_write_q);
Expand All @@ -703,10 +732,14 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
#pragma unroll 1
for (; kv_tile_idx >= 0; --kv_tile_idx) {
pipeline_kv.producer_acquire(smem_pipe_write_kv);
load_kv<false, KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
kpe_stride_n, kpe_stride_page, packed_kv_bound,
block_iter_base + kv_tile_idx * CTA_TILE_KV, block_size,
smem_pipe_write_kv.index());
load_kv<false, KTraits>(&smem_storage, ckv, kpe, packed_kv_bound,
block_iter_base + kv_tile_idx * CTA_TILE_KV,
smem_pipe_write_kv.index(), ckv_offset, kpe_offset);
if (kv_tile_idx > 0) {
prefetch_offset<KTraits>(block_iter_base + (kv_tile_idx - 1) * CTA_TILE_KV,
packed_kv_bound, ckv_stride_page, ckv_stride_n, kpe_stride_page,
kpe_stride_n, block_size, kv_indices, ckv_offset, kpe_offset);
}
pipeline_kv.producer_commit(smem_pipe_write_kv, cutlass::arch::cpasync_barrier_arrive);
++smem_pipe_write_kv;

Expand Down