Skip to content

Commit 106e6fc

Browse files
yzh119abcdabcd987
andauthored
perf: memory efficient deepseek mla fused page-attention kernel (#804)
**Description:** This PR implements a memory efficient fused Deepseek MLA PageAttention kernel for decode, prefill, and chunked prefill operations. MLA computation after matrix absorption can be described as a (multi-query attention) MQA kernel, with same K/V cache and special head dimensions: `head_dim_qk=576, head_dim_vo=512`. **Background:** For Deepseek v2/3, the large `head_dim` (512) makes it challenging to store the output tensor in registers when using tensor cores. A previous approach ([#551](#551)) split `head_dim` across `gridDim.y`, which led to two main issues: - **Re-computation Overhead:** Multiple blocks had to redundantly compute the Q*K operation. - **Memory Access Latency:** Using multiple blocks prevented shared memory usage, forcing reliance on slower L2 for KV-Cache accesses. **New Design:** We use head-group fusion to increase the operational intensity of the kernel ([appendix A in the paper](https://arxiv.org/pdf/2501.01005) ). To address the large head-dimension issue, we redesign the kernel (diagram below) to use two warp groups (WG1 and WG2, each with 4 warps) per CTA: <img width="688" alt="image" src="https://github.com/user-attachments/assets/be684e8d-eaa7-41fb-8160-35e121711d8d" /> - **QK Computation:** Each warp group processes half of the CTA_TILE_KV dimension, eliminating redundant computations. - **Shared Memory Broadcast:** Local QK results are written to shared memory (reusing the `kpe` buffer) to efficiently broadcast data for PV computation. - **PV Computation:** The head dimension is split between the warp groups, with each computing half. The maximum `CTA_TILE_Q` (bounded by register files) is 64, and in this case the maximum `CTA_TILE_KV` (bounded by shared memory limit) is 64 (for H100) and 32 (for A100) when number of pipeline stages is set to 2. For large `num_local_heads` such as 128 (if no TP and no MTP), we create a cluster of size 2, and the upper half (64) and lower half (64) are dispatched to two SMs in a cluster, and we can use software-managed multicasting (for large page size), we leave it to later PR. ## Benchmark results Decoding Memory bandwidth on H100 SXM5 (Peak bandwidth = 3352 GB/s): ``` Config: batch_size=768, seq_len=1024, num_heads=16 Memory bandwidth: 2002.11 GB/s Config: batch_size=768, seq_len=1024, num_heads=32 Memory bandwidth: 2035.59 GB/s Config: batch_size=768, seq_len=1024, num_heads=64 Memory bandwidth: 2082.20 GB/s Config: batch_size=768, seq_len=2048, num_heads=16 Memory bandwidth: 2064.97 GB/s Config: batch_size=768, seq_len=2048, num_heads=32 Memory bandwidth: 2080.99 GB/s Config: batch_size=768, seq_len=2048, num_heads=64 Memory bandwidth: 2082.78 GB/s ``` --------- Co-authored-by: Lequn Chen <[email protected]>
1 parent 8b91e95 commit 106e6fc

17 files changed

+1918
-247
lines changed

benchmarks/bench_deepseek_mla.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""
2+
Copyright (c) 2024 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import torch
18+
import triton
19+
20+
import flashinfer
21+
22+
23+
def bench_deepseek_mla_decode(batch_size, seq_len, num_heads):
24+
head_dim_ckv = 512
25+
head_dim_kpe = 64
26+
page_size = 1
27+
q_nope = torch.randn(
28+
batch_size * 1, num_heads, head_dim_ckv, dtype=torch.half, device="cuda"
29+
)
30+
q_pe = torch.zeros(
31+
batch_size * 1, num_heads, head_dim_kpe, dtype=torch.half, device="cuda"
32+
)
33+
ckv = torch.randn(
34+
batch_size * seq_len, 1, head_dim_ckv, dtype=torch.half, device="cuda"
35+
)
36+
kpe = torch.zeros(
37+
batch_size * seq_len, 1, head_dim_kpe, dtype=torch.half, device="cuda"
38+
)
39+
sm_scale = 1.0 / ((head_dim_ckv + head_dim_kpe) ** 0.5)
40+
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
41+
wrapper = flashinfer.mla.BatchMLAPageAttentionWrapper(
42+
workspace_buffer, backend="fa2"
43+
)
44+
q_indptr = torch.arange(0, batch_size + 1).to(0).int()
45+
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * seq_len
46+
kv_indices = torch.arange(0, batch_size * seq_len).to(0).int()
47+
kv_lens = torch.full((batch_size,), seq_len, dtype=torch.int32).to(0)
48+
wrapper.plan(
49+
q_indptr,
50+
kv_indptr,
51+
kv_indices,
52+
kv_lens,
53+
num_heads,
54+
head_dim_ckv,
55+
head_dim_kpe,
56+
page_size,
57+
False, # causal
58+
sm_scale,
59+
q_nope.dtype,
60+
ckv.dtype,
61+
)
62+
o = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False)
63+
64+
ms = triton.testing.do_bench(
65+
lambda: wrapper.run(q_nope, q_pe, ckv, kpe),
66+
warmup=100,
67+
rep=1000,
68+
)
69+
70+
io = sum([_.numel() * _.element_size() for _ in [q_nope, q_pe, ckv, kpe, o]])
71+
72+
print(f"Config: batch_size={batch_size}, seq_len={seq_len}, num_heads={num_heads}")
73+
print(f"Memory bandwidth: {io * 1e-6 / ms:.2f} GB/s")
74+
75+
76+
if __name__ == "__main__":
77+
bench_deepseek_mla_decode(768, 1024, 16)
78+
bench_deepseek_mla_decode(768, 1024, 32)
79+
bench_deepseek_mla_decode(768, 1024, 64)
80+
bench_deepseek_mla_decode(768, 2048, 16)
81+
bench_deepseek_mla_decode(768, 2048, 32)
82+
bench_deepseek_mla_decode(768, 2048, 64)

csrc/batch_mla_config.jinja

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#pragma once
2+
#include <flashinfer/page.cuh>
3+
#include <flashinfer/math.cuh>
4+
#include <flashinfer/layout.cuh>
5+
#include <flashinfer/utils.cuh>
6+
#include <flashinfer/pos_enc.cuh>
7+
#include <flashinfer/fastdiv.cuh>
8+
#include <flashinfer/attention/variant_helper.cuh>
9+
#include <flashinfer/attention/mla_params.cuh>
10+
11+
using namespace flashinfer;
12+
13+
using DTypeQ = {{ dtype_q }};
14+
using DTypeKV = {{ dtype_kv }};
15+
using DTypeO = {{ dtype_o }};
16+
using IdType = {{ dtype_idx }};
17+
constexpr int HEAD_DIM_CKV = {{ head_dim_ckv }};
18+
constexpr int HEAD_DIM_KPE = {{ head_dim_kpe }};
19+
20+
#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \
21+
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \
22+
using Params = MLAParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
23+
__VA_ARGS__(); \
24+
})

