Skip to content

Commit c935795

Browse files
authored
add native_dropout and related ops pattern (#1211)
1 parent 41aa562 commit c935795

File tree

6 files changed

+207
-3
lines changed

6 files changed

+207
-3
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5313,9 +5313,10 @@ def Torch_AtenOnesLikeOp : Torch_Op<"aten.ones_like", [
53135313
}
53145314

53155315
def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [
5316+
NoSideEffect,
53165317
AllowsTypeRefinement,
53175318
HasValueSemantics,
5318-
ReadOnly
5319+
ReadOnly,
53195320
]> {
53205321
let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`";
53215322
let arguments = (ins

lib/Conversion/TorchToMhlo/Basic.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,25 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
7171
};
7272
} // namespace
7373

74+
// ConvertAtenUnaryConvertOp legalize genearl unary ops into Mhlo ConverOp
75+
namespace {
76+
template <typename AtenOpT>
77+
class ConvertAtenUnaryConvertOp: public OpConversionPattern<AtenOpT> {
78+
public:
79+
using OpConversionPattern<AtenOpT>::OpConversionPattern;
80+
using OpAdaptor = typename AtenOpT::Adaptor;
81+
LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
82+
ConversionPatternRewriter &rewriter) const override {
83+
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
84+
op,
85+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
86+
op.getType()),
87+
adaptor.self());
88+
return success();
89+
}
90+
};
91+
} // namespace
92+
7493
// aten.ones & aten.zeros
7594
// Ref: Error checking based on the Torch to TOSA lowering
7695
namespace {
@@ -307,6 +326,9 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
307326
std::is_same<AtenOpT, AtenGtScalarOp>()) {
308327
compareDirectionAttr = mhlo::ComparisonDirectionAttr::get(
309328
op->getContext(), mhlo::ComparisonDirection::GT);
329+
} else if (std::is_same<AtenOpT, AtenGeScalarOp>()) {
330+
compareDirectionAttr = mhlo::ComparisonDirectionAttr::get(
331+
op->getContext(), mhlo::ComparisonDirection::GE);
310332
} else if (std::is_same<AtenOpT, AtenEqTensorOp>() ||
311333
std::is_same<AtenOpT, AtenEqScalarOp>()) {
312334
compareDirectionAttr = mhlo::ComparisonDirectionAttr::get(
@@ -980,6 +1002,75 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
9801002
}
9811003
} // namespace
9821004

1005+
// AtenSizeIntOp
1006+
namespace {
1007+
template <>
1008+
LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
1009+
AtenSizeIntOp op,
1010+
OpAdaptor adaptor,
1011+
ConversionPatternRewriter& rewriter) const {
1012+
// Not a tensor type.
1013+
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
1014+
if (!selfType)
1015+
return op.emitError("Only tensor types are currently supported");
1016+
auto dim = rewriter.create<arith::IndexCastOp>(
1017+
op.getLoc(), rewriter.getIndexType(), adaptor.dim());
1018+
auto dimSize = rewriter.create<tensor::DimOp>(
1019+
op.getLoc(), rewriter.getIndexType(), adaptor.self(), dim);
1020+
1021+
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
1022+
op, getTypeConverter()->convertType(op.getType()), dimSize);
1023+
1024+
return success();
1025+
}
1026+
} // namespace
1027+
1028+
// ValsemVariantAtenUniformOp
1029+
namespace {
1030+
template <>
1031+
LogicalResult ConvertAtenOp<ValsemVariantAtenUniformOp>::matchAndRewrite(
1032+
ValsemVariantAtenUniformOp op,
1033+
OpAdaptor adaptor,
1034+
ConversionPatternRewriter& rewriter) const {
1035+
auto inputTy = adaptor.self().getType().template cast<RankedTensorType>();
1036+
auto loc = op.getLoc();
1037+
if (!inputTy) {
1038+
op.emitError("input should be ranked tensor type.");
1039+
}
1040+
auto definingOp = op.self().getDefiningOp();
1041+
auto shape = definingOp->getOperand(0);
1042+
SmallVector<Value, 4> dimSizes;
1043+
getListConstructElements(shape, dimSizes);
1044+
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) {
1045+
dSize = rewriter.create<torch::TorchConversion::ToI64Op>(loc, dSize).getResult();
1046+
return dSize;
1047+
});
1048+
1049+
auto mhloShape =
1050+
rewriter.create<tensor::FromElementsOp>(op.getLoc(), dimSizes);
1051+
1052+
double fromDoubleValue, toDoubleValue;
1053+
if (!matchPattern(op.from(), m_TorchConstantFloat(&fromDoubleValue))) {
1054+
op.emitError("operand #1 should be scalar");
1055+
}
1056+
if (!matchPattern(op.to(), m_TorchConstantFloat(&toDoubleValue))) {
1057+
op.emitError("operand #2 should be scalar");
1058+
}
1059+
Value fromTensor = rewriter.create<mhlo::ConstantOp>(
1060+
op.getLoc(),
1061+
rewriter.getFloatAttr(inputTy.getElementType(), fromDoubleValue));
1062+
Value toTensor = rewriter.create<mhlo::ConstantOp>(
1063+
op.getLoc(),
1064+
rewriter.getFloatAttr(inputTy.getElementType(), toDoubleValue));
1065+
1066+
auto outType = getTypeConverter()
1067+
->convertType(op.getType())
1068+
.template dyn_cast<TensorType>();
1069+
rewriter.replaceOpWithNewOp<mhlo::RngOp>(
1070+
op, inputTy, fromTensor, toTensor, mhloShape, mhlo::RngDistribution::UNIFORM);
1071+
return success();
1072+
}
1073+
}
9831074
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
9841075
TypeConverter &typeConverter, RewritePatternSet &patterns,
9851076
ConversionTarget &target) {
@@ -1005,6 +1096,15 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
10051096
INSERT_UNARY_FPONLY_PATTERN(AtenNegOp, mhlo::NegOp);
10061097
#undef INSERT_UNARY_FPONLY_PATTERN
10071098

1099+
#define INSERT_UNARY_CONVERT_PATTERN(AtenOp) \
1100+
target.addIllegalOp<AtenOp>(); \
1101+
patterns.add<ConvertAtenUnaryConvertOp<AtenOp>>(typeConverter, \
1102+
context);
1103+
INSERT_UNARY_CONVERT_PATTERN(AtenContiguousOp);
1104+
INSERT_UNARY_CONVERT_PATTERN(AtenToDtypeOp);
1105+
INSERT_UNARY_CONVERT_PATTERN(AtenTypeAsOp);
1106+
#undef INSERT_UNARY_CONVERT_PATTERN
1107+
10081108
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
10091109
target.addIllegalOp<AtenOp>(); \
10101110
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
@@ -1038,6 +1138,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
10381138

10391139
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp);
10401140
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp);
1141+
INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp);
10411142
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp);
10421143
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp);
10431144
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp);
@@ -1063,5 +1164,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
10631164

10641165
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
10651166
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
1167+
INSERT_ATENOP_PATTERN(AtenSizeIntOp);
1168+
INSERT_ATENOP_PATTERN(ValsemVariantAtenUniformOp);
10661169
#undef INSERT_ATENOP_PATTERN
10671170
}

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,6 +1155,47 @@ class DecomposeAtenDropoutOp : public OpRewritePattern<AtenDropoutOp> {
11551155
};
11561156
} // namespace
11571157

1158+
namespace {
1159+
class DecomposeAtenNativeDropoutOp : public OpRewritePattern<AtenNativeDropoutOp> {
1160+
public:
1161+
using OpRewritePattern::OpRewritePattern;
1162+
LogicalResult matchAndRewrite(AtenNativeDropoutOp op,
1163+
PatternRewriter &rewriter) const override {
1164+
auto loc = op.getLoc();
1165+
Value input = op.input();
1166+
Value prob = op.p();
1167+
bool train = false;
1168+
if (!matchPattern(op.train(), m_TorchConstantBool(&train)))
1169+
return rewriter.notifyMatchFailure(op, "train must be a boolean constant");
1170+
1171+
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
1172+
if (!train) {
1173+
// TODO(yancey.yx): supports inference mode
1174+
return op.emitError(
1175+
"native_dropout does not support argument train is false");
1176+
}
1177+
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>())
1178+
return rewriter.notifyMatchFailure(
1179+
op, "only support floating type input for training mode");
1180+
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
1181+
Value floatOne =
1182+
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
1183+
Value oneMinusP = rewriter.create<AtenSubFloatOp>(loc, floatOne, prob);
1184+
Value boolMask = rewriter.create<ValsemVariantAtenBernoulliFloatOp>(
1185+
loc, inputType, input, oneMinusP, /*generator=*/noneVal);
1186+
Value maskedInput =
1187+
rewriter.create<AtenMulTensorOp>(loc, inputType, boolMask, input);
1188+
Value output =
1189+
rewriter.create<AtenMulScalarOp>(loc, inputType, maskedInput, oneMinusP);
1190+
Value one =
1191+
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
1192+
boolMask = rewriter.create<AtenGeScalarOp>(
1193+
loc, op.getResult(1).getType(), boolMask, one);
1194+
rewriter.replaceOp(op, {output, boolMask});
1195+
return success();
1196+
}
1197+
};
1198+
} // namespace
11581199
// Decompose aten.var into: aten.var.dim op.
11591200
namespace {
11601201
class DecomposeAtenVarOp : public OpRewritePattern<AtenVarOp> {
@@ -2596,6 +2637,8 @@ class DecomposeComplexOpsPass
25962637
patterns.add<DecomposeAten_ToCopyOp>(context);
25972638
target.addIllegalOp<Aten_ToCopyOp>();
25982639
patterns.add<DecomposeAtenDropoutOp>(context);
2640+
patterns.add<DecomposeAtenNativeDropoutOp>(context);
2641+
target.addIllegalOp<AtenNativeDropoutOp>();
25992642
target.addIllegalOp<AtenDropoutOp>();
26002643
target.addIllegalOp<AtenNewEmptyOp>();
26012644
patterns.add<DecomposeAtenNewEmptyOp>(context);

lib/Dialect/TorchConversion/Transforms/Passes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline(
139139
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
140140

141141
pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass());
142+
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
143+
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
142144

143145
if (options.optimize) {
144146
// Clean up any non-canonical code introduced above..
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// RUN: torch-mlir-opt < %s --torch-function-to-torch-backend-pipeline --torch-backend-to-mhlo-backend-pipeline -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @torch.aten.native_dropout.train(
4+
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: f64) -> (tensor<?x?xf32>, tensor<?x?xi1>) {
5+
// CHECK: %[[T0:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
6+
// CHECK: %[[CST_0:.*]] = arith.constant 1 : index
7+
// CHECK: %[[CST_1:.*]] = arith.constant 0 : index
8+
// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f64>
9+
// CHECK: %[[T2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f64>
10+
// CHECK: %[[CST_2:.*]] = arith.constant 1.000000e+00 : f64
11+
// CHECK: %[[CST_3:.*]] = arith.subf %[[CST_2]], %[[ARG1]] : f64
12+
// CHECK: %[[T3:.*]] = tensor.from_elements %[[CST_3]] : tensor<1xf64>
13+
// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf64>) -> tensor<f64>
14+
// CHECK: %[[T5:.*]] = mhlo.convert(%[[ARG0]]) : (tensor<?x?xf32>) -> tensor<?x?xf64>
15+
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T5]], %[[CST_1]] : tensor<?x?xf64>
16+
// CHECK: %[[CST_I64_0:.*]] = arith.index_cast %[[DIM_0]] : index to i64
17+
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T5]], %[[CST_0]] : tensor<?x?xf64>
18+
// CHECK: %[[CST_I64_1:.*]] = arith.index_cast %[[DIM_1]] : index to i64
19+
// CHECK: %[[T6:.*]] = tensor.from_elements %[[CST_I64_0]], %[[CST_I64_1]] : tensor<2xi64>
20+
// CHECK: %[[T7:.*]] = "mhlo.rng"(%[[T2]], %[[T1]], %[[T6]]) {rng_distribution = #mhlo.rng_distribution<UNIFORM>} : (tensor<f64>, tensor<f64>, tensor<2xi64>) -> tensor<?x?xf64>
21+
// CHECK: %[[T8:.*]] = shape.shape_of %[[T7]] : tensor<?x?xf64> -> tensor<2xindex>
22+
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T4]], %[[T8]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f64>, tensor<2xindex>) -> tensor<?x?xf64>
23+
// CHECK: %[[T10:.*]] = mhlo.compare LT, %[[T7]], %[[T9]], FLOAT : (tensor<?x?xf64>, tensor<?x?xf64>) -> tensor<?x?xi1>
24+
// CHECK: %[[T11:.*]] = mhlo.convert(%[[T10]]) : (tensor<?x?xi1>) -> tensor<?x?xf32>
25+
// CHECK: %[[T12:.*]] = shape.shape_of %[[T11]] : tensor<?x?xf32> -> tensor<2xindex>
26+
// CHECK: %[[T13:.*]] = shape.shape_of %[[ARG0]] : tensor<?x?xf32> -> tensor<2xindex>
27+
// CHECK: %[[T14:.*]] = shape.cstr_broadcastable %[[T12]], %[[T13]] : tensor<2xindex>, tensor<2xindex>
28+
// CHECK: %[[T15:.*]] = shape.assuming %[[T14]] -> (tensor<?x?xf32>) {
29+
// CHECK: %[[T16:.*]] = shape.broadcast %[[T12]], %[[T13]] : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex>
30+
// CHECK: %[[T17:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T11]], %[[T16]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
31+
// CHECK: %[[T18:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[T16]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
32+
// CHECK: %[[T19:.*]] = mhlo.multiply %[[T17]], %[[T18]] : tensor<?x?xf32>
33+
// CHECK: shape.assuming_yield %[[T19]] : tensor<?x?xf32>
34+
// CHECK: }
35+
// CHECK: %[[T20:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xf64>) -> tensor<1xf32>
36+
// CHECK: %[[T21:.*]] = "mhlo.reshape"(%[[T20]]) : (tensor<1xf32>) -> tensor<f32>
37+
// CHECK: %[[T22:.*]] = shape.shape_of %[[T15]] : tensor<?x?xf32> -> tensor<2xindex>
38+
// CHECK: %[[T23:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T21]], %[[T22]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<2xindex>) -> tensor<?x?xf32>
39+
// CHECK: %[[T24:.*]] = mhlo.multiply %[[T15]], %[[T23]] : tensor<?x?xf32>
40+
// CHECK: %[[T25:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T12]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<2xindex>) -> tensor<?x?xf32>
41+
// CHECK: %[[T26:.*]] = mhlo.compare GE, %[[T11]], %[[T25]], FLOAT : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
42+
// CHECK: return %[[T24]], %[[T26]] : tensor<?x?xf32>, tensor<?x?xi1>
43+
func.func @torch.aten.native_dropout.train(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.float) -> (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>) {
44+
%bool_true = torch.constant.bool true
45+
%result0, %result1 = torch.aten.native_dropout %arg0, %arg1, %bool_true: !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>
46+
return %result0, %result1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>
47+
}

test/Conversion/TorchToMhlo/view_like.mlir

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,9 +360,17 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !
360360
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3,?,?],f32> -> tensor<2x3x?x?xf32>
361361
// CHECK: %[[INTneg1:.*]] = torch.constant.int -1
362362
// CHECK: %[[INT1:.*]] = torch.constant.int 1
363+
// CHECK: %[[C1_I64:.*]] = torch_c.to_i64 %[[INT1]]
363364
// CHECK: %[[INT0:.*]] = torch.constant.int 0
364-
// CHECK: %[[T1:.*]] = torch.aten.size.int %[[ARG0]], %[[INT0]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int
365-
// CHECK: %[[T2:.*]] = torch.aten.size.int %[[ARG0]], %[[INT1]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int
365+
// CHECK: %[[C2_I64:.*]] = torch_c.to_i64 %[[INT0]]
366+
// CHECK: %[[INDEX_1:.*]] = arith.index_cast %[[C2_I64]] : i64 to index
367+
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[INDEX_1]] : tensor<2x3x?x?xf32>
368+
// CHECK: %[[DIM_I64_1:.*]] = arith.index_cast %[[DIM_1]] : index to i64
369+
// CHECK: %[[T1:.*]] = torch_c.from_i64 %[[DIM_I64_1]]
370+
// CHECK: %[[INDEX_2:.*]] = arith.index_cast %[[C1_I64]] : i64 to index
371+
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[INDEX_2]] : tensor<2x3x?x?xf32>
372+
// CHECK: %[[DIM_I64_2:.*]] = arith.index_cast %[[DIM_2]] : index to i64
373+
// CHECK: %[[T2:.*]] = torch_c.from_i64 %[[DIM_I64_2]]
366374
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]], %[[INTneg1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
367375
// CHECK: %[[T4:.*]] = torch_c.to_i64 %[[T1]]
368376
// CHECK: %[[T5:.*]] = torch_c.to_i64 %[[T2]]

0 commit comments

Comments
 (0)