@@ -68,8 +68,8 @@ template <RotaryMode rotary_mode, uint32_t vec_size, uint32_t bdx, uint32_t tile
68
68
__device__ __forceinline__ void compute_qk (const T* smem, uint32_t compute_stage_idx,
69
69
const vec_t <float , vec_size>& q_vec,
70
70
const vec_t <float , vec_size>& freq, uint32_t kv_idx_base,
71
- uint32_t iter_base, uint32_t iter_bound, float sm_scale ,
72
- float * s, state_t <vec_size>& st) {
71
+ uint32_t iter_base, uint32_t iter_bound, float * s ,
72
+ state_t <vec_size>& st) {
73
73
uint32_t tx = threadIdx .x , tz = threadIdx .z ;
74
74
float m_prev = st.m ;
75
75
#pragma unroll
@@ -86,7 +86,7 @@ __device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage
86
86
s[j] = 0 .f ;
87
87
#pragma unroll
88
88
for (uint32_t i = 0 ; i < vec_size; ++i) {
89
- s[j] += q_vec[i] * k_vec[i] * sm_scale ;
89
+ s[j] += q_vec[i] * k_vec[i];
90
90
}
91
91
#pragma unroll
92
92
for (uint32_t offset = bdx / 2 ; offset > 0 ; offset /= 2 ) {
@@ -240,6 +240,11 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn*
240
240
// do not apply rotary embedding to q matrix
241
241
q_vec.cast_load (q + info.get_qo_elem_offset (0 , qo_head_idx, tx * vec_size));
242
242
}
243
+ // multiple q_vec by sm_scale
244
+ #pragma unroll
245
+ for (uint32_t i = 0 ; i < vec_size; ++i) {
246
+ q_vec[i] *= sm_scale;
247
+ }
243
248
block.sync ();
244
249
245
250
uint32_t chunk_start = kv_chunk_idx * kv_chunk_size;
@@ -286,8 +291,8 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn*
286
291
block.sync ();
287
292
compute_qk<rotary_mode, vec_size, bdx, bdy * tile_size_per_bdx>(
288
293
k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec,
289
- freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size, sm_scale ,
290
- s, st_local);
294
+ freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size, s ,
295
+ st_local);
291
296
block.sync ();
292
297
// load k
293
298
for (uint32_t j = 0 ; j < tile_size_per_bdx; ++j) {
@@ -385,6 +390,10 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(
385
390
q_vec.cast_load (q + batch_idx * num_qo_heads * head_dim +
386
391
info.get_qo_elem_offset (0 , qo_head_idx, tx * vec_size));
387
392
}
393
+ #pragma unroll
394
+ for (uint32_t i = 0 ; i < vec_size; ++i) {
395
+ q_vec[i] *= sm_scale;
396
+ }
388
397
block.sync ();
389
398
390
399
// preload k tiles and v tiles
@@ -421,7 +430,7 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(
421
430
block.sync ();
422
431
compute_qk<rotary_mode, vec_size, bdx, bdy>(k_smem + (stage_idx * bdz + tz) * bdy * head_dim,
423
432
stage_idx, q_vec, freq, consumer_kv_idx_base,
424
- iter * bdy * bdz, seq_len, sm_scale, s, st_local);
433
+ iter * bdy * bdz, seq_len, s, st_local);
425
434
block.sync ();
426
435
// load k
427
436
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch , SharedMemFillMode::kNoFill >(
@@ -551,6 +560,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
551
560
// do not apply rotary embedding to q matrix
552
561
q_vec.cast_load (q + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
553
562
}
563
+ #pragma unroll
564
+ for (uint32_t i = 0 ; i < vec_size; ++i) {
565
+ q_vec[i] *= sm_scale;
566
+ }
554
567
block.sync ();
555
568
556
569
// preload k/v tiles
@@ -622,7 +635,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
622
635
freq,
623
636
(paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset [mapped_batch_idx]) +
624
637
cur_chunk_start + iter * tile_size_per_bdx * bdy * bdz,
625
- iter * tile_size_per_bdx * bdy * bdz, kv_chunk_len, sm_scale, s, st);
638
+ iter * tile_size_per_bdx * bdy * bdz, kv_chunk_len, s, st);
626
639
block.sync ();
627
640
628
641
#pragma unroll
0 commit comments