csrc/batch_mla_plan.cu

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <flashinfer/attention/scheduler.cuh>
17+
#include <optional>
18+
19+
#include "batch_mla_config.inc"
20+
#include "pytorch_extension_utils.h"
21+
22+
using namespace flashinfer;
23+
24+
std::vector<int64_t> BatchMLAPageAttentionPlan(at::Tensor float_workspace_buffer,
25+
at::Tensor int_workspace_buffer,
26+
at::Tensor page_locked_int_workspace_buffer,
27+
at::Tensor qo_indptr, at::Tensor kv_indptr,
28+
at::Tensor kv_len, unsigned int num_heads,
29+
unsigned int head_dim_o, bool causal,
30+
int64_t cuda_stream) {
31+
size_t float_workspace_size_in_bytes =
32+
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
33+
size_t int_workspace_size_in_bytes =
34+
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
35+
36+
MLAPlanInfo plan_info;
37+
38+
int batch_size = kv_len.size(0);
39+
40+
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
41+
cudaError_t status =
42+
MLAPlan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
43+
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
44+
int_workspace_size_in_bytes, plan_info, static_cast<IdType*>(qo_indptr.data_ptr()),
45+
static_cast<IdType*>(kv_indptr.data_ptr()), static_cast<IdType*>(kv_len.data_ptr()),
46+
batch_size, num_heads, head_dim_o, causal, stream);
47+
48+
TORCH_CHECK(status == cudaSuccess, "Failed to plan MLA, error: ", cudaGetErrorString(status));
49+
50+
return plan_info.ToVector();
51+
}

