Skip to content

Commit 660c559

Browse files
authored
perf: multiple q by sm_scale in decode kernels (#144)
The same optimization was used in our prefill attention kernels, this PR applies this optimization to decode attention kernels.
1 parent 5f70697 commit 660c559

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

include/flashinfer/attention/decode.cuh

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ template <RotaryMode rotary_mode, uint32_t vec_size, uint32_t bdx, uint32_t tile
6868
__device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage_idx,
6969
const vec_t<float, vec_size>& q_vec,
7070
const vec_t<float, vec_size>& freq, uint32_t kv_idx_base,
71-
uint32_t iter_base, uint32_t iter_bound, float sm_scale,
72-
float* s, state_t<vec_size>& st) {
71+
uint32_t iter_base, uint32_t iter_bound, float* s,
72+
state_t<vec_size>& st) {
7373
uint32_t tx = threadIdx.x, tz = threadIdx.z;
7474
float m_prev = st.m;
7575
#pragma unroll
@@ -86,7 +86,7 @@ __device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage
8686
s[j] = 0.f;
8787
#pragma unroll
8888
for (uint32_t i = 0; i < vec_size; ++i) {
89-
s[j] += q_vec[i] * k_vec[i] * sm_scale;
89+
s[j] += q_vec[i] * k_vec[i];
9090
}
9191
#pragma unroll
9292
for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) {
@@ -240,6 +240,11 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn*
240240
// do not apply rotary embedding to q matrix
241241
q_vec.cast_load(q + info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size));
242242
}
243+
// multiple q_vec by sm_scale
244+
#pragma unroll
245+
for (uint32_t i = 0; i < vec_size; ++i) {
246+
q_vec[i] *= sm_scale;
247+
}
243248
block.sync();
244249

245250
uint32_t chunk_start = kv_chunk_idx * kv_chunk_size;
@@ -286,8 +291,8 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn*
286291
block.sync();
287292
compute_qk<rotary_mode, vec_size, bdx, bdy * tile_size_per_bdx>(
288293
k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec,
289-
freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size, sm_scale,
290-
s, st_local);
294+
freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size, s,
295+
st_local);
291296
block.sync();
292297
// load k
293298
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
@@ -385,6 +390,10 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(
385390
q_vec.cast_load(q + batch_idx * num_qo_heads * head_dim +
386391
info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size));
387392
}
393+
#pragma unroll
394+
for (uint32_t i = 0; i < vec_size; ++i) {
395+
q_vec[i] *= sm_scale;
396+
}
388397
block.sync();
389398

390399
// preload k tiles and v tiles
@@ -421,7 +430,7 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(
421430
block.sync();
422431
compute_qk<rotary_mode, vec_size, bdx, bdy>(k_smem + (stage_idx * bdz + tz) * bdy * head_dim,
423432
stage_idx, q_vec, freq, consumer_kv_idx_base,
424-
iter * bdy * bdz, seq_len, sm_scale, s, st_local);
433+
iter * bdy * bdz, seq_len, s, st_local);
425434
block.sync();
426435
// load k
427436
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
@@ -551,6 +560,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
551560
// do not apply rotary embedding to q matrix
552561
q_vec.cast_load(q + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
553562
}
563+
#pragma unroll
564+
for (uint32_t i = 0; i < vec_size; ++i) {
565+
q_vec[i] *= sm_scale;
566+
}
554567
block.sync();
555568

556569
// preload k/v tiles
@@ -622,7 +635,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
622635
freq,
623636
(paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[mapped_batch_idx]) +
624637
cur_chunk_start + iter * tile_size_per_bdx * bdy * bdz,
625-
iter * tile_size_per_bdx * bdy * bdz, kv_chunk_len, sm_scale, s, st);
638+
iter * tile_size_per_bdx * bdy * bdz, kv_chunk_len, s, st);
626639
block.sync();
627640

628641
#pragma unroll

0 commit comments

Comments
 (0)