Skip to content

[NVIDIA] Support Cutlass MLA for Blackwell GPUs #16032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp")
"csrc/cutlass_extensions/common.cpp"
"csrc/attention/mla/cutlass_mla_entry.cu")

set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}"
Expand Down Expand Up @@ -462,7 +463,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(FP4_ARCHS)
endif()

#
# CUTLASS MLA Archs and flags
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS)
set(SRCS
"csrc/attention/mla/cutlass_mla_kernels.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${MLA_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MLA=1")
# Add MLA-specific include directories only to MLA source files
set_source_files_properties(${SRCS}
PROPERTIES INCLUDE_DIRECTORIES "${CUTLASS_INCLUDE_DIR}/../examples/77_blackwell_fmha;${CUTLASS_INCLUDE_DIR}/../examples/common")
message(STATUS "Building CUTLASS MLA for archs: ${MLA_ARCHS}")
else()
message(STATUS "Not building CUTLASS MLA as no compatible archs were found.")
# clear MLA_ARCHS
set(MLA_ARCHS)
endif()

# CUTLASS MoE kernels

# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
Expand Down
37 changes: 37 additions & 0 deletions csrc/attention/mla/cutlass_mla_entry.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <torch/all.h>

#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
torch::Tensor const& q_nope_and_q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table);
#endif

void cutlass_mla_decode(torch::Tensor const& out,
torch::Tensor const& q_nope_and_q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table) {
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
return cutlass_mla_decode_sm100a(out, q_nope_and_q_pe, kv_c_and_k_pe_cache,
seq_lens, page_table);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA");
}
186 changes: 186 additions & 0 deletions csrc/attention/mla/cutlass_mla_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <torch/all.h>

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include "cute/tensor.hpp"

#include "cutlass/cutlass.h"
#include "cutlass/kernel_hardware_info.h"

#include "device/sm100_mla.hpp"
#include "kernel/sm100_mla_tile_scheduler.hpp"

#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, \
cutlassGetStatusString(error)); \
}

using namespace cute;
using namespace cutlass::fmha::kernel;

template <bool v>
struct IsPersistent {
static const bool value = v;
};

template <typename T, typename PersistenceOption = IsPersistent<true>>
struct MlaSm100 {
using Element = T;
using ElementAcc = float;
using ElementOut = T;

using TileShape = Shape<_128, _128, Shape<_512, _64>>;
using TileShapeH = cute::tuple_element_t<0, TileShape>;
using TileShapeD = cute::tuple_element_t<2, TileShape>;

// H K (D_latent D_rope) B
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;

using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
using StrideO = StrideK; // H D B
using StrideLSE = cute::tuple<_1, int>; // H B

using TileScheduler = std::conditional_t<PersistenceOption::value,
Sm100MlaPersistentTileScheduler,
Sm100MlaIndividualTileScheduler>;

using FmhaKernel =
cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler,
/*kIsCpAsync=*/true>;
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
};

template <typename T>
typename T::Fmha::Arguments args_from_options(
at::Tensor const& out, at::Tensor const& q_nope_and_q_pe,
at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens,
at::Tensor const& page_table) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = q_nope_and_q_pe.device().index();
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);

int batches = q_nope_and_q_pe.sizes()[0];
int page_count_per_seq = page_table.sizes()[1];
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
int page_size = kv_c_and_k_pe_cache.sizes()[1];
int max_seq_len = page_size * page_count_per_seq;
using TileShapeH = typename T::TileShapeH;
using TileShapeD = typename T::TileShapeD;
auto problem_shape =
cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);

auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;

// the scale is based on the non-absorbed sizes, change as appropriate
// we can't determine this parameter from the info we have, it's an input
int D_non_latent = 128;
float scale = 1.0 / sqrt(1.0 * (D_non_latent + D_rope));

using StrideQ = typename T::StrideQ;
using StrideK = typename T::StrideK;
using StrideO = typename T::StrideO;
using StrideLSE = typename T::StrideLSE;

StrideQ stride_Q =
cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{},
static_cast<int64_t>(H * (0 + D_latent + D_rope)));
StrideK stride_C =
cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{},
static_cast<int64_t>(page_size * (D_latent + D_rope)));
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H);
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{},
static_cast<int64_t>(0 + H * D_latent));

using Element = typename T::Element;
using ElementOut = typename T::ElementOut;
using ElementAcc = typename T::ElementAcc;
auto Q_ptr = static_cast<Element*>(q_nope_and_q_pe.data_ptr());
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
typename T::Fmha::Arguments arguments{
problem_shape,
{scale, Q_ptr, stride_Q, Q_ptr + D_latent, stride_Q, C_ptr, stride_C,
C_ptr + D_latent, stride_C, static_cast<int*>(seq_lens.data_ptr()),
static_cast<int*>(page_table.data_ptr()), stride_PT, page_count_total,
page_size},
{static_cast<ElementOut*>(out.data_ptr()), stride_O,
static_cast<ElementAcc*>(nullptr), stride_LSE},
hw_info,
-1, // split_kv
nullptr, // is_var_split_kv
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
// split_kv automatically based on batch size and sequence length to balance
// workload across available SMs. Consider using var_split_kv for manual
// control if needed.
T::Fmha::set_split_kv(arguments);
return arguments;
}

template <typename Element>
void runMla(at::Tensor const& out, at::Tensor const& q_nope_and_q_pe,
at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens,
at::Tensor const& page_table, cudaStream_t stream) {
using MlaSm100Type = MlaSm100<Element>;
typename MlaSm100Type::Fmha fmha;
auto arguments = args_from_options<MlaSm100Type>(
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table);
size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments);
auto const workspace_options = torch::TensorOptions()
.dtype(torch::kUInt8)
.device(q_nope_and_q_pe.device());
auto workspace = torch::empty(workspace_size, workspace_options);

CUTLASS_CHECK(fmha.can_implement(arguments));

CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));

CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
}

void cutlass_mla_decode_sm100a(torch::Tensor const& out,
torch::Tensor const& q_nope_and_q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table) {
auto in_dtype = q_nope_and_q_pe.dtype();
at::cuda::CUDAGuard device_guard{(char)q_nope_and_q_pe.get_device()};
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(q_nope_and_q_pe.get_device());
if (in_dtype == at::ScalarType::Half) {
runMla<cutlass::half_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens,
page_table, stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
runMla<cutlass::bfloat16_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache,
seq_lens, page_table, stream);
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
runMla<cutlass::float_e4m3_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache,
seq_lens, page_table, stream);
} else {
TORCH_CHECK(false, "Unsupported input data type of MLA");
}
}
6 changes: 6 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ void advance_step_flashinfer(
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);

void cutlass_mla_decode(torch::Tensor const& out,
torch::Tensor const& q_nope_and_q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table);

torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);

#ifndef USE_ROCM
Expand Down
7 changes: 7 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
") -> ()");
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);

// Compute MLA decode using cutlass.
ops.def(
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe,"
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
" Tensor page_table) -> ()");
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);

// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
Expand Down
89 changes: 89 additions & 0 deletions tests/kernels/test_cutlass_mla_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
import torch.nn.functional as F
from torch import Tensor

import vllm._custom_ops as ops
from vllm.platforms import current_platform

if not current_platform.has_device_capability(100):
pytest.skip(
reason="Cutlass MLA Requires compute capability of 10 or above.",
allow_module_level=True)


def ref_mla(
out: Tensor, # (bs, num_heads, v_head_dim)
query: Tensor, # (bs, num_heads, head_dim)
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
scale: float,
block_tables: Tensor, # (bs, max_num_blocks)
seq_lens: Tensor, # (bs,)
):
bs, num_heads, v_head_dim = out.shape
head_dim = query.shape[2]

for i in range(bs):
# gather and flatten KV-cache
kv = kv_cache[
block_tables[i]] # (max_num_blocks, block_size, head_dim)
kv = kv.view(1, -1,
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
v = kv[:, :, :v_head_dim]

q = query[i].view(num_heads, 1, head_dim)
o = F.scaled_dot_product_attention(q,
kv,
v,
scale=scale,
enable_gqa=True)
out[i] = o.view(num_heads, v_head_dim)

return out


@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096])
@pytest.mark.parametrize("bs", [1, 2, 4])
@pytest.mark.parametrize("varlen", [False, True])
@pytest.mark.parametrize("block_size", [16, 64, 128])
def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int,
varlen: bool, block_size: int):
torch.set_default_dtype(dtype)
torch.set_default_device('cuda')
torch.manual_seed(42)

d = 576
h_q = 128
dv = 512

q_nope_dim = 128
q_pe_dim = 64
scale = (q_nope_dim + q_pe_dim)**(-0.5)
if varlen:
seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2)
seq_lens = seq_lens.clip(2).to(torch.int32)
else:
seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32)
max_seq_len = seq_lens.max().item()
block_num = (max_seq_len + block_size - 1) // block_size

# Pad block_num so that small blocks can be packed into full 128-sized
# CUTLASS tiles. One 128-wide tile can hold (128 // block_size) small
# blocks.
pack_factor = 128 // block_size
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor

q = torch.randn(bs, h_q, d)
block_table = torch.randint(0,
bs * block_num, (bs, block_num),
dtype=torch.int32)

kv_cache = torch.randn(block_table.numel(), block_size, d)

out_ref = q.new_zeros(bs, h_q, dv)
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
out = ops.cutlass_mla_decode(q, kv_cache, seq_lens, block_table)

torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)
Loading