csrc/batch_mla_pybind.cu

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include "batch_mla_config.inc"
17+
#include "pytorch_extension_utils.h"
18+
19+
std::vector<int64_t> BatchMLAPageAttentionPlan(at::Tensor float_workspace_buffer,
20+
at::Tensor int_workspace_buffer,
21+
at::Tensor page_locked_int_workspace_buffer,
22+
at::Tensor qo_indptr, at::Tensor kv_indptr,
23+
at::Tensor kv_len, unsigned int num_heads,
24+
unsigned int head_dim_o, bool causal,
25+
int64_t cuda_stream);
26+
27+
void BatchMLAPageAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
28+
std::vector<int64_t> plan_info_vec, at::Tensor q_nope,
29+
at::Tensor q_pe, at::Tensor ckv_cache, at::Tensor kpe_cache,
30+
at::Tensor kv_indices, at::Tensor o,
31+
std::optional<at::Tensor> maybe_lse, int mask_mode_code,
32+
int num_heads, int page_size, float sm_scale, int64_t cuda_stream);
33+
34+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
35+
m.def("plan", &BatchMLAPageAttentionPlan, "Batch MLA Page Attention Plan");
36+
m.def("run", &BatchMLAPageAttentionRun, "Batch MLA Page Attention Run");
37+
}

csrc/batch_mla_run.cu

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <driver_types.h>
17+
18+
#include <flashinfer/attention/mla_fa2.cuh>
19+
#include <flashinfer/attention/scheduler.cuh>
20+
#include <flashinfer/fastdiv.cuh>
21+
#include <optional>
22+
23+
#include "batch_mla_config.inc"
24+
#include "pytorch_extension_utils.h"
25+
26+
using namespace flashinfer;
27+
28+
void BatchMLAPageAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
29+
std::vector<int64_t> plan_info_vec, at::Tensor q_nope,
30+
at::Tensor q_pe, at::Tensor ckv_cache, at::Tensor kpe_cache,
31+
at::Tensor kv_indices, at::Tensor o,
32+
std::optional<at::Tensor> maybe_lse, int mask_mode_code,
33+
int num_heads, int page_size, float sm_scale, int64_t cuda_stream) {
34+
// q_nope: [n, num_heads, head_dim_ckv]
35+
// q_pe: [n, num_heads, head_dim_kpe]
36+
// ckv_cache: [num_pages, page_size, head_dim_ckv]
37+
// kpe_cache: [num_pages, page_size, head_dim_kpe]
38+
MLAPlanInfo plan_info;
39+
plan_info.FromVector(plan_info_vec);
40+
41+
auto device = q_nope.device();
42+
43+
void* float_buffer_ptr = float_workspace_buffer.data_ptr();
44+
void* int_buffer_ptr = int_workspace_buffer.data_ptr();
45+
46+
const MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
47+
48+
auto q_scalar_type = q_nope.scalar_type();
49+
auto kv_scalar_type = ckv_cache.scalar_type();
50+
51+
unsigned int q_nope_stride_n = q_nope.stride(0);
52+
unsigned int q_nope_stride_h = q_nope.stride(1);
53+
unsigned int q_pe_stride_n = q_pe.stride(0);
54+
unsigned int q_pe_stride_h = q_pe.stride(1);
55+
unsigned int ckv_stride_page = ckv_cache.stride(0);
56+
unsigned int ckv_stride_n = ckv_cache.stride(1);
57+
unsigned int kpe_stride_page = kpe_cache.stride(0);
58+
unsigned int kpe_stride_n = kpe_cache.stride(1);
59+
unsigned int o_stride_n = o.stride(0);
60+
unsigned int o_stride_h = o.stride(1);
61+
62+
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
63+
64+
DISPATCH_context(
65+
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, [&] {
66+
Params params;
67+
68+
params.q_nope = static_cast<DTypeQ*>(q_nope.data_ptr());
69+
params.q_pe = static_cast<DTypeQ*>(q_pe.data_ptr());
70+
params.ckv = static_cast<DTypeKV*>(ckv_cache.data_ptr());
71+
params.kpe = static_cast<DTypeKV*>(kpe_cache.data_ptr());
72+
params.kv_indices = static_cast<IdType*>(kv_indices.data_ptr());
73+
74+
params.q_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.q_indptr_offset);
75+
params.kv_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_indptr_offset);
76+
params.kv_indices = static_cast<IdType*>(kv_indices.data_ptr());
77+
params.q_len = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.q_len_offset);
78+
params.kv_len = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_len_offset);
79+
params.q_start = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.q_start_offset);
80+
params.kv_start = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_start_offset);
81+
params.kv_end = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_end_offset);
82+
params.work_indptr =
83+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
84+
params.final_o = static_cast<DTypeO*>(o.data_ptr());
85+
params.final_lse =
86+
maybe_lse.has_value() ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr;
87+
params.partial_o =
88+
GetPtrFromBaseOffset<float>(float_buffer_ptr, plan_info.partial_o_offset);
89+
params.partial_lse =
90+
GetPtrFromBaseOffset<float>(float_buffer_ptr, plan_info.partial_lse_offset);
91+
92+
params.num_heads = uint_fastdiv(num_heads);
93+
params.block_size = uint_fastdiv(page_size);
94+
95+
params.q_nope_stride_n = q_nope_stride_n;
96+
params.q_nope_stride_h = q_nope_stride_h;
97+
params.q_pe_stride_n = q_pe_stride_n;
98+
params.q_pe_stride_h = q_pe_stride_h;
99+
params.ckv_stride_page = ckv_stride_page;
100+
params.ckv_stride_n = ckv_stride_n;
101+
params.kpe_stride_page = kpe_stride_page;
102+
params.kpe_stride_n = kpe_stride_n;
103+
params.o_stride_n = o_stride_n;
104+
params.o_stride_h = o_stride_h;
105+
106+
params.sm_scale = sm_scale;
107+
108+
cudaError_t status = mla::BatchMLAPageAttention<MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE>(
109+
params, plan_info.num_blks_x, plan_info.num_blks_y, stream);
110+
111+
TORCH_CHECK(status == cudaSuccess,
112+
"Failed to run MLA, error: ", cudaGetErrorString(status));
113+
});
114+
}

