Skip to content

Commit c38308f

Browse files
authored
Add lowering for _convolution.deprecated (#1259)
* Add lowering for _convolution.deprecated
1 parent 99fb4c8 commit c38308f

File tree

7 files changed

+166
-9
lines changed

7 files changed

+166
-9
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

+34
Original file line numberDiff line numberDiff line change
@@ -3444,6 +3444,40 @@ def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [
34443444
}];
34453445
}
34463446

3447+
def Torch_Aten_ConvolutionDeprecatedOp : Torch_Op<"aten._convolution.deprecated", [
3448+
AllowsTypeRefinement,
3449+
HasValueSemantics,
3450+
ReadOnly
3451+
]> {
3452+
let summary = "Generated op for `aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)`";
3453+
let arguments = (ins
3454+
AnyTorchTensorType:$input,
3455+
AnyTorchTensorType:$weight,
3456+
AnyTorchOptionalTensorType:$bias,
3457+
AnyTorchListOfTorchIntType:$stride,
3458+
AnyTorchListOfTorchIntType:$padding,
3459+
AnyTorchListOfTorchIntType:$dilation,
3460+
Torch_BoolType:$transposed,
3461+
AnyTorchListOfTorchIntType:$output_padding,
3462+
Torch_IntType:$groups,
3463+
Torch_BoolType:$benchmark,
3464+
Torch_BoolType:$deterministic,
3465+
Torch_BoolType:$cudnn_enabled
3466+
);
3467+
let results = (outs
3468+
AnyTorchTensorType:$result
3469+
);
3470+
let hasCustomAssemblyFormat = 1;
3471+
let extraClassDefinition = [{
3472+
ParseResult Aten_ConvolutionDeprecatedOp::parse(OpAsmParser &parser, OperationState &result) {
3473+
return parseDefaultTorchOp(parser, result, 12, 1);
3474+
}
3475+
void Aten_ConvolutionDeprecatedOp::print(OpAsmPrinter &printer) {
3476+
printDefaultTorchOp(printer, *this, 12, 1);
3477+
}
3478+
}];
3479+
}
3480+
34473481
def Torch_AtenFlipOp : Torch_Op<"aten.flip", [
34483482
AllowsTypeRefinement,
34493483
HasValueSemantics,

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

+10-7
Original file line numberDiff line numberDiff line change
@@ -927,13 +927,14 @@ class DecomposeAtenConvolutionOverrideableOp
927927
};
928928
} // namespace
929929

