Skip to content

Commit 4e123a7

Browse files
tlrmchlsmthlulmer
authored andcommitted
[Kernel] Add needs_fixed_stride_order tag to most GEMMs (vllm-project#14306)
Signed-off-by: Tyler Michael Smith <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
1 parent 047ee9d commit 4e123a7

File tree

1 file changed

+40
-15
lines changed

1 file changed

+40
-15
lines changed

csrc/torch_bindings.cpp

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "core/registration.h"
55

66
#include <torch/library.h>
7+
#include <torch/version.h>
78

89
// Note on op signatures:
910
// The X_meta signatures are for the meta functions corresponding to op X.
@@ -17,6 +18,15 @@
1718

1819
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
1920
// vLLM custom ops
21+
//
22+
23+
// The default behavior in PyTorch 2.6 is "requires_contiguous", so we need
24+
// to override this for many GEMMs with the following tag. Otherwise,
25+
// torch.compile will force all input tensors to be contiguous(), which
26+
// will break many custom ops that require column-major weight matrices.
27+
// TODO: remove this for PyTorch 2.8, when the default is planned to switch
28+
// to match exact eager-mode strides.
29+
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
2030

2131
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
2232
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
@@ -163,25 +173,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
163173
ops.def(
164174
"aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
165175
"Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
166-
"-> Tensor");
176+
"-> Tensor",
177+
{stride_tag});
167178
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
168179

169180
// Decompression method for AQLM.
170181
ops.def(
171182
"aqlm_dequant(Tensor codes, Tensor codebooks, "
172-
"int[] codebook_partition_sizes) -> Tensor");
183+
"int[] codebook_partition_sizes) -> Tensor",
184+
{stride_tag});
173185
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
174186

175187
// Quantized GEMM for AWQ.
176188
ops.def(
177189
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
178-
"Tensor _zeros, SymInt split_k_iters) -> Tensor");
190+
"Tensor _zeros, SymInt split_k_iters) -> Tensor",
191+
{stride_tag});
179192
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
180193

181194
// Dequantization for AWQ.
182195
ops.def(
183196
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
184-
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor");
197+
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor",
198+
{stride_tag});
185199
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
186200

187201
// Note about marlin kernel 'workspace' arguments:
@@ -202,15 +216,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
202216
ops.def(
203217
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
204218
"Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
205-
"Tensor");
219+
"Tensor",
220+
{stride_tag});
206221
// conditionally compiled so impl in source file
207222

208223
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
209224
ops.def(
210225
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
211226
"Tensor b_scales, Tensor workspace, "
212227
"int b_q_type, "
213-
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor");
228+
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor",
229+
{stride_tag});
214230
// conditionally compiled so impl in source file
215231

216232
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
@@ -236,7 +252,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
236252
" Tensor? channel_scales,"
237253
" Tensor? token_scales,"
238254
" str? schedule"
239-
") -> Tensor");
255+
") -> Tensor",
256+
{stride_tag});
240257
ops.def(
241258
"machete_prepack_B("
242259
" Tensor B,"
@@ -255,7 +272,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
255272
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
256273
"int b_q_type, "
257274
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
258-
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
275+
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
276+
{stride_tag});
259277
// conditionally compiled so impl registration is in source file
260278

261279
// gptq_marlin repack from GPTQ.
@@ -291,30 +309,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
291309
ops.def(
292310
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
293311
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
294-
"SymInt size_k) -> Tensor");
312+
"SymInt size_k) -> Tensor",
313+
{stride_tag});
295314
// conditionally compiled so impl registration is in source file
296315

297316
// marlin_qqq_gemm for QQQ.
298317
ops.def(
299318
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
300319
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
301320
"Tensor! workspace, SymInt size_m, SymInt size_n, "
302-
"SymInt size_k) -> Tensor");
321+
"SymInt size_k) -> Tensor",
322+
{stride_tag});
303323
// conditionally compiled so impl registration is in source file
304324

305325
// CUTLASS nvfp4 block scaled GEMM
306326
ops.def(
307327
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
308328
" Tensor block_scale_a, Tensor block_scale_b,"
309-
" Tensor alpha) -> ()");
329+
" Tensor alpha) -> ()",
330+
{stride_tag});
310331
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
311332

312333
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
313334
// quantization, as well as bias
314335
ops.def(
315336
"cutlass_scaled_mm(Tensor! out, Tensor a,"
316337
" Tensor b, Tensor a_scales,"
317-
" Tensor b_scales, Tensor? bias) -> ()");
338+
" Tensor b_scales, Tensor? bias) -> ()",
339+
{stride_tag});
318340
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
319341

320342
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
@@ -323,7 +345,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
323345
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
324346
" Tensor b, Tensor a_scales,"
325347
" Tensor b_scales, Tensor azp_adj,"
326-
" Tensor? azp, Tensor? bias) -> ()");
348+
" Tensor? azp, Tensor? bias) -> ()",
349+
{stride_tag});
327350
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
328351

329352
// Check if cutlass scaled_mm is supported for CUDA devices of the given
@@ -351,7 +374,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
351374
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
352375
" Tensor bt_nzs,"
353376
" Tensor bt_meta, Tensor a_scales,"
354-
" Tensor b_scales, Tensor? bias) -> ()");
377+
" Tensor b_scales, Tensor? bias) -> ()",
378+
{stride_tag});
355379
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
356380

357381
// CUTLASS sparse matrix compressor
@@ -407,7 +431,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
407431
ops.def(
408432
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
409433
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
410-
"-> Tensor");
434+
"-> Tensor",
435+
{stride_tag});
411436
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
412437

413438
// Post processing for GPTQ.

0 commit comments

Comments
 (0)