4
4
#include " core/registration.h"
5
5
6
6
#include < torch/library.h>
7
+ #include < torch/version.h>
7
8
8
9
// Note on op signatures:
9
10
// The X_meta signatures are for the meta functions corresponding to op X.
17
18
18
19
TORCH_LIBRARY_EXPAND (TORCH_EXTENSION_NAME, ops) {
19
20
// vLLM custom ops
21
+ //
22
+
23
+ // The default behavior in PyTorch 2.6 is "requires_contiguous", so we need
24
+ // to override this for many GEMMs with the following tag. Otherwise,
25
+ // torch.compile will force all input tensors to be contiguous(), which
26
+ // will break many custom ops that require column-major weight matrices.
27
+ // TODO: remove this for PyTorch 2.8, when the default is planned to switch
28
+ // to match exact eager-mode strides.
29
+ at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
20
30
21
31
ops.def (" weak_ref_tensor(Tensor input) -> Tensor" );
22
32
ops.impl (" weak_ref_tensor" , torch::kCUDA , &weak_ref_tensor);
@@ -163,25 +173,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
163
173
ops.def (
164
174
" aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
165
175
" Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
166
- " -> Tensor" );
176
+ " -> Tensor" ,
177
+ {stride_tag});
167
178
ops.impl (" aqlm_gemm" , torch::kCUDA , &aqlm_gemm);
168
179
169
180
// Decompression method for AQLM.
170
181
ops.def (
171
182
" aqlm_dequant(Tensor codes, Tensor codebooks, "
172
- " int[] codebook_partition_sizes) -> Tensor" );
183
+ " int[] codebook_partition_sizes) -> Tensor" ,
184
+ {stride_tag});
173
185
ops.impl (" aqlm_dequant" , torch::kCUDA , &aqlm_dequant);
174
186
175
187
// Quantized GEMM for AWQ.
176
188
ops.def (
177
189
" awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
178
- " Tensor _zeros, SymInt split_k_iters) -> Tensor" );
190
+ " Tensor _zeros, SymInt split_k_iters) -> Tensor" ,
191
+ {stride_tag});
179
192
ops.impl (" awq_gemm" , torch::kCUDA , &awq_gemm);
180
193
181
194
// Dequantization for AWQ.
182
195
ops.def (
183
196
" awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
184
- " Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor" );
197
+ " Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor" ,
198
+ {stride_tag});
185
199
ops.impl (" awq_dequantize" , torch::kCUDA , &awq_dequantize);
186
200
187
201
// Note about marlin kernel 'workspace' arguments:
@@ -202,15 +216,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
202
216
ops.def (
203
217
" marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
204
218
" Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
205
- " Tensor" );
219
+ " Tensor" ,
220
+ {stride_tag});
206
221
// conditionally compiled so impl in source file
207
222
208
223
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
209
224
ops.def (
210
225
" gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
211
226
" Tensor b_scales, Tensor workspace, "
212
227
" int b_q_type, "
213
- " SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor" );
228
+ " SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor" ,
229
+ {stride_tag});
214
230
// conditionally compiled so impl in source file
215
231
216
232
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
@@ -236,7 +252,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
236
252
" Tensor? channel_scales,"
237
253
" Tensor? token_scales,"
238
254
" str? schedule"
239
- " ) -> Tensor" );
255
+ " ) -> Tensor" ,
256
+ {stride_tag});
240
257
ops.def (
241
258
" machete_prepack_B("
242
259
" Tensor B,"
@@ -255,7 +272,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
255
272
" Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
256
273
" int b_q_type, "
257
274
" SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
258
- " bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor" );
275
+ " bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor" ,
276
+ {stride_tag});
259
277
// conditionally compiled so impl registration is in source file
260
278
261
279
// gptq_marlin repack from GPTQ.
@@ -291,30 +309,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
291
309
ops.def (
292
310
" fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
293
311
" Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
294
- " SymInt size_k) -> Tensor" );
312
+ " SymInt size_k) -> Tensor" ,
313
+ {stride_tag});
295
314
// conditionally compiled so impl registration is in source file
296
315
297
316
// marlin_qqq_gemm for QQQ.
298
317
ops.def (
299
318
" marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
300
319
" Tensor s_tok, Tensor s_ch, Tensor s_group, "
301
320
" Tensor! workspace, SymInt size_m, SymInt size_n, "
302
- " SymInt size_k) -> Tensor" );
321
+ " SymInt size_k) -> Tensor" ,
322
+ {stride_tag});
303
323
// conditionally compiled so impl registration is in source file
304
324
305
325
// CUTLASS nvfp4 block scaled GEMM
306
326
ops.def (
307
327
" cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
308
328
" Tensor block_scale_a, Tensor block_scale_b,"
309
- " Tensor alpha) -> ()" );
329
+ " Tensor alpha) -> ()" ,
330
+ {stride_tag});
310
331
ops.impl (" cutlass_scaled_fp4_mm" , torch::kCUDA , &cutlass_scaled_fp4_mm);
311
332
312
333
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
313
334
// quantization, as well as bias
314
335
ops.def (
315
336
" cutlass_scaled_mm(Tensor! out, Tensor a,"
316
337
" Tensor b, Tensor a_scales,"
317
- " Tensor b_scales, Tensor? bias) -> ()" );
338
+ " Tensor b_scales, Tensor? bias) -> ()" ,
339
+ {stride_tag});
318
340
ops.impl (" cutlass_scaled_mm" , torch::kCUDA , &cutlass_scaled_mm);
319
341
320
342
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
@@ -323,7 +345,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
323
345
" cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
324
346
" Tensor b, Tensor a_scales,"
325
347
" Tensor b_scales, Tensor azp_adj,"
326
- " Tensor? azp, Tensor? bias) -> ()" );
348
+ " Tensor? azp, Tensor? bias) -> ()" ,
349
+ {stride_tag});
327
350
ops.impl (" cutlass_scaled_mm_azp" , torch::kCUDA , &cutlass_scaled_mm_azp);
328
351
329
352
// Check if cutlass scaled_mm is supported for CUDA devices of the given
@@ -351,7 +374,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
351
374
" cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
352
375
" Tensor bt_nzs,"
353
376
" Tensor bt_meta, Tensor a_scales,"
354
- " Tensor b_scales, Tensor? bias) -> ()" );
377
+ " Tensor b_scales, Tensor? bias) -> ()" ,
378
+ {stride_tag});
355
379
ops.impl (" cutlass_scaled_sparse_mm" , torch::kCUDA , &cutlass_scaled_sparse_mm);
356
380
357
381
// CUTLASS sparse matrix compressor
@@ -407,7 +431,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
407
431
ops.def (
408
432
" gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
409
433
" Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
410
- " -> Tensor" );
434
+ " -> Tensor" ,
435
+ {stride_tag});
411
436
ops.impl (" gptq_gemm" , torch::kCUDA , &gptq_gemm);
412
437
413
438
// Post processing for GPTQ.
0 commit comments