37
37
using namespace cute ;
38
38
using namespace cutlass ::fmha::kernel;
39
39
40
- template <bool v>
40
+ template <bool v>
41
41
struct IsPersistent {
42
42
static const bool value = v;
43
43
};
@@ -54,31 +54,28 @@ struct MlaSm100 {
54
54
55
55
// H K (D_latent D_rope) B
56
56
using ProblemShape = cute::tuple<TileShapeH, int , TileShapeD, int >;
57
-
57
+
58
58
using StrideQ = cute::tuple<int64_t , _1, int64_t >; // H D B
59
59
using StrideK = cute::tuple<int64_t , _1, int64_t >; // K D B
60
60
using StrideO = StrideK; // H D B
61
61
using StrideLSE = cute::tuple<_1, int >; // H B
62
62
63
- using TileScheduler = std::conditional_t <
64
- PersistenceOption::value,
65
- Sm100MlaPersistentTileScheduler,
66
- Sm100MlaIndividualTileScheduler>;
63
+ using TileScheduler = std::conditional_t <PersistenceOption::value,
64
+ Sm100MlaPersistentTileScheduler,
65
+ Sm100MlaIndividualTileScheduler>;
67
66
68
67
using FmhaKernel =
69
68
cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
70
- TileShape, Element, ElementAcc, ElementOut, ElementAcc,
71
- TileScheduler, /* kIsCpAsync=*/ true >;
69
+ TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler,
70
+ /* kIsCpAsync=*/ true >;
72
71
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
73
72
};
74
73
75
-
76
74
template <typename T>
77
- typename T::Fmha::Arguments args_from_options (at::Tensor const & out,
78
- at::Tensor const & q_nope_and_q_pe,
79
- at::Tensor const & kv_c_and_k_pe_cache,
80
- at::Tensor const & seq_lens,
81
- at::Tensor const & page_table) {
75
+ typename T::Fmha::Arguments args_from_options (
76
+ at::Tensor const & out, at::Tensor const & q_nope_and_q_pe,
77
+ at::Tensor const & kv_c_and_k_pe_cache, at::Tensor const & seq_lens,
78
+ at::Tensor const & page_table) {
82
79
cutlass::KernelHardwareInfo hw_info;
83
80
hw_info.device_id = q_nope_and_q_pe.device ().index ();
84
81
hw_info.sm_count =
@@ -92,8 +89,8 @@ typename T::Fmha::Arguments args_from_options(at::Tensor const& out,
92
89
int max_seq_len = page_size * page_count_per_seq;
93
90
using TileShapeH = typename T::TileShapeH;
94
91
using TileShapeD = typename T::TileShapeD;
95
- auto problem_shape = cute::make_tuple (
96
- TileShapeH{}, max_seq_len, TileShapeD{}, batches);
92
+ auto problem_shape =
93
+ cute::make_tuple ( TileShapeH{}, max_seq_len, TileShapeD{}, batches);
97
94
98
95
auto [H, K, D, B] = problem_shape;
99
96
auto [D_latent, D_rope] = D;
@@ -108,66 +105,55 @@ typename T::Fmha::Arguments args_from_options(at::Tensor const& out,
108
105
using StrideO = typename T::StrideO;
109
106
using StrideLSE = typename T::StrideLSE;
110
107
111
- StrideQ stride_Q = cute::make_tuple (
112
- static_cast <int64_t >(0 + D_latent + D_rope),
113
- _1{},
114
- static_cast <int64_t >(H * (0 + D_latent + D_rope)));
115
- StrideK stride_C = cute::make_tuple (
116
- static_cast <int64_t >(0 + D_latent + D_rope),
117
- _1{},
118
- static_cast <int64_t >(page_size * (D_latent + D_rope)));
108
+ StrideQ stride_Q =
109
+ cute::make_tuple (static_cast <int64_t >(0 + D_latent + D_rope), _1{},
110
+ static_cast <int64_t >(H * (0 + D_latent + D_rope)));
111
+ StrideK stride_C =
112
+ cute::make_tuple (static_cast <int64_t >(0 + D_latent + D_rope), _1{},
113
+ static_cast <int64_t >(page_size * (D_latent + D_rope)));
119
114
StrideLSE stride_PT = cute::make_stride (_1{}, page_count_per_seq);
120
115
StrideLSE stride_LSE = cute::make_tuple (_1{}, 0 + H);
121
- StrideO stride_O = cute::make_tuple (
122
- static_cast <int64_t >(0 + D_latent),
123
- _1{},
124
- static_cast <int64_t >(0 + H * D_latent));
116
+ StrideO stride_O =
117
+ cute::make_tuple (static_cast <int64_t >(0 + D_latent), _1{},
118
+ static_cast <int64_t >(0 + H * D_latent));
125
119
126
120
using Element = typename T::Element;
127
121
using ElementOut = typename T::ElementOut;
128
122
using ElementAcc = typename T::ElementAcc;
129
123
auto Q_ptr = static_cast <Element*>(q_nope_and_q_pe.data_ptr ());
130
124
auto C_ptr = static_cast <Element*>(kv_c_and_k_pe_cache.data_ptr ());
131
125
typename T::Fmha::Arguments arguments{
132
- problem_shape,
133
- { scale,
134
- Q_ptr, stride_Q,
135
- Q_ptr + D_latent, stride_Q,
136
- C_ptr, stride_C,
137
- C_ptr + D_latent, stride_C,
138
- static_cast <int *>(seq_lens.data_ptr ()),
139
- static_cast <int *>(page_table.data_ptr ()), stride_PT,
140
- page_count_total, page_size},
141
- { static_cast <ElementOut*>(out.data_ptr ()), stride_O,
142
- // static_cast<ElementAcc*>(lse.data_ptr()), stride_LSE},
143
- static_cast <ElementAcc*>(nullptr ), stride_LSE},
144
- hw_info,
145
- -1 , // split_kv
146
- nullptr , // is_var_split_kv
126
+ problem_shape,
127
+ {scale, Q_ptr, stride_Q, Q_ptr + D_latent, stride_Q, C_ptr, stride_C,
128
+ C_ptr + D_latent, stride_C, static_cast <int *>(seq_lens.data_ptr ()),
129
+ static_cast <int *>(page_table.data_ptr ()), stride_PT, page_count_total,
130
+ page_size},
131
+ {static_cast <ElementOut*>(out.data_ptr ()), stride_O,
132
+ static_cast <ElementAcc*>(nullptr ), stride_LSE},
133
+ hw_info,
134
+ -1 , // split_kv
135
+ nullptr , // is_var_split_kv
147
136
};
148
137
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
149
- // split_kv automatically based on batch size and sequence length to balance
138
+ // split_kv automatically based on batch size and sequence length to balance
150
139
// workload across available SMs. Consider using var_split_kv for manual
151
140
// control if needed.
152
141
T::Fmha::set_split_kv (arguments);
153
142
return arguments;
154
143
}
155
144
156
145
template <typename Element>
157
- void runMla (at::Tensor const & out,
158
- at::Tensor const & q_nope_and_q_pe,
159
- at::Tensor const & kv_c_and_k_pe_cache,
160
- at::Tensor const & seq_lens,
161
- at::Tensor const & page_table,
162
- cudaStream_t stream) {
146
+ void runMla (at::Tensor const & out, at::Tensor const & q_nope_and_q_pe,
147
+ at::Tensor const & kv_c_and_k_pe_cache, at::Tensor const & seq_lens,
148
+ at::Tensor const & page_table, cudaStream_t stream) {
163
149
using MlaSm100Type = MlaSm100<Element>;
164
150
typename MlaSm100Type::Fmha fmha;
165
- auto arguments =
166
- args_from_options<MlaSm100Type>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache,
167
- seq_lens, page_table);
151
+ auto arguments = args_from_options<MlaSm100Type>(
152
+ out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table);
168
153
size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size (arguments);
169
- auto const workspace_options =
170
- torch::TensorOptions ().dtype (torch::kUInt8 ).device (q_nope_and_q_pe.device ());
154
+ auto const workspace_options = torch::TensorOptions ()
155
+ .dtype (torch::kUInt8 )
156
+ .device (q_nope_and_q_pe.device ());
171
157
auto workspace = torch::empty (workspace_size, workspace_options);
172
158
173
159
CUTLASS_CHECK (fmha.can_implement (arguments));
@@ -182,20 +168,20 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out,
182
168
torch::Tensor const & kv_c_and_k_pe_cache,
183
169
torch::Tensor const & seq_lens,
184
170
torch::Tensor const & page_table) {
185
- auto in_dtype = q_nope_and_q_pe.dtype ();
186
- at::cuda::CUDAGuard device_guard{(char )q_nope_and_q_pe.get_device ()};
187
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream (
188
- q_nope_and_q_pe.get_device ());
189
- if (in_dtype == at::ScalarType::Half) {
190
- runMla<cutlass::half_t >(
191
- out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, stream);
192
- } else if (in_dtype == at::ScalarType::BFloat16) {
193
- runMla<cutlass::bfloat16_t >(
194
- out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, stream);
195
- } else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
196
- runMla<cutlass::float_e4m3_t >(
197
- out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, stream);
198
- } else {
199
- TORCH_CHECK (false , " Unsupported input data type of MLA" );
200
- }
171
+ auto in_dtype = q_nope_and_q_pe.dtype ();
172
+ at::cuda::CUDAGuard device_guard{(char )q_nope_and_q_pe.get_device ()};
173
+ const cudaStream_t stream =
174
+ at::cuda::getCurrentCUDAStream ( q_nope_and_q_pe.get_device ());
175
+ if (in_dtype == at::ScalarType::Half) {
176
+ runMla<cutlass::half_t >(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens,
177
+ page_table, stream);
178
+ } else if (in_dtype == at::ScalarType::BFloat16) {
179
+ runMla<cutlass::bfloat16_t >(out, q_nope_and_q_pe, kv_c_and_k_pe_cache,
180
+ seq_lens, page_table, stream);
181
+ } else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
182
+ runMla<cutlass::float_e4m3_t >(out, q_nope_and_q_pe, kv_c_and_k_pe_cache,
183
+ seq_lens, page_table, stream);
184
+ } else {
185
+ TORCH_CHECK (false , " Unsupported input data type of MLA" );
186
+ }
201
187
}
0 commit comments