930-
// Decompose aten.convolution_overrideable to aten.convolution
930+
// Decompose aten._convolution-like to aten.convolution
931931
namespace {
932-
class DecomposeAten_ConvolutionOp
933-
: public OpRewritePattern<Aten_ConvolutionOp> {
932+
template<typename ConvolutionLikeOp>
933+
class DecomposeAten_ConvolutionLikeOp
934+
: public OpRewritePattern<ConvolutionLikeOp> {
934935
public:
935-
using OpRewritePattern::OpRewritePattern;
936-
LogicalResult matchAndRewrite(Aten_ConvolutionOp op,
936+
using OpRewritePattern<ConvolutionLikeOp>::OpRewritePattern;
937+
LogicalResult matchAndRewrite(ConvolutionLikeOp op,
937938
PatternRewriter &rewriter) const override {
938939

939940
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
@@ -2542,8 +2543,10 @@ class DecomposeComplexOpsPass
25422543
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
25432544
target.addIllegalOp<AtenConvolutionOverrideableOp>();
25442545
patterns.add<DecomposeAtenConvolutionOverrideableOp>(context);
2545-
target.addIllegalOp<Aten_ConvolutionOp>();
2546-
patterns.add<DecomposeAten_ConvolutionOp>(context);
2546+
target.addIllegalOp<Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp>();
2547+
patterns.add<DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>,
2548+
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionDeprecatedOp>>(
2549+
context);
25472550
target.addIllegalOp<AtenConv2dOp>();
25482551
patterns.add<DecomposeAtenConv2dOp>(context);
25492552
patterns.add<DecomposeAtenArangeOp>(context);

lib/Dialect/Torch/Transforms/RefineTypes.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ void TypeAnalysis::visitOperation(Operation *op,
712712

713713
// Promote the two dtypes assuming non-zero rank.
714714
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
715-
Aten_ConvolutionOp, AtenConvolutionOverrideableOp>(op)) {
715+
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenConvolutionOverrideableOp>(op)) {
716716
auto knowledge =
717717
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
718718
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(

lib/Dialect/Torch/Transforms/ShapeLibrary.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -6341,6 +6341,10 @@ module {
63416341
%0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>
63426342
return %0 : !torch.list<int>
63436343
}
6344+
func.func @"__torch_mlir_shape_fn.aten._convolution.deprecated"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.list<int> {
6345+
%0 = call @"__torch_mlir_shape_fn.aten.convolution"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>
6346+
return %0 : !torch.list<int>
6347+
}
63446348
func.func @"__torch_mlir_shape_fn.aten.flip"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
63456349
return %arg0 : !torch.list<int>
63466350
}

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,10 @@ def aten〇convolution(input: List[int], weight: List[int], bias: Optional[List[
940940

941941
def aten〇_convolution(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]:
942942
return aten〇convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)
943-
943+
944+
def aten〇_convolution〇deprecated(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> List[int]:
945+
return aten〇convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)
946+
944947
def aten〇flip(self: List[int], dims: List[int]) -> List[int]:
945948
return self
946949

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

+1
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ def emit_with_mutating_variants(key, **kwargs):
337337
emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
338338
emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
339339
emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)")
340+
emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)")
340341
emit("aten::flip : (Tensor, int[]) -> (Tensor)")
341342
emit(
342343
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"

python/torch_mlir_e2e_test/test_suite/conv.py

+112
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,118 @@ def forward(self, inputVec, weight):
406406
def _Convolution2DTF32Module_basic(module, tu: TestUtils):
407407
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
408408

409+
class _ConvolutionDeprecated2DAllFalseModule(torch.nn.Module):
410+
def __init__(self):
411+
super().__init__()
412+
413+
@export
414+
@annotate_args([
415+
None,
416+
([-1, -1, -1, -1], torch.float32, True),
417+
([-1, -1, -1, -1], torch.float32, True),
418+
])
419+
def forward(self, inputVec, weight):
420+
return torch.ops.aten._convolution(inputVec,
421+
weight,
422+
bias=None,
423+
stride=[3, 3],
424+
padding=[2, 2],
425+
dilation=[1, 1],
426+
transposed=False,
427+
output_padding=[0, 0],
428+
groups=1,
429+
benchmark=False,
430+
deterministic=False,
431+
cudnn_enabled=False)
432+
433+
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DAllFalseModule())
434+
def _ConvolutionDeprecated2DAllFalseModule_basic(module, tu: TestUtils):
435+
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
436+
437+
class _ConvolutionDeprecated2DBenchmarkModule(torch.nn.Module):
438+
def __init__(self):
439+
super().__init__()
440+
441+
@export
442+
@annotate_args([
443+
None,
444+
([-1, -1, -1, -1], torch.float32, True),
445+
([-1, -1, -1, -1], torch.float32, True),
446+
])
447+
def forward(self, inputVec, weight):
448+
return torch.ops.aten._convolution(inputVec,
449+
weight,
450+
bias=None,
451+
stride=[3, 3],
452+
padding=[2, 2],
453+
dilation=[1, 1],
454+
transposed=False,
455+
output_padding=[0, 0],
456+
groups=1,
457+
benchmark=True,
458+
deterministic=False,
459+
cudnn_enabled=False)
460+
461+
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DBenchmarkModule())
462+
def _ConvolutionDeprecated2DBenchmarkModule_basic(module, tu: TestUtils):
463+
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
464+
465+
class _ConvolutionDeprecated2DDeterministicModule(torch.nn.Module):
466+
def __init__(self):
467+
super().__init__()
468+
469+
@export
470+
@annotate_args([
471+
None,
472+
([-1, -1, -1, -1], torch.float32, True),
473+
([-1, -1, -1, -1], torch.float32, True),
474+
])
475+
def forward(self, inputVec, weight):
476+
return torch.ops.aten._convolution(inputVec,
477+
weight,
478+
bias=None,
479+
stride=[3, 3],
480+
padding=[2, 2],
481+
dilation=[1, 1],
482+
transposed=False,
483+
output_padding=[0, 0],
484+
groups=1,
485+
benchmark=False,
486+
deterministic=True,
487+
cudnn_enabled=False)
488+
489+
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DDeterministicModule())
490+
def _ConvolutionDeprecated2DDeterministicModule_basic(module, tu: TestUtils):
491+
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
492+
493+
class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module):
494+
def __init__(self):
495+
super().__init__()
496+
497+
@export
498+
@annotate_args([
499+
None,
500+
([-1, -1, -1, -1], torch.float32, True),
501+
([-1, -1, -1, -1], torch.float32, True),
502+
])
503+
def forward(self, inputVec, weight):
504+
return torch.ops.aten._convolution(inputVec,
505+
weight,
506+
bias=None,
507+
stride=[3, 3],
508+
padding=[2, 2],
509+
dilation=[1, 1],
510+
transposed=False,
511+
output_padding=[0, 0],
512+
groups=1,
513+
benchmark=False,
514+
deterministic=False,
515+
cudnn_enabled=True)
516+
517+
@register_test_case(module_factory=lambda: _ConvolutionDeprecated2DCudnnModule())
518+
def _Convolution2DCudnnModule_basic(module, tu: TestUtils):
519+
module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2))
520+
409521
class ConvolutionModule2DGroups(torch.nn.Module):
410522
def __init__(self):
411523
super().__init__()

0 commit comments

Comments
 (0)