Skip to content

Commit c6c40c5

Browse files
committed
Add canonicalization for aten.add.tensor op
1 parent a34dad2 commit c6c40c5

File tree

4 files changed

+137
-50
lines changed

4 files changed

+137
-50
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -790,55 +790,6 @@ def Torch_AtenBitwiseNot_Op : Torch_Op<"aten.bitwise_not_", [
790790
}];
791791
}
792792

793-
def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [
794-
AllowsTypeRefinement,
795-
HasValueSemantics,
796-
ReadOnly
797-
]> {
798-
let summary = "Generated op for `aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`";
799-
let arguments = (ins
800-
AnyTorchTensorType:$self,
801-
AnyTorchTensorType:$other,
802-
AnyTorchScalarType:$alpha
803-
);
804-
let results = (outs
805-
AnyTorchTensorType:$result
806-
);
807-
let hasCustomAssemblyFormat = 1;
808-
let extraClassDefinition = [{
809-
ParseResult AtenAddTensorOp::parse(OpAsmParser &parser, OperationState &result) {
810-
return parseDefaultTorchOp(parser, result, 3, 1);
811-
}
812-
void AtenAddTensorOp::print(OpAsmPrinter &printer) {
813-
printDefaultTorchOp(printer, *this, 3, 1);
814-
}
815-
}];
816-
}
817-
818-
def Torch_AtenAdd_TensorOp : Torch_Op<"aten.add_.Tensor", [
819-
IsTrailingUnderscoreInplaceVariant,
820-
AllowsTypeRefinement
821-
]> {
822-
let summary = "Generated op for `aten::add_.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`";
823-
let arguments = (ins
824-
AnyTorchTensorType:$self,
825-
AnyTorchTensorType:$other,
826-
AnyTorchScalarType:$alpha
827-
);
828-
let results = (outs
829-
AnyTorchTensorType:$result
830-
);
831-
let hasCustomAssemblyFormat = 1;
832-
let extraClassDefinition = [{
833-
ParseResult AtenAdd_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
834-
return parseDefaultTorchOp(parser, result, 3, 1);
835-
}
836-
void AtenAdd_TensorOp::print(OpAsmPrinter &printer) {
837-
printDefaultTorchOp(printer, *this, 3, 1);
838-
}
839-
}];
840-
}
841-
842793
def Torch_AtenSubTensorOp : Torch_Op<"aten.sub.Tensor", [
843794
AllowsTypeRefinement,
844795
HasValueSemantics,
@@ -2439,6 +2390,56 @@ def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [
24392390
}];
24402391
}
24412392

2393+
def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [
2394+
AllowsTypeRefinement,
2395+
HasValueSemantics,
2396+
ReadOnly
2397+
]> {
2398+
let summary = "Generated op for `aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`";
2399+
let arguments = (ins
2400+
AnyTorchTensorType:$self,
2401+
AnyTorchTensorType:$other,
2402+
AnyTorchScalarType:$alpha
2403+
);
2404+
let results = (outs
2405+
AnyTorchTensorType:$result
2406+
);
2407+
let hasCustomAssemblyFormat = 1;
2408+
let extraClassDefinition = [{
2409+
ParseResult AtenAddTensorOp::parse(OpAsmParser &parser, OperationState &result) {
2410+
return parseDefaultTorchOp(parser, result, 3, 1);
2411+
}
2412+
void AtenAddTensorOp::print(OpAsmPrinter &printer) {
2413+
printDefaultTorchOp(printer, *this, 3, 1);
2414+
}
2415+
}];
2416+
let hasCanonicalizer = 1;
2417+
}
2418+
2419+
def Torch_AtenAdd_TensorOp : Torch_Op<"aten.add_.Tensor", [
2420+
IsTrailingUnderscoreInplaceVariant,
2421+
AllowsTypeRefinement
2422+
]> {
2423+
let summary = "Generated op for `aten::add_.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`";
2424+
let arguments = (ins
2425+
AnyTorchTensorType:$self,
2426+
AnyTorchTensorType:$other,
2427+
AnyTorchScalarType:$alpha
2428+
);
2429+
let results = (outs
2430+
AnyTorchTensorType:$result
2431+
);
2432+
let hasCustomAssemblyFormat = 1;
2433+
let extraClassDefinition = [{
2434+
ParseResult AtenAdd_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
2435+
return parseDefaultTorchOp(parser, result, 3, 1);
2436+
}
2437+
void AtenAdd_TensorOp::print(OpAsmPrinter &printer) {
2438+
printDefaultTorchOp(printer, *this, 3, 1);
2439+
}
2440+
}];
2441+
}
2442+
24422443
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
24432444
AllowsTypeRefinement,
24442445
HasValueSemantics,

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Func/IR/FuncOps.h"
1313
#include "mlir/IR/Builders.h"
1414
#include "mlir/IR/BuiltinOps.h"
15+
#include "mlir/IR/Operation.h"
1516
#include "mlir/IR/PatternMatch.h"
1617
#include "mlir/IR/TypeUtilities.h"
1718
#include "mlir/Support/LLVM.h"
@@ -98,6 +99,27 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
9899
return FloatAttr::get(Float64Type::get(context), value);
99100
}
100101

