Skip to content

Commit 1f20b72

Browse files
authored
[Torch Dialect] add canonicalize for aten.min.other (#2452)
1 parent cc4a5d9 commit 1f20b72

File tree

7 files changed

+103
-0
lines changed

7 files changed

+103
-0
lines changed

e2e_testing/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,8 @@
934934
"ElementwiseBinaryStaticShapeModule_basic",
935935
"ElementwiseMinimumModule_basic",
936936
"ElementwiseMinimumIntModule_basic",
937+
"ElementwiseMinOtherIntModule_basic",
938+
"ElementwiseMinOtherModule_basic",
937939
"ElementwiseMaximumModule_basic",
938940
"ElementwiseMaximumIntModule_basic",
939941
"ElementwiseMaxOtherIntModule_basic",

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8967,6 +8967,31 @@ def Torch_AtenMinOp : Torch_Op<"aten.min", [
89678967
}];
89688968
}
89698969

8970+
def Torch_AtenMinOtherOp : Torch_Op<"aten.min.other", [
8971+
AllowsTypeRefinement,
8972+
HasValueSemantics,
8973+
ReadOnly
8974+
]> {
8975+
let summary = "Generated op for `aten::min.other : (Tensor, Tensor) -> (Tensor)`";
8976+
let arguments = (ins
8977+
AnyTorchTensorType:$self,
8978+
AnyTorchTensorType:$other
8979+
);
8980+
let results = (outs
8981+
AnyTorchTensorType:$result
8982+
);
8983+
let hasCustomAssemblyFormat = 1;
8984+
let extraClassDefinition = [{
8985+
ParseResult AtenMinOtherOp::parse(OpAsmParser &parser, OperationState &result) {
8986+
return parseDefaultTorchOp(parser, result, 2, 1);
8987+
}
8988+
void AtenMinOtherOp::print(OpAsmPrinter &printer) {
8989+
printDefaultTorchOp(printer, *this, 2, 1);
8990+
}
8991+
}];
8992+
let hasCanonicalizer = 1;
8993+
}
8994+
89708995
def Torch_AtenMinDimOp : Torch_Op<"aten.min.dim", [
89718996
AllowsTypeRefinement,
89728997
HasValueSemantics,

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,20 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
930930
});
931931
}
932932

933+
//===----------------------------------------------------------------------===//
934+
// AtenMinOtherOp
935+
//===----------------------------------------------------------------------===//
936+
937+
void AtenMinOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
938+
MLIRContext *context) {
939+
// `aten.min.other` -> `aten.minimum`
940+
patterns.add(+[](AtenMinOtherOp op, PatternRewriter &rewriter) {
941+
rewriter.replaceOpWithNewOp<AtenMinimumOp>(op, op.getType(), op.getSelf(),
942+
op.getOther());
943+
return success();
944+
});
945+
}
946+
933947
//===----------------------------------------------------------------------===//
934948
// AtenMaxOtherOp
935949
//===----------------------------------------------------------------------===//

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6570,6 +6570,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
65706570
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
65716571
" return %0 : !torch.list<int>\n"
65726572
" }\n"
6573+
" func.func @\"__torch_mlir_shape_fn.aten.min.other\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
6574+
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
6575+
" return %0 : !torch.list<int>\n"
6576+
" }\n"
65736577
" func.func @\"__torch_mlir_shape_fn.aten.max\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
65746578
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
65756579
" return %0 : !torch.list<int>\n"
@@ -10266,6 +10270,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1026610270
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1026710271
" return %0#1 : !torch.int\n"
1026810272
" }\n"
10273+
" func.func @\"__torch_mlir_dtype_fn.aten.min.other\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
10274+
" %0 = call @\"__torch_mlir_dtype_fn.aten.minimum\"(%arg0, %arg1) : (!torch.tuple<int, int>, !torch.tuple<int, int>) -> !torch.int\n"
10275+
" return %0 : !torch.int\n"
10276+
" }\n"
1026910277
" func.func @\"__torch_mlir_dtype_fn.aten.max\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
1027010278
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1027110279
" return %0#1 : !torch.int\n"

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,9 @@ def aten〇all〡shape(self: List[int]) -> List[int]:
312312
def aten〇min〡shape(self: List[int]) -> List[int]:
313313
return []
314314

315+
def aten〇min〇other〡shape(self: List[int], other: List[int]) -> List[int]:
316+
return upstream_shape_functions.broadcast(self, other)
317+
315318
def aten〇max〡shape(self: List[int]) -> List[int]:
316319
return []
317320

@@ -3059,6 +3062,10 @@ def aten〇min〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
30593062
self_rank, self_dtype = self_rank_dtype
30603063
return self_dtype
30613064

3065+
@check_dtype_function(_check_two_tensor_op())
3066+
def aten〇min〇other〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
3067+
return aten〇minimum〡dtype(self_rank_dtype, other_rank_dtype)
3068+
30623069
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
30633070
def aten〇max〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
30643071
self_rank, self_dtype = self_rank_dtype

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,7 @@ def emit_with_mutating_variants(key, **kwargs):
567567
emit("aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
568568
emit("aten::amax : (Tensor, int[], bool) -> (Tensor)")
569569
emit("aten::min : (Tensor) -> (Tensor)")
570+
emit("aten::min.other : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
570571
emit("aten::min.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
571572
emit("aten::amin : (Tensor, int[], bool) -> (Tensor)")
572573
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True)

python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,52 @@ def ElementwiseMinimumIntModule_basic(module, tu: TestUtils):
653653
# ==============================================================================
654654

655655

656+
class ElementwiseMinOtherModule(torch.nn.Module):
657+
658+
def __init__(self):
659+
super().__init__()
660+
661+
@export
662+
@annotate_args([
663+
None,
664+
([-1, -1], torch.float32, True),
665+
([-1, -1], torch.float32, True),
666+
])
667+
def forward(self, x, y):
668+
return x.min(y)
669+
670+
671+
@register_test_case(module_factory=lambda: ElementwiseMinOtherModule())
672+
def ElementwiseMinOtherModule_basic(module, tu: TestUtils):
673+
module.forward(tu.rand(3, 5), tu.rand(3, 5))
674+
675+
676+
# ==============================================================================
677+
678+
679+
class ElementwiseMinOtherIntModule(torch.nn.Module):
680+
681+
def __init__(self):
682+
super().__init__()
683+
684+
@export
685+
@annotate_args([
686+
None,
687+
([-1, -1], torch.int64, True),
688+
([-1, -1], torch.int64, True),
689+
])
690+
def forward(self, x, y):
691+
return x.min(y)
692+
693+
694+
@register_test_case(module_factory=lambda: ElementwiseMinOtherIntModule())
695+
def ElementwiseMinOtherIntModule_basic(module, tu: TestUtils):
696+
module.forward(tu.randint(3, 5, high=10), tu.randint(3, 5, high=10))
697+
698+
699+
# ==============================================================================
700+
701+
656702
class ElementwiseMaximumModule(torch.nn.Module):
657703

658704
def __init__(self):

0 commit comments

Comments
 (0)