Skip to content

Commit 504b990

Browse files
authored
fix rope logic in mla decoding (#793)
Co-authored-by: pankajroark <[email protected]> As titled, unblock the FlashInfer integration. E2E testing is functioning properly. cc @yzh119 @pankajroark @merrymercy @Ying1123 @ispobock ```bash python3 tests/test_mla_decode_kernel.py ``` ``` Now use MLA decode kernel! 2025-02-06 22:55:31,946 - INFO - flashinfer.jit: Loading JIT ops: batch_decode_mla_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_512_head_dim_vo_512_use_swa_False_use_logits_cap_False /usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py:1964: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST']. warnings.warn( /usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py:1964: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST']. warnings.warn( 2025-02-06 22:55:31,960 - INFO - flashinfer.jit: Finished loading JIT ops: batch_decode_mla_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_512_head_dim_vo_512_use_swa_False_use_logits_cap_False cos_use_torch_f32 = 1.0 wmape_use_torch_f32 = 1.4899706573821664e-05 mse_use_torch_f32=0.004270492121577263 cos_use_torch_f16 = 0.999683678150177 wmape_use_torch_f16 = 0.020623904841957166 mse_use_torch_f16 = 5391.00048828125 cos_use_flashinfer = 0.9999864101409912 wmape_use_flashinfer = 0.004352144090863914 mse_use_flashinfer = 231.20518493652344 ```
1 parent 23413e0 commit 504b990

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

include/flashinfer/attention/decode.cuh

+5-5
Original file line numberDiff line numberDiff line change
@@ -794,8 +794,8 @@ __device__ __forceinline__ void compute_qk_and_update_local_stat_mla(
794794
ckv_vec.cast_load(ckv_smem + j * head_dim_ckv + tx * vec_size_ckv);
795795

796796
vec_t<float, vec_size_kpe> kpe_vec;
797-
kpe_vec = vec_apply_llama_rope_interleave<vec_size_kpe, bdx>(kpe_smem + j * head_dim_kpe, freq,
798-
kv_idx_base + tz * tile_size + j);
797+
kpe_vec.cast_load(kpe_smem + j * head_dim_kpe + tx * vec_size_kpe);
798+
799799
s[j] = 0.f;
800800
#pragma unroll
801801
for (uint32_t i = 0; i < vec_size_ckv; ++i) {
@@ -920,9 +920,9 @@ __global__ void BatchDecodeWithPagedKVCacheKernelMLA(Params params) {
920920
q_nope_vec[i].cast_load(q_nope +
921921
(mapped_batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_ckv +
922922
tx * vec_size_ckv);
923-
q_pe_vec[i] = vec_apply_llama_rope_interleave<vec_size_kpe, bdx>(
924-
q_pe + (mapped_batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_kpe, freq,
925-
q_rope_offset_val);
923+
q_pe_vec[i].cast_load(q_pe +
924+
(mapped_batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_kpe +
925+
tx * vec_size_kpe);
926926
}
927927
}
928928

tests/test_mla_decode_kernel.py

+10
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,16 @@ def run_proof_of_concept(
316316
raise ValueError(
317317
"For simplicity, kv_len should be multiple of page_size."
318318
)
319+
freqs_cis = precompute_freqs_cis(
320+
self.qk_rope_head_dim, kv_len, self.rope_theta, use_scaled=False
321+
).to(k_pe_cache.device)
322+
q_pe, k_pe_cache = apply_rotary_emb(
323+
q_pe.unsqueeze(1).repeat(1, kv_len, 1, 1),
324+
k_pe_cache.unsqueeze(2),
325+
freqs_cis,
326+
)
327+
q_pe = q_pe[:, -1:, :, :].squeeze(1).contiguous()
328+
k_pe_cache = k_pe_cache.squeeze(2)
319329
num_pages_per_seq = kv_len // page_size
320330
total_num_pages = num_pages_per_seq * bsz
321331

0 commit comments

Comments
 (0)