102+
static Value getScalarValue(Value input, Location loc,
103+
PatternRewriter &rewriter) {
104+
Value scalar = nullptr;
105+
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
106+
if (valueTensorLiteralOp &&
107+
getTensorRank(valueTensorLiteralOp.getResult()) == 0) {
108+
auto val = valueTensorLiteralOp.value()
109+
.cast<DenseElementsAttr>()
110+
.getSplatValue<int64_t>();
111+
scalar = rewriter.create<Torch::ConstantIntOp>(
112+
loc, rewriter.getI64IntegerAttr(val));
113+
return scalar;
114+
}
115+
} else if (auto primNumToTensorScalarOp =
116+
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
117+
scalar = primNumToTensorScalarOp.a();
118+
return scalar;
119+
}
120+
return scalar;
121+
}
122+
101123
//===----------------------------------------------------------------------===//
102124
// MethodOp
103125
//===----------------------------------------------------------------------===//
@@ -763,6 +785,36 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
763785
});
764786
}
765787

788+
//===----------------------------------------------------------------------===//
789+
// AtenAddTensorOp
790+
//===----------------------------------------------------------------------===//
791+
792+
void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
793+
MLIRContext *context) {
794+
patterns.add(+[](AtenAddTensorOp op, PatternRewriter &rewriter) {
795+
// The lhs and rhs of the add.tensor op should be 0d tensors for the
796+
// canonicalization to be carried out.
797+
// `aten.add.tensor(self, other, alpha)` is canonicalized to
798+
// `aten.add.int(self, aten.mul.int(other, alpha))`.
799+
800+
Value lhs = getScalarValue(op.self(), op.getLoc(), rewriter);
801+
Value rhs = getScalarValue(op.other(), op.getLoc(), rewriter);
802+
803+
if (!lhs || !rhs)
804+
return failure();
805+
806+
if (!lhs.getType().isa<Torch::IntType>() ||
807+
!rhs.getType().isa<Torch::IntType>())
808+
return failure();
809+
810+
Value mul = rewriter.create<AtenMulIntOp>(op->getLoc(), rhs, op.alpha());
811+
Value add = rewriter.create<AtenAddIntOp>(op->getLoc(), lhs, mul);
812+
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(
813+
op, op.self().getType(), add);
814+
return success();
815+
});
816+
}
817+
766818
//===----------------------------------------------------------------------===//
767819
// AtenSizeOp
768820
//===----------------------------------------------------------------------===//

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,6 @@ def emit_with_mutating_variants(key, **kwargs):
255255
"aten::floor : (Tensor) -> (Tensor)",
256256
"aten::ceil : (Tensor) -> (Tensor)",
257257
"aten::bitwise_not : (Tensor) -> (Tensor)",
258-
"aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
259258
"aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
260259
"aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)",
261260
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
@@ -294,6 +293,7 @@ def emit_with_mutating_variants(key, **kwargs):
294293
emit_with_mutating_variants(key)
295294
# Elementwise tensor compute ops that don't have the standard mutating
296295
# variants.
296+
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
297297
emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
298298
emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
299299
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")

test/Dialect/Torch/canonicalize.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,3 +1267,37 @@ func.func @torch.aten.Bool.int$fold_cst() -> !torch.bool {
12671267
%1 = torch.aten.Bool.int %int : !torch.int -> !torch.bool
12681268
return %1 : !torch.bool
12691269
}
1270+
1271+
// CHECK-LABEL: func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
1272+
// CHECK: %[[INT6:.*]] = torch.constant.int 6
1273+
// CHECK: %[[INT2:.*]] = torch.constant.int 2
1274+
// CHECK: %[[INT0:.*]] = torch.constant.int 0
1275+
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64>
1276+
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64>
1277+
// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
1278+
// CHECK: return %[[PR3]] : !torch.vtensor<[],si64>
1279+
func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
1280+
%int0 = torch.constant.int 0
1281+
%int2 = torch.constant.int 2
1282+
%int3 = torch.constant.int 3
1283+
%0 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
1284+
%1 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64>
1285+
%2 = torch.aten.add.Tensor %0, %1, %int3 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>
1286+
return %2 : !torch.vtensor<[],si64>
1287+
}
1288+
1289+
// CHECK-LABEL: @torch.aten.add.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
1290+
// CHECK: %[[INT6:.*]] = torch.constant.int 6
1291+
// CHECK: %[[INT2:.*]] = torch.constant.int 2
1292+
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64>
1293+
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
1294+
// CHECK: return %[[PR2]] : !torch.vtensor<[],si64>
1295+
func.func @torch.aten.add.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
1296+
%int0 = torch.constant.int 0
1297+
%int2 = torch.constant.int 2
1298+
%int3 = torch.constant.int 3
1299+
%0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
1300+
%1 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64>
1301+
%2 = torch.aten.add.Tensor %0, %1, %int3 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>
1302+
return %2 : !torch.vtensor<[],si64>
1303+
}

0 commit comments

Comments
 (0)