Skip to content

Commit 0ed1ce8

Browse files
authored
perf: tweak the pipeline design of mla kernel (#901)
1. defer barrier sync for `p_smem` 2. change unroll number from 1 to 2 We found there are still significant overhead for synchronizing two consumers in qk stage. Use only one warpgroup for qk can resolve the issue.
1 parent e4a68e4 commit 0ed1ce8

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

include/flashinfer/attention/mla_hopper.cuh

+2-2
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ __device__ __forceinline__ void compute_mla_qk(typename KTraits::SharedStorage*
275275
template <typename KTraits>
276276
__device__ __forceinline__ void compute_mla_pv(typename KTraits::SharedStorage* smem_storage,
277277
const uint32_t stage_idx, float* o_frag) {
278+
barrier_sync(KTraits::NUM_MMA_THREADS, NamedBarriers::kConsumerSync);
278279
const uint32_t lane_idx = cutlass::canonical_lane_idx();
279280
const uint32_t warp_idx_in_wg = cutlass::canonical_warp_idx() % 4;
280281
const uint32_t warp_group_idx = cutlass::canonical_warp_group_idx();
@@ -400,7 +401,6 @@ __device__ __forceinline__ void write_p_rmem_smem(typename KTraits::SharedStorag
400401
(warp_group_idx - 1) * NUM_MMA_KV + mma_kv * 2 + lane_idx / 16);
401402
p_smem.stmatrix_m8n8x4(p_smem_offset_w, p_frag + mma_kv * 4);
402403
}
403-
barrier_sync(KTraits::NUM_MMA_THREADS, NamedBarriers::kConsumerSync);
404404
}
405405

406406
template <typename KTraits>
@@ -780,7 +780,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
780780
}
781781

782782
// loop without mask
783-
#pragma unroll 1
783+
#pragma unroll 2
784784
for (; kv_tile_idx > NUM_STAGES; --kv_tile_idx) {
785785
auto smem_pipe_read_kv_cur = smem_pipe_read_kv;
786786
++smem_pipe_read_kv;

0 commit comments

Comments
 (0)