flashinfer/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
limitations under the License.
1515
"""
1616

17+
from ._build_meta import __version__ as __version__
1718
from .activation import gelu_and_mul as gelu_and_mul
1819
from .activation import gelu_tanh_and_mul as gelu_tanh_and_mul
1920
from .activation import silu_and_mul as silu_and_mul
@@ -41,6 +42,7 @@
4142
from .decode import single_decode_with_kv_cache as single_decode_with_kv_cache
4243
from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper
4344
from .gemm import bmm_fp8 as bmm_fp8
45+
from .mla import BatchMLAPageAttentionWrapper as BatchMLAPageAttentionWrapper
4446
from .norm import fused_add_rmsnorm as fused_add_rmsnorm
4547
from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm
4648
from .norm import gemma_rmsnorm as gemma_rmsnorm
@@ -87,5 +89,3 @@
8789
from .sampling import top_p_renorm_probs as top_p_renorm_probs
8890
from .sampling import top_p_sampling_from_probs as top_p_sampling_from_probs
8991
from .sparse import BlockSparseAttentionWrapper as BlockSparseAttentionWrapper
90-
91-
from ._build_meta import __version__ as __version__

flashinfer/jit/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .activation import get_act_and_mul_cu_str as get_act_and_mul_cu_str
2020
from .attention import gen_batch_decode_mla_module as gen_batch_decode_mla_module
2121
from .attention import gen_batch_decode_module as gen_batch_decode_module
22+
from .attention import gen_batch_mla_module as gen_batch_mla_module
2223
from .attention import gen_batch_prefill_module as gen_batch_prefill_module
2324
from .attention import (
2425
gen_customize_batch_decode_module as gen_customize_batch_decode_module,
@@ -36,6 +37,7 @@
3637
from .attention import gen_single_prefill_module as gen_single_prefill_module
3738
from .attention import get_batch_decode_mla_uri as get_batch_decode_mla_uri
3839
from .attention import get_batch_decode_uri as get_batch_decode_uri
40+
from .attention import get_batch_mla_uri as get_batch_mla_uri
3941
from .attention import get_batch_prefill_uri as get_batch_prefill_uri
4042
from .attention import get_single_decode_uri as get_single_decode_uri
4143
from .attention import get_single_prefill_uri as get_single_prefill_uri

0 commit comments

Comments
 (0)