Skip to content

Commit 23413e0

Browse files
authored
bugfix: MLA decode should multiply sm_scale by math::log2e (#787)
1 parent 9569106 commit 23413e0

File tree

2 files changed

+2
-0
lines changed

2 files changed

+2
-0
lines changed

Diff for: flashinfer/jit/attention.py

+1
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def gen_batch_decode_mla_module(
180180
dtype_o,
181181
dtype_idx,
182182
head_dim,
183+
head_dim,
183184
use_sliding_window,
184185
use_logits_soft_cap,
185186
)

Diff for: include/flashinfer/attention/decode.cuh

+1
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernelMLA(Params params) {
857857
const float rope_rcp_scale = params.rope_rcp_scale;
858858
const float rope_rcp_theta = params.rope_rcp_theta;
859859
const bool partition_kv = params.partition_kv;
860+
params.sm_scale *= math::log2e;
860861

861862
constexpr uint32_t head_dim_ckv = bdx * vec_size_ckv;
862863
constexpr uint32_t head_dim_kpe = bdx * vec_size_kpe;

0 commit comments

Comments
 (0)