Skip to content

Commit 5cff40c

Browse files
committed
Add canonicalization for aten.add.tensor op
1 parent e143a34 commit 5cff40c

File tree

4 files changed

+140
-50
lines changed

4 files changed

+140
-50
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -790,55 +790,6 @@ def Torch_AtenBitwiseNot_Op : Torch_Op<"aten.bitwise_not_", [
790790
}];
791791
}
792792

793-
def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [
794-
AllowsTypeRefinement,
795-
HasValueSemantics,
796-
ReadOnly
797-
]> {
798-
let summary = "Generated op for `aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`";
799-
let arguments = (ins
800-
AnyTorchTensorType:$self,
801-
AnyTorchTensorType:$other,
802-
AnyTorchScalarType:$alpha
803-
);
804-
let results = (outs
805-
AnyTorchTensorType:$result
806-
);
807-
let hasCustomAssemblyFormat = 1;
808-
let extraClassDefinition = [{
809-
ParseResult AtenAddTensorOp::parse(OpAsmParser &parser, OperationState &result) {
810-
return parseDefaultTorchOp(parser, result, 3, 1);
811-
}
812-
void AtenAddTensorOp::print(OpAsmPrinter &printer) {
813-
printDefaultTorchOp(printer, *this, 3, 1);
814-
}
815-
}];
816-
}
817-
818-
def Torch_AtenAdd_TensorOp : Torch_Op<"aten.add_.Tensor", [
819-
IsTrailingUnderscoreInplaceVariant,
820-
AllowsTypeRefinement
821-
]> {
822-
let summary = "Generated op for `aten::add_.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`";
823-
let arguments = (ins
824-
AnyTorchTensorType:$self,
825-
AnyTorchTensorType:$other,
826-
AnyTorchScalarType:$alpha
827-
);
828-
let results = (outs
829-
AnyTorchTensorType:$result
830-
);
831-
let hasCustomAssemblyFormat = 1;
832-
let extraClassDefinition = [{
833-
ParseResult AtenAdd_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
834-
return parseDefaultTorchOp(parser, result, 3, 1);
835-
}
836-
void AtenAdd_TensorOp::print(OpAsmPrinter &printer) {
837-
printDefaultTorchOp(printer, *this, 3, 1);
838-
}
839-
}];
840-
}
841-
842793
def Torch_AtenSubTensorOp : Torch_Op<"aten.sub.Tensor", [
843794
AllowsTypeRefinement,
844795
HasValueSemantics,
@@ -2439,6 +2390,56 @@ def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [
24392390
}];
24402391
}
24412392

2393+
def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [
2394+
AllowsTypeRefinement,
2395+
HasValueSemantics,
2396+
ReadOnly
2397+
]> {
2398+
let summary = "Generated op for `aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`";
2399+
let arguments = (ins
2400+
AnyTorchTensorType:$self,
2401+
AnyTorchTensorType:$other,
2402+
AnyTorchScalarType:$alpha
2403+
);
2404+
let results = (outs
2405+
AnyTorchTensorType:$result
2406+
);
2407+
let hasCustomAssemblyFormat = 1;
2408+
let extraClassDefinition = [{
2409+
ParseResult AtenAddTensorOp::parse(OpAsmParser &parser, OperationState &result) {
2410+
return parseDefaultTorchOp(parser, result, 3, 1);
2411+
}
2412+
void AtenAddTensorOp::print(OpAsmPrinter &printer) {
2413+
printDefaultTorchOp(printer, *this, 3, 1);
2414+
}
2415+
}];
2416+
let hasCanonicalizer = 1;
2417+
}
2418+
2419+
def Torch_AtenAdd_TensorOp : Torch_Op<"aten.add_.Tensor", [
2420+
IsTrailingUnderscoreInplaceVariant,
2421+
AllowsTypeRefinement
2422+
]> {
2423+
let summary = "Generated op for `aten::add_.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`";
2424+
let arguments = (ins
2425+
AnyTorchTensorType:$self,
2426+
AnyTorchTensorType:$other,
2427+
AnyTorchScalarType:$alpha
2428+
);
2429+
let results = (outs
2430+
AnyTorchTensorType:$result
2431+
);
2432+
let hasCustomAssemblyFormat = 1;
2433+
let extraClassDefinition = [{
2434+
ParseResult AtenAdd_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
2435+
return parseDefaultTorchOp(parser, result, 3, 1);
2436+
}
2437+
void AtenAdd_TensorOp::print(OpAsmPrinter &printer) {
2438+
printDefaultTorchOp(printer, *this, 3, 1);
2439+
}
2440+
}];
2441+
}
2442+
24422443
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
24432444
AllowsTypeRefinement,
24442445
HasValueSemantics,

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,29 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
9898
return FloatAttr::get(Float64Type::get(context), value);
9999
}
100100

101+
static Value getScalarValue(Value input, Location loc,
102+
PatternRewriter &rewriter) {
103+
Value scalar = nullptr;
104+
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
105+
if (valueTensorLiteralOp &&
106+
getTensorRank(valueTensorLiteralOp.getResult()) == 0) {
107+
auto tensorType =
108+
valueTensorLiteralOp.value().getType().cast<RankedTensorType>();
109+
if (tensorType.getElementType().isa<mlir::IntegerType>()) {
110+
auto val = valueTensorLiteralOp.value()
111+
.cast<DenseElementsAttr>()
112+
.getSplatValue<int64_t>();
113+
scalar = rewriter.create<Torch::ConstantIntOp>(
114+
loc, rewriter.getI64IntegerAttr(val));
115+
}
116+
}
117+
} else if (auto primNumToTensorScalarOp =
118+
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
119+
scalar = primNumToTensorScalarOp.a();
120+
}
121+
return scalar;
122+
}
123+
101124
//===----------------------------------------------------------------------===//
102125
// MethodOp
103126
//===----------------------------------------------------------------------===//
@@ -763,6 +786,38 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
763786
});
764787
}
765788

789+
//===----------------------------------------------------------------------===//
790+
// AtenAddTensorOp
791+
//===----------------------------------------------------------------------===//
792+
793+
void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
794+
MLIRContext *context) {
795+
patterns.add(+[](AtenAddTensorOp op, PatternRewriter &rewriter) {
796+
// The lhs and rhs of the add.tensor op should be 0d tensors for the
797+
// canonicalization to be carried out.
798+
// `aten.add.tensor(self, other, alpha)` is canonicalized to
799+
// `aten.add.int(self, aten.mul.int(other, alpha))`.
800+
801+
Value lhs = getScalarValue(op.self(), op.getLoc(), rewriter);
802+
if (!lhs)
803+
return rewriter.notifyMatchFailure(op, "lhs scalar is empyty");
804+
if (!lhs.getType().isa<Torch::IntType>())
805+
return rewriter.notifyMatchFailure(op, "lhs scalar is not IntType");
806+
807+
Value rhs = getScalarValue(op.other(), op.getLoc(), rewriter);
808+
if (!rhs)
809+
return rewriter.notifyMatchFailure(op, "rhs scalar is empyty");
810+
if (!rhs.getType().isa<Torch::IntType>())
811+
return rewriter.notifyMatchFailure(op, "rhs scalar is not IntType");
812+
813+
Value mul = rewriter.create<AtenMulIntOp>(op->getLoc(), rhs, op.alpha());
814+
Value add = rewriter.create<AtenAddIntOp>(op->getLoc(), lhs, mul);
815+
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(
816+
op, op.self().getType(), add);
817+
return success();
818+
});
819+
}
820+
766821
//===----------------------------------------------------------------------===//
767822
// AtenSizeOp
768823
//===----------------------------------------------------------------------===//

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,6 @@ def emit_with_mutating_variants(key, **kwargs):
255255
"aten::floor : (Tensor) -> (Tensor)",
256256
"aten::ceil : (Tensor) -> (Tensor)",
257257
"aten::bitwise_not : (Tensor) -> (Tensor)",
258-
"aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
259258
"aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
260259
"aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)",
261260
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
@@ -294,6 +293,7 @@ def emit_with_mutating_variants(key, **kwargs):
294293
emit_with_mutating_variants(key)
295294
# Elementwise tensor compute ops that don't have the standard mutating
296295
# variants.
296+
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
297297
emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
298298
emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
299299
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")

test/Dialect/Torch/canonicalize.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,3 +1267,37 @@ func.func @torch.aten.Bool.int$fold_cst() -> !torch.bool {
12671267
%1 = torch.aten.Bool.int %int : !torch.int -> !torch.bool
12681268
return %1 : !torch.bool
12691269
}
1270+
1271+
// CHECK-LABEL: func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
1272+
// CHECK: %[[INT6:.*]] = torch.constant.int 6
1273+
// CHECK: %[[INT0:.*]] = torch.constant.int 0
1274+
// CHECK: %[[INT2:.*]] = torch.constant.int 2
1275+
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64>
1276+
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64>
1277+
// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
1278+
// CHECK: return %[[PR3]] : !torch.vtensor<[],si64>
1279+
func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
1280+
%int0 = torch.constant.int 0
1281+
%int2 = torch.constant.int 2
1282+
%int3 = torch.constant.int 3
1283+
%0 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
1284+
%1 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64>
1285+
%2 = torch.aten.add.Tensor %0, %1, %int3 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>
1286+
return %2 : !torch.vtensor<[],si64>
1287+
}
1288+
1289+
// CHECK-LABEL: @torch.aten.add.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
1290+
// CHECK: %[[INT6:.*]] = torch.constant.int 6
1291+
// CHECK: %[[INT2:.*]] = torch.constant.int 2
1292+
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64>
1293+
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
1294+
// CHECK: return %[[PR2]] : !torch.vtensor<[],si64>
1295+
func.func @torch.aten.add.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
1296+
%int0 = torch.constant.int 0
1297+
%int2 = torch.constant.int 2
1298+
%int3 = torch.constant.int 3
1299+
%0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
1300+
%1 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64>
1301+
%2 = torch.aten.add.Tensor %0, %1, %int3 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>
1302+
return %2 : !torch.vtensor<[],si64>
1303+
}

0 commit comments

Comments
 (0)