Skip to content

Commit 85049f8

Browse files
committed
Support cutlass MLA
Signed-off-by: kaixih <[email protected]>
1 parent d4bfc23 commit 85049f8

File tree

7 files changed

+397
-2
lines changed

7 files changed

+397
-2
lines changed

CMakeLists.txt

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
289289
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
290290
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
291291
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
292-
"csrc/cutlass_extensions/common.cpp")
292+
"csrc/cutlass_extensions/common.cpp"
293+
"csrc/attention/mla/cutlass_mla_entry.cu")
293294

294295
set_gencode_flags_for_srcs(
295296
SRCS "${VLLM_EXT_SRC}"
@@ -462,7 +463,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
462463
set(FP4_ARCHS)
463464
endif()
464465

465-
#
466+
# CUTLASS MLA Archs and flags
467+
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
468+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS)
469+
set(SRCS
470+
"csrc/attention/mla/cutlass_mla_kernels.cu")
471+
set_gencode_flags_for_srcs(
472+
SRCS "${SRCS}"
473+
CUDA_ARCHS "${MLA_ARCHS}")
474+
list(APPEND VLLM_EXT_SRC "${SRCS}")
475+
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1")
476+
# Add MLA-specific include directories only to MLA source files
477+
set_source_files_properties(${SRCS}
478+
PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_INCLUDE_DIR}/../examples/77_blackwell_fmha;${CUTLASS_INCLUDE_DIR}/../examples/common")
479+
message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}")
480+
else()
481+
message(STATUS "Not building CUTLASS MLA as no compatible archs were found.")
482+
# clear MLA_ARCHS
483+
set(MLA_ARCHS)
484+
endif()
485+
466486
# CUTLASS MoE kernels
467487

468488
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <torch/all.h>
18+
19+
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
20+
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
21+
torch::Tensor const& q_nope_and_q_pe,
22+
torch::Tensor const& kv_c_and_k_pe_cache,
23+
torch::Tensor const& seq_lens,
24+
torch::Tensor const& page_table);
25+
#endif
26+
27+
void cutlass_mla_decode(torch::Tensor const& out,
28+
torch::Tensor const& q_nope_and_q_pe,
29+
torch::Tensor const& kv_c_and_k_pe_cache,
30+
torch::Tensor const& seq_lens,
31+
torch::Tensor const& page_table) {
32+
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
33+
return cutlass_mla_decode_sm100a(out, q_nope_and_q_pe, kv_c_and_k_pe_cache,
34+
seq_lens, page_table);
35+
#endif
36+
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA");
37+
}
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <torch/all.h>
18+
19+
#include <ATen/cuda/CUDAContext.h>
20+
#include <c10/cuda/CUDAGuard.h>
21+
22+
#include "cute/tensor.hpp"
23+
24+
#include "cutlass/cutlass.h"
25+
#include "cutlass/kernel_hardware_info.h"
26+
27+
#include "device/sm100_mla.hpp"
28+
#include "kernel/sm100_mla_tile_scheduler.hpp"
29+
30+
#define CUTLASS_CHECK(status) \
31+
{ \
32+
cutlass::Status error = status; \
33+
TORCH_CHECK(error == cutlass::Status::kSuccess, \
34+
cutlassGetStatusString(error)); \
35+
}
36+
37+
using namespace cute;
38+
using namespace cutlass::fmha::kernel;
39+
40+
template <bool v>
41+
struct IsPersistent {
42+
static const bool value = v;
43+
};
44+
45+
template <typename T, typename PersistenceOption = IsPersistent<true>>
46+
struct MlaSm100 {
47+
using Element = T;
48+
using ElementAcc = float;
49+
using ElementOut = T;
50+
51+
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
52+
using TileShapeH = cute::tuple_element_t<0, TileShape>;
53+
using TileShapeD = cute::tuple_element_t<2, TileShape>;
54+
55+
// H K (D_latent D_rope) B
56+
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
57+
58+
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
59+
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
60+
using StrideO = StrideK; // H D B
61+
using StrideLSE = cute::tuple<_1, int>; // H B
62+
63+
using TileScheduler = std::conditional_t<PersistenceOption::value,
64+
Sm100MlaPersistentTileScheduler,
65+
Sm100MlaIndividualTileScheduler>;
66+
67+
using FmhaKernel =
68+
cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
69+
TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler,
70+
/*kIsCpAsync=*/true>;
71+
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
72+
};
73+
74+
template <typename T>
75+
typename T::Fmha::Arguments args_from_options(
76+
at::Tensor const& out, at::Tensor const& q_nope_and_q_pe,
77+
at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens,
78+
at::Tensor const& page_table) {
79+
cutlass::KernelHardwareInfo hw_info;
80+
hw_info.device_id = q_nope_and_q_pe.device().index();
81+
hw_info.sm_count =
82+
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
83+
hw_info.device_id);
84+
85+
int batches = q_nope_and_q_pe.sizes()[0];
86+
int page_count_per_seq = page_table.sizes()[1];
87+
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
88+
int page_size = kv_c_and_k_pe_cache.sizes()[1];
89+
int max_seq_len = page_size * page_count_per_seq;
90+
using TileShapeH = typename T::TileShapeH;
91+
using TileShapeD = typename T::TileShapeD;
92+
auto problem_shape =
93+
cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
94+
95+
auto [H, K, D, B] = problem_shape;
96+
auto [D_latent, D_rope] = D;
97+
98+
// the scale is based on the non-absorbed sizes, change as appropriate
99+
// we can't determine this parameter from the info we have, it's an input
100+
int D_non_latent = 128;
101+
float scale = 1.0 / sqrt(1.0 * (D_non_latent + D_rope));
102+
103+
using StrideQ = typename T::StrideQ;
104+
using StrideK = typename T::StrideK;
105+
using StrideO = typename T::StrideO;
106+
using StrideLSE = typename T::StrideLSE;
107+
108+
StrideQ stride_Q =
109+
cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{},
110+
static_cast<int64_t>(H * (0 + D_latent + D_rope)));
111+
StrideK stride_C =
112+
cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{},
113+
static_cast<int64_t>(page_size * (D_latent + D_rope)));
114+
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
115+
StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H);
116+
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{},
117+
static_cast<int64_t>(0 + H * D_latent));
118+
119+
using Element = typename T::Element;
120+
using ElementOut = typename T::ElementOut;
121+
using ElementAcc = typename T::ElementAcc;
122+
auto Q_ptr = static_cast<Element*>(q_nope_and_q_pe.data_ptr());
123+
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
124+
typename T::Fmha::Arguments arguments{
125+
problem_shape,
126+
{scale, Q_ptr, stride_Q, Q_ptr + D_latent, stride_Q, C_ptr, stride_C,
127+
C_ptr + D_latent, stride_C, static_cast<int*>(seq_lens.data_ptr()),
128+
static_cast<int*>(page_table.data_ptr()), stride_PT, page_count_total,
129+
page_size},
130+
{static_cast<ElementOut*>(out.data_ptr()), stride_O,
131+
static_cast<ElementAcc*>(nullptr), stride_LSE},
132+
hw_info,
133+
-1, // split_kv
134+
nullptr, // is_var_split_kv
135+
};
136+
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
137+
// split_kv automatically based on batch size and sequence length to balance
138+
// workload across available SMs. Consider using var_split_kv for manual
139+
// control if needed.
140+
T::Fmha::set_split_kv(arguments);
141+
return arguments;
142+
}
143+
144+
template <typename Element>
145+
void runMla(at::Tensor const& out, at::Tensor const& q_nope_and_q_pe,
146+
at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens,
147+
at::Tensor const& page_table, cudaStream_t stream) {
148+
using MlaSm100Type = MlaSm100<Element>;
149+
typename MlaSm100Type::Fmha fmha;
150+
auto arguments = args_from_options<MlaSm100Type>(
151+
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table);
152+
size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments);
153+
auto const workspace_options = torch::TensorOptions()
154+
.dtype(torch::kUInt8)
155+
.device(q_nope_and_q_pe.device());
156+
auto workspace = torch::empty(workspace_size, workspace_options);
157+
158+
CUTLASS_CHECK(fmha.can_implement(arguments));
159+
160+
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
161+
162+
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
163+
}
164+
165+
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
166+
torch::Tensor const& q_nope_and_q_pe,
167+
torch::Tensor const& kv_c_and_k_pe_cache,
168+
torch::Tensor const& seq_lens,
169+
torch::Tensor const& page_table) {
170+
auto in_dtype = q_nope_and_q_pe.dtype();
171+
at::cuda::CUDAGuard device_guard{(char)q_nope_and_q_pe.get_device()};
172+
const cudaStream_t stream =
173+
at::cuda::getCurrentCUDAStream(q_nope_and_q_pe.get_device());
174+
if (in_dtype == at::ScalarType::Half) {
175+
runMla<cutlass::half_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens,
176+
page_table, stream);
177+
} else if (in_dtype == at::ScalarType::BFloat16) {
178+
runMla<cutlass::bfloat16_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache,
179+
seq_lens, page_table, stream);
180+
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
181+
runMla<cutlass::float_e4m3_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache,
182+
seq_lens, page_table, stream);
183+
} else {
184+
TORCH_CHECK(false, "Unsupported input data type of MLA");
185+
}
186+
}

csrc/ops.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,12 @@ void advance_step_flashinfer(
119119
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
120120
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
121121

122+
void cutlass_mla_decode(torch::Tensor const& out,
123+
torch::Tensor const& q_nope_and_q_pe,
124+
torch::Tensor const& kv_c_and_k_pe_cache,
125+
torch::Tensor const& seq_lens,
126+
torch::Tensor const& page_table);
127+
122128
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
123129

124130
#ifndef USE_ROCM

csrc/torch_bindings.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
115115
") -> ()");
116116
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
117117

118+
// Compute MLA decode using cutlass.
119+
ops.def(
120+
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe,"
121+
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
122+
" Tensor page_table) -> ()");
123+
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
124+
118125
// Layernorm
119126
// Apply Root Mean Square (RMS) Normalization to the input tensor.
120127
ops.def(
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import pytest
3+
import torch
4+
import torch.nn.functional as F
5+
from torch import Tensor
6+
7+
import vllm._custom_ops as ops
8+
from vllm.platforms import current_platform
9+
10+
if not current_platform.has_device_capability(100):
11+
pytest.skip(
12+
reason="Cutlass MLA Requires compute capability of 10 or above.",
13+
allow_module_level=True)
14+
15+
16+
def ref_mla(
17+
out: Tensor, # (bs, num_heads, v_head_dim)
18+
query: Tensor, # (bs, num_heads, head_dim)
19+
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
20+
scale: float,
21+
block_tables: Tensor, # (bs, max_num_blocks)
22+
seq_lens: Tensor, # (bs,)
23+
):
24+
bs, num_heads, v_head_dim = out.shape
25+
head_dim = query.shape[2]
26+
27+
for i in range(bs):
28+
# gather and flatten KV-cache
29+
kv = kv_cache[
30+
block_tables[i]] # (max_num_blocks, block_size, head_dim)
31+
kv = kv.view(1, -1,
32+
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
33+
v = kv[:, :, :v_head_dim]
34+
35+
q = query[i].view(num_heads, 1, head_dim)
36+
o = F.scaled_dot_product_attention(q,
37+
kv,
38+
v,
39+
scale=scale,
40+
enable_gqa=True)
41+
out[i] = o.view(num_heads, v_head_dim)
42+
43+
return out
44+
45+
46+
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
47+
@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096])
48+
@pytest.mark.parametrize("bs", [1, 2, 4])
49+
@pytest.mark.parametrize("varlen", [False, True])
50+
@pytest.mark.parametrize("block_size", [16, 64, 128])
51+
def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int,
52+
varlen: bool, block_size: int):
53+
torch.set_default_dtype(dtype)
54+
torch.set_default_device('cuda')
55+
torch.manual_seed(42)
56+
57+
d = 576
58+
h_q = 128
59+
dv = 512
60+
61+
q_nope_dim = 128
62+
q_pe_dim = 64
63+
scale = (q_nope_dim + q_pe_dim)**(-0.5)
64+
if varlen:
65+
seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2)
66+
seq_lens = seq_lens.clip(2).to(torch.int32)
67+
else:
68+
seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32)
69+
max_seq_len = seq_lens.max().item()
70+
block_num = (max_seq_len + block_size - 1) // block_size
71+
72+
# Pad block_num so that small blocks can be packed into full 128-sized
73+
# CUTLASS tiles. One 128-wide tile can hold (128 // block_size) small
74+
# blocks.
75+
pack_factor = 128 // block_size
76+
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
77+
78+
q = torch.randn(bs, h_q, d)
79+
block_table = torch.randint(0,
80+
bs * block_num, (bs, block_num),
81+
dtype=torch.int32)
82+
83+
kv_cache = torch.randn(block_table.numel(), block_size, d)
84+
85+
out_ref = q.new_zeros(bs, h_q, dv)
86+
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
87+
out = ops.cutlass_mla_decode(q, kv_cache, seq_lens, block_table)
88+
89+
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)

0 commit comments

Comments
 (0)