@@ -71,6 +71,25 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
71
71
};
72
72
} // namespace
73
73
74
+ // ConvertAtenUnaryConvertOp legalize genearl unary ops into Mhlo ConverOp
75
+ namespace {
76
+ template <typename AtenOpT>
77
+ class ConvertAtenUnaryConvertOp : public OpConversionPattern <AtenOpT> {
78
+ public:
79
+ using OpConversionPattern<AtenOpT>::OpConversionPattern;
80
+ using OpAdaptor = typename AtenOpT::Adaptor;
81
+ LogicalResult matchAndRewrite (AtenOpT op, OpAdaptor adaptor,
82
+ ConversionPatternRewriter &rewriter) const override {
83
+ rewriter.replaceOpWithNewOp <mhlo::ConvertOp>(
84
+ op,
85
+ OpConversionPattern<AtenOpT>::getTypeConverter ()->convertType (
86
+ op.getType ()),
87
+ adaptor.self ());
88
+ return success ();
89
+ }
90
+ };
91
+ } // namespace
92
+
74
93
// aten.ones & aten.zeros
75
94
// Ref: Error checking based on the Torch to TOSA lowering
76
95
namespace {
@@ -307,6 +326,9 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
307
326
std::is_same<AtenOpT, AtenGtScalarOp>()) {
308
327
compareDirectionAttr = mhlo::ComparisonDirectionAttr::get (
309
328
op->getContext (), mhlo::ComparisonDirection::GT);
329
+ } else if (std::is_same<AtenOpT, AtenGeScalarOp>()) {
330
+ compareDirectionAttr = mhlo::ComparisonDirectionAttr::get (
331
+ op->getContext (), mhlo::ComparisonDirection::GE);
310
332
} else if (std::is_same<AtenOpT, AtenEqTensorOp>() ||
311
333
std::is_same<AtenOpT, AtenEqScalarOp>()) {
312
334
compareDirectionAttr = mhlo::ComparisonDirectionAttr::get (
@@ -980,6 +1002,75 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
980
1002
}
981
1003
} // namespace
982
1004
1005
+ // AtenSizeIntOp
1006
+ namespace {
1007
+ template <>
1008
+ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
1009
+ AtenSizeIntOp op,
1010
+ OpAdaptor adaptor,
1011
+ ConversionPatternRewriter& rewriter) const {
1012
+ // Not a tensor type.
1013
+ auto selfType = adaptor.self ().getType ().dyn_cast <TensorType>();
1014
+ if (!selfType)
1015
+ return op.emitError (" Only tensor types are currently supported" );
1016
+ auto dim = rewriter.create <arith::IndexCastOp>(
1017
+ op.getLoc (), rewriter.getIndexType (), adaptor.dim ());
1018
+ auto dimSize = rewriter.create <tensor::DimOp>(
1019
+ op.getLoc (), rewriter.getIndexType (), adaptor.self (), dim);
1020
+
1021
+ rewriter.replaceOpWithNewOp <arith::IndexCastOp>(
1022
+ op, getTypeConverter ()->convertType (op.getType ()), dimSize);
1023
+
1024
+ return success ();
1025
+ }
1026
+ } // namespace
1027
+
1028
+ // ValsemVariantAtenUniformOp
1029
+ namespace {
1030
+ template <>
1031
+ LogicalResult ConvertAtenOp<ValsemVariantAtenUniformOp>::matchAndRewrite(
1032
+ ValsemVariantAtenUniformOp op,
1033
+ OpAdaptor adaptor,
1034
+ ConversionPatternRewriter& rewriter) const {
1035
+ auto inputTy = adaptor.self ().getType ().template cast <RankedTensorType>();
1036
+ auto loc = op.getLoc ();
1037
+ if (!inputTy) {
1038
+ op.emitError (" input should be ranked tensor type." );
1039
+ }
1040
+ auto definingOp = op.self ().getDefiningOp ();
1041
+ auto shape = definingOp->getOperand (0 );
1042
+ SmallVector<Value, 4 > dimSizes;
1043
+ getListConstructElements (shape, dimSizes);
1044
+ std::for_each (dimSizes.begin (), dimSizes.end (), [&](Value& dSize) {
1045
+ dSize = rewriter.create <torch::TorchConversion::ToI64Op>(loc, dSize).getResult ();
1046
+ return dSize;
1047
+ });
1048
+
1049
+ auto mhloShape =
1050
+ rewriter.create <tensor::FromElementsOp>(op.getLoc (), dimSizes);
1051
+
1052
+ double fromDoubleValue, toDoubleValue;
1053
+ if (!matchPattern (op.from (), m_TorchConstantFloat (&fromDoubleValue))) {
1054
+ op.emitError (" operand #1 should be scalar" );
1055
+ }
1056
+ if (!matchPattern (op.to (), m_TorchConstantFloat (&toDoubleValue))) {
1057
+ op.emitError (" operand #2 should be scalar" );
1058
+ }
1059
+ Value fromTensor = rewriter.create <mhlo::ConstantOp>(
1060
+ op.getLoc (),
1061
+ rewriter.getFloatAttr (inputTy.getElementType (), fromDoubleValue));
1062
+ Value toTensor = rewriter.create <mhlo::ConstantOp>(
1063
+ op.getLoc (),
1064
+ rewriter.getFloatAttr (inputTy.getElementType (), toDoubleValue));
1065
+
1066
+ auto outType = getTypeConverter ()
1067
+ ->convertType (op.getType ())
1068
+ .template dyn_cast <TensorType>();
1069
+ rewriter.replaceOpWithNewOp <mhlo::RngOp>(
1070
+ op, inputTy, fromTensor, toTensor, mhloShape, mhlo::RngDistribution::UNIFORM);
1071
+ return success ();
1072
+ }
1073
+ }
983
1074
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality (
984
1075
TypeConverter &typeConverter, RewritePatternSet &patterns,
985
1076
ConversionTarget &target) {
@@ -1005,6 +1096,15 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
1005
1096
INSERT_UNARY_FPONLY_PATTERN (AtenNegOp, mhlo::NegOp);
1006
1097
#undef INSERT_UNARY_FPONLY_PATTERN
1007
1098
1099
+ #define INSERT_UNARY_CONVERT_PATTERN (AtenOp ) \
1100
+ target.addIllegalOp <AtenOp>(); \
1101
+ patterns.add <ConvertAtenUnaryConvertOp<AtenOp>>(typeConverter, \
1102
+ context);
1103
+ INSERT_UNARY_CONVERT_PATTERN (AtenContiguousOp);
1104
+ INSERT_UNARY_CONVERT_PATTERN (AtenToDtypeOp);
1105
+ INSERT_UNARY_CONVERT_PATTERN (AtenTypeAsOp);
1106
+ #undef INSERT_UNARY_CONVERT_PATTERN
1107
+
1008
1108
#define INSERT_CONSTANT_FILL_PATTERN (AtenOp, fillVal ) \
1009
1109
target.addIllegalOp <AtenOp>(); \
1010
1110
patterns.add <ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
@@ -1038,6 +1138,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
1038
1138
1039
1139
INSERT_BINARY_COMPARE_PATTERN (AtenGtTensorOp);
1040
1140
INSERT_BINARY_COMPARE_PATTERN (AtenGtScalarOp);
1141
+ INSERT_BINARY_COMPARE_PATTERN (AtenGeScalarOp);
1041
1142
INSERT_BINARY_COMPARE_PATTERN (AtenLtTensorOp);
1042
1143
INSERT_BINARY_COMPARE_PATTERN (AtenLtScalarOp);
1043
1144
INSERT_BINARY_COMPARE_PATTERN (AtenEqTensorOp);
@@ -1063,5 +1164,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
1063
1164
1064
1165
INSERT_ATENOP_PATTERN (AtenBatchNormOp);
1065
1166
INSERT_ATENOP_PATTERN (AtenNativeLayerNormOp);
1167
+ INSERT_ATENOP_PATTERN (AtenSizeIntOp);
1168
+ INSERT_ATENOP_PATTERN (ValsemVariantAtenUniformOp);
1066
1169
#undef INSERT_ATENOP_PATTERN
1067
1170
}
0 commit comments