Skip to content

Commit b20ac92

Browse files
committed
Format
Signed-off-by: kaixih <[email protected]>
1 parent b8a7d56 commit b20ac92

File tree

3 files changed

+83
-84
lines changed

3 files changed

+83
-84
lines changed

Diff for: csrc/attention/mla/cutlass_mla_kernels.cu

+57-71
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
using namespace cute;
3838
using namespace cutlass::fmha::kernel;
3939

40-
template<bool v>
40+
template <bool v>
4141
struct IsPersistent {
4242
static const bool value = v;
4343
};
@@ -54,31 +54,28 @@ struct MlaSm100 {
5454

5555
// H K (D_latent D_rope) B
5656
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
57-
57+
5858
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
5959
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
6060
using StrideO = StrideK; // H D B
6161
using StrideLSE = cute::tuple<_1, int>; // H B
6262

63-
using TileScheduler = std::conditional_t<
64-
PersistenceOption::value,
65-
Sm100MlaPersistentTileScheduler,
66-
Sm100MlaIndividualTileScheduler>;
63+
using TileScheduler = std::conditional_t<PersistenceOption::value,
64+
Sm100MlaPersistentTileScheduler,
65+
Sm100MlaIndividualTileScheduler>;
6766

6867
using FmhaKernel =
6968
cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
70-
TileShape, Element, ElementAcc, ElementOut, ElementAcc,
71-
TileScheduler, /*kIsCpAsync=*/true>;
69+
TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler,
70+
/*kIsCpAsync=*/true>;
7271
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
7372
};
7473

75-
7674
template <typename T>
77-
typename T::Fmha::Arguments args_from_options(at::Tensor const& out,
78-
at::Tensor const& q_nope_and_q_pe,
79-
at::Tensor const& kv_c_and_k_pe_cache,
80-
at::Tensor const& seq_lens,
81-
at::Tensor const& page_table) {
75+
typename T::Fmha::Arguments args_from_options(
76+
at::Tensor const& out, at::Tensor const& q_nope_and_q_pe,
77+
at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens,
78+
at::Tensor const& page_table) {
8279
cutlass::KernelHardwareInfo hw_info;
8380
hw_info.device_id = q_nope_and_q_pe.device().index();
8481
hw_info.sm_count =
@@ -92,8 +89,8 @@ typename T::Fmha::Arguments args_from_options(at::Tensor const& out,
9289
int max_seq_len = page_size * page_count_per_seq;
9390
using TileShapeH = typename T::TileShapeH;
9491
using TileShapeD = typename T::TileShapeD;
95-
auto problem_shape = cute::make_tuple(
96-
TileShapeH{}, max_seq_len, TileShapeD{}, batches);
92+
auto problem_shape =
93+
cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
9794

9895
auto [H, K, D, B] = problem_shape;
9996
auto [D_latent, D_rope] = D;
@@ -108,66 +105,55 @@ typename T::Fmha::Arguments args_from_options(at::Tensor const& out,
108105
using StrideO = typename T::StrideO;
109106
using StrideLSE = typename T::StrideLSE;
110107

111-
StrideQ stride_Q = cute::make_tuple(
112-
static_cast<int64_t>(0 + D_latent + D_rope),
113-
_1{},
114-
static_cast<int64_t>(H * (0 + D_latent + D_rope)));
115-
StrideK stride_C = cute::make_tuple(
116-
static_cast<int64_t>(0 + D_latent + D_rope),
117-
_1{},
118-
static_cast<int64_t>(page_size * (D_latent + D_rope)));
108+
StrideQ stride_Q =
109+
cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{},
110+
static_cast<int64_t>(H * (0 + D_latent + D_rope)));
111+
StrideK stride_C =
112+
cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{},
113+
static_cast<int64_t>(page_size * (D_latent + D_rope)));
119114
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
120115
StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H);
121-
StrideO stride_O = cute::make_tuple(
122-
static_cast<int64_t>(0 + D_latent),
123-
_1{},
124-
static_cast<int64_t>(0 + H * D_latent));
116+
StrideO stride_O =
117+
cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{},
118+
static_cast<int64_t>(0 + H * D_latent));
125119

126120
using Element = typename T::Element;
127121
using ElementOut = typename T::ElementOut;
128122
using ElementAcc = typename T::ElementAcc;
129123
auto Q_ptr = static_cast<Element*>(q_nope_and_q_pe.data_ptr());
130124
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
131125
typename T::Fmha::Arguments arguments{
132-
problem_shape,
133-
{ scale,
134-
Q_ptr, stride_Q,
135-
Q_ptr + D_latent, stride_Q,
136-
C_ptr, stride_C,
137-
C_ptr + D_latent, stride_C,
138-
static_cast<int*>(seq_lens.data_ptr()),
139-
static_cast<int*>(page_table.data_ptr()), stride_PT,
140-
page_count_total, page_size},
141-
{ static_cast<ElementOut*>(out.data_ptr()), stride_O,
142-
// static_cast<ElementAcc*>(lse.data_ptr()), stride_LSE},
143-
static_cast<ElementAcc*>(nullptr), stride_LSE},
144-
hw_info,
145-
-1, // split_kv
146-
nullptr, // is_var_split_kv
126+
problem_shape,
127+
{scale, Q_ptr, stride_Q, Q_ptr + D_latent, stride_Q, C_ptr, stride_C,
128+
C_ptr + D_latent, stride_C, static_cast<int*>(seq_lens.data_ptr()),
129+
static_cast<int*>(page_table.data_ptr()), stride_PT, page_count_total,
130+
page_size},
131+
{static_cast<ElementOut*>(out.data_ptr()), stride_O,
132+
static_cast<ElementAcc*>(nullptr), stride_LSE},
133+
hw_info,
134+
-1, // split_kv
135+
nullptr, // is_var_split_kv
147136
};
148137
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
149-
// split_kv automatically based on batch size and sequence length to balance
138+
// split_kv automatically based on batch size and sequence length to balance
150139
// workload across available SMs. Consider using var_split_kv for manual
151140
// control if needed.
152141
T::Fmha::set_split_kv(arguments);
153142
return arguments;
154143
}
155144

156145
template <typename Element>
157-
void runMla(at::Tensor const& out,
158-
at::Tensor const& q_nope_and_q_pe,
159-
at::Tensor const& kv_c_and_k_pe_cache,
160-
at::Tensor const& seq_lens,
161-
at::Tensor const& page_table,
162-
cudaStream_t stream) {
146+
void runMla(at::Tensor const& out, at::Tensor const& q_nope_and_q_pe,
147+
at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens,
148+
at::Tensor const& page_table, cudaStream_t stream) {
163149
using MlaSm100Type = MlaSm100<Element>;
164150
typename MlaSm100Type::Fmha fmha;
165-
auto arguments =
166-
args_from_options<MlaSm100Type>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache,
167-
seq_lens, page_table);
151+
auto arguments = args_from_options<MlaSm100Type>(
152+
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table);
168153
size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments);
169-
auto const workspace_options =
170-
torch::TensorOptions().dtype(torch::kUInt8).device(q_nope_and_q_pe.device());
154+
auto const workspace_options = torch::TensorOptions()
155+
.dtype(torch::kUInt8)
156+
.device(q_nope_and_q_pe.device());
171157
auto workspace = torch::empty(workspace_size, workspace_options);
172158

173159
CUTLASS_CHECK(fmha.can_implement(arguments));
@@ -182,20 +168,20 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out,
182168
torch::Tensor const& kv_c_and_k_pe_cache,
183169
torch::Tensor const& seq_lens,
184170
torch::Tensor const& page_table) {
185-
auto in_dtype = q_nope_and_q_pe.dtype();
186-
at::cuda::CUDAGuard device_guard{(char)q_nope_and_q_pe.get_device()};
187-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(
188-
q_nope_and_q_pe.get_device());
189-
if (in_dtype == at::ScalarType::Half) {
190-
runMla<cutlass::half_t>(
191-
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, stream);
192-
} else if (in_dtype == at::ScalarType::BFloat16) {
193-
runMla<cutlass::bfloat16_t>(
194-
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, stream);
195-
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
196-
runMla<cutlass::float_e4m3_t>(
197-
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, stream);
198-
} else {
199-
TORCH_CHECK(false, "Unsupported input data type of MLA");
200-
}
171+
auto in_dtype = q_nope_and_q_pe.dtype();
172+
at::cuda::CUDAGuard device_guard{(char)q_nope_and_q_pe.get_device()};
173+
const cudaStream_t stream =
174+
at::cuda::getCurrentCUDAStream(q_nope_and_q_pe.get_device());
175+
if (in_dtype == at::ScalarType::Half) {
176+
runMla<cutlass::half_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens,
177+
page_table, stream);
178+
} else if (in_dtype == at::ScalarType::BFloat16) {
179+
runMla<cutlass::bfloat16_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache,
180+
seq_lens, page_table, stream);
181+
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
182+
runMla<cutlass::float_e4m3_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache,
183+
seq_lens, page_table, stream);
184+
} else {
185+
TORCH_CHECK(false, "Unsupported input data type of MLA");
186+
}
201187
}

Diff for: tests/kernels/test_cutlass_mla_decode.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
from vllm.platforms import current_platform
99

1010
if not current_platform.has_device_capability(100):
11-
pytest.skip(reason="Cutlass MLA Requires compute capability of 10 or above.",
12-
allow_module_level=True)
11+
pytest.skip(
12+
reason="Cutlass MLA Requires compute capability of 10 or above.",
13+
allow_module_level=True)
14+
1315

1416
def ref_mla(
1517
out: Tensor, # (bs, num_heads, v_head_dim)
@@ -40,20 +42,22 @@ def ref_mla(
4042

4143
return out
4244

45+
4346
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
4447
@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096])
4548
@pytest.mark.parametrize("bs", [1, 2, 4])
4649
@pytest.mark.parametrize("varlen", [False, True])
4750
@pytest.mark.parametrize("block_size", [16, 128])
48-
def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int, varlen: bool, block_size: int):
51+
def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int,
52+
varlen: bool, block_size: int):
4953
torch.set_default_dtype(dtype)
5054
torch.set_default_device('cuda')
5155
torch.manual_seed(42)
5256

5357
d = 576
5458
h_q = 128
5559
dv = 512
56-
60+
5761
q_nope_dim = 128
5862
q_pe_dim = 64
5963
scale = (q_nope_dim + q_pe_dim)**(-0.5)
@@ -66,7 +70,9 @@ def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int, varl
6670
block_num = (max_seq_len + block_size - 1) // block_size
6771

6872
q = torch.randn(bs, h_q, d)
69-
block_table = torch.randint(0, bs * block_num, (bs, block_num), dtype=torch.int32)
73+
block_table = torch.randint(0,
74+
bs * block_num, (bs, block_num),
75+
dtype=torch.int32)
7076

7177
kv_cache = torch.randn(block_table.numel(), block_size, d)
7278

Diff for: vllm/_custom_ops.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -1432,10 +1432,11 @@ def flash_mla_with_kvcache(
14321432
)
14331433
return out, softmax_lse
14341434

1435+
14351436
def cutlass_mla_decode(q_nope_and_q_pe: torch.Tensor,
1436-
kv_c_and_k_pe_cache: torch.Tensor,
1437-
seq_lens: torch.Tensor,
1438-
page_table: torch.Tensor) -> torch.Tensor:
1437+
kv_c_and_k_pe_cache: torch.Tensor,
1438+
seq_lens: torch.Tensor,
1439+
page_table: torch.Tensor) -> torch.Tensor:
14391440
assert not current_platform.is_rocm()
14401441
assert q_nope_and_q_pe.ndim == 3, f"q_nope_and_q_pe must be a 3D tensor, but got {q_nope_and_q_pe.ndim}"
14411442
assert kv_c_and_k_pe_cache.ndim == 3, f"kv_c_and_k_pe_cache must be a 3D tensor, but got {kv_c_and_k_pe_cache.ndim}"
@@ -1446,13 +1447,17 @@ def cutlass_mla_decode(q_nope_and_q_pe: torch.Tensor,
14461447
D_rope = 64
14471448
assert D_q == D_ckv and D_q == D_latent + D_rope, (
14481449
f"D_q must be equal to D_ckv and D_q must be equal to D_latent + D_rope, "
1449-
f"but got D_q = {D_q}, D_ckv = {D_ckv}, D_latent = {D_latent}, D_rope = {D_rope}")
1450+
f"but got D_q = {D_q}, D_ckv = {D_ckv}, D_latent = {D_latent}, D_rope = {D_rope}"
1451+
)
14501452
assert H == 128, f"H must be 128, but got {H}"
1451-
assert PAGE_SIZE > 0 and (PAGE_SIZE & (PAGE_SIZE - 1)) == 0, f"PAGE_SIZE must be a power of 2, but got {PAGE_SIZE}"
1452-
1453+
assert PAGE_SIZE > 0 and (
1454+
PAGE_SIZE & (PAGE_SIZE - 1)
1455+
) == 0, f"PAGE_SIZE must be a power of 2, but got {PAGE_SIZE}"
1456+
14531457
# TODO(kaixih@nvidia): support fp8
14541458
assert q_nope_and_q_pe.dtype in (torch.float16, torch.bfloat16), (
1455-
f'q_nope_and_q_pe.dtype needs to be fp16 or bf16 but got {q_nope_and_q_pe.dtype}.')
1459+
f'q_nope_and_q_pe.dtype needs to be fp16 or bf16 but got {q_nope_and_q_pe.dtype}.'
1460+
)
14561461
assert kv_c_and_k_pe_cache.dtype == q_nope_and_q_pe.dtype, (
14571462
f'kv_c_and_k_pe_cache.dtype needs to be the same as q_nope_and_q_pe.dtype, '
14581463
f'but got {kv_c_and_k_pe_cache.dtype}.')
@@ -1461,7 +1466,9 @@ def cutlass_mla_decode(q_nope_and_q_pe: torch.Tensor,
14611466
assert page_table.dtype == torch.int32, (
14621467
f'page_table.dtype needs to be int32 but got {page_table.dtype}.')
14631468

1464-
out = torch.empty((B_q, H, D_latent), device=q_nope_and_q_pe.device, dtype=q_nope_and_q_pe.dtype)
1469+
out = torch.empty((B_q, H, D_latent),
1470+
device=q_nope_and_q_pe.device,
1471+
dtype=q_nope_and_q_pe.dtype)
14651472

14661473
torch.ops._C.cutlass_mla_decode(out, q_nope_and_q_pe, kv_c_and_k_pe_cache,
14671474
seq_lens, page_table)

0 commit comments

Comments
 (0)