diff --git a/CMakeLists.txt b/CMakeLists.txt index 00340141c835..02b01d43ff68 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,7 +23,7 @@ endif() project(torch-mlir LANGUAGES CXX C) set(CMAKE_C_STANDARD 11) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) macro(torch_mlir_add_llvm_external_project name identifier location) message(STATUS "Adding LLVM external project ${name} (${identifier}) -> ${location}") diff --git a/externals/llvm-project b/externals/llvm-project index 02b3a358926e..061e0189a3da 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 02b3a358926e7bbcac9226cbecbfc3067c2ad61b +Subproject commit 061e0189a3dab6b1831a80d489ff1b15ad93aafb diff --git a/externals/mlir-hlo b/externals/mlir-hlo index ad54b43c623c..0430519b7ebf 160000 --- a/externals/mlir-hlo +++ b/externals/mlir-hlo @@ -1 +1 @@ -Subproject commit ad54b43c623cc5ae69b0e90f395b3fba13ffa55a +Subproject commit 0430519b7ebf11a3f44c469fce8b579561fa6052 diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h index 46771dc72663..59be99c885e7 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h @@ -54,12 +54,12 @@ class BaseTensorType : public Type { Type getOptionalDtype() const; /// Return true if this type has a list of sizes. - bool hasSizes() const { return getOptionalSizes().hasValue(); } + bool hasSizes() const { return getOptionalSizes().has_value(); } /// Get the list of sizes. Requires `hasSizes()`. ArrayRef getSizes() const { assert(hasSizes() && "must have sizes"); - return getOptionalSizes().getValue(); + return getOptionalSizes().value(); } /// Return true if all sizes of this tensor are known. diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index e0780395e7a4..aa820a320c0f 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -49,8 +49,8 @@ class OptionalArrayRefParameter : AttrOrTypeParameter< "::llvm::Optional<::llvm::ArrayRef<" # arrayOf # ">>", desc> { let allocator = [{ - if ($_self.hasValue()) { - $_dst.getValue() = $_allocator.copyInto($_self.getValue()); + if ($_self.has_value()) { + $_dst.value() = $_allocator.copyInto($_self.value()); } }]; } diff --git a/lib/CAPI/TorchTypes.cpp b/lib/CAPI/TorchTypes.cpp index 6d72e7e1551c..7465fc06ef08 100644 --- a/lib/CAPI/TorchTypes.cpp +++ b/lib/CAPI/TorchTypes.cpp @@ -213,7 +213,8 @@ MlirType torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation( } MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr) { - auto attrTensorType = unwrap(attr).getType().cast(); + auto attrTensorType = + unwrap(attr).cast().getType().cast(); return wrap(Torch::NonValueTensorType::get(attrTensorType.getContext(), attrTensorType.getShape(), attrTensorType.getElementType())); diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 57383095a115..6de17c254bd0 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -342,7 +342,7 @@ class ConvertAtenViewOp : public OpConversionPattern { continue; } - if (inferredDimension.hasValue()) { + if (inferredDimension.has_value()) { return rewriter.notifyMatchFailure( op, "at most one element in size list is allowed to be -1"); } @@ -363,7 +363,7 @@ class ConvertAtenViewOp : public OpConversionPattern { // then we don't need to analyze the static information of the input // shape since the reassociation of dimensions only requires rank // information. - if (inferredDimension.hasValue() && outputShape.size() > 1) { + if (inferredDimension.has_value() && outputShape.size() > 1) { if (llvm::count(outputShape, kUnknownSize) != 1 || llvm::count(inputShape, kUnknownSize) != 0) { return rewriter.notifyMatchFailure( @@ -585,14 +585,14 @@ class ConvertAtenViewOp : public OpConversionPattern { collapsedInput = rewriter .create( loc, adjustedResultType, - expandedInput.hasValue() ? expandedInput.value() - : castedInput, + expandedInput.has_value() ? expandedInput.value() + : castedInput, outputAssociations) .result(); } - Value result = collapsedInput.hasValue() ? collapsedInput.value() - : expandedInput.value(); + Value result = collapsedInput.has_value() ? collapsedInput.value() + : expandedInput.value(); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); } diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp index 6529d70746d2..8392357dd132 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -119,7 +119,7 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { SmallVector values(size, fillVal); auto constOp = - mhlo::getConstTensor(rewriter, op, values, shape).getValue(); + mhlo::getConstTensor(rewriter, op, values, shape).value(); rewriter.replaceOpWithNewOp(op, outType, constOp); return success(); @@ -884,7 +884,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op->getLoc(), mhloBatchNormOutTy, input, mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape), {static_cast(inputFlattenShape.size())}) - .getValue()); + .value()); // Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp. SmallVector zeroConstVec( @@ -920,19 +920,19 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op->getLoc(), outputTy, batchNormTrainingResult.getResult(0), mhlo::getConstTensor(rewriter, op, outputTy.getShape(), {static_cast(outputTy.getShape().size())}) - .getValue()); + .value()); auto mean = rewriter.create( op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1), mhlo::getConstTensor( rewriter, op, outputMeanOrVarTy.getShape(), {static_cast(outputMeanOrVarTy.getShape().size())}) - .getValue()); + .value()); auto var = rewriter.create( op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2), mhlo::getConstTensor( rewriter, op, outputMeanOrVarTy.getShape(), {static_cast(outputMeanOrVarTy.getShape().size())}) - .getValue()); + .value()); // Apply affine transform: output x weight + bias [element-wise] auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy); diff --git a/lib/Conversion/TorchToMhlo/Pooling.cpp b/lib/Conversion/TorchToMhlo/Pooling.cpp index ab28e98aeca8..3c74e23dc7b9 100644 --- a/lib/Conversion/TorchToMhlo/Pooling.cpp +++ b/lib/Conversion/TorchToMhlo/Pooling.cpp @@ -314,8 +314,7 @@ LogicalResult ConvertAtenPoolingOp::matchAndRewrite( initIndexTensor, inputShapeTensor) .getResult(); - Value initIdx = - mhlo::getConstTensor(rewriter, op, {0}, {}).getValue(); + Value initIdx = mhlo::getConstTensor(rewriter, op, {0}, {}).value(); auto reduceWindowOp = rewriter.create( op->getLoc(), mlir::TypeRange{outValTy, outIdxTy}, @@ -491,7 +490,7 @@ LogicalResult ConvertAtenPoolingOp::matchAndRewrite( if (countIncludePad) { Value divisor = mhlo::getConstTensor( rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) - .getValue(); + .value(); divisor = mhlo::promoteType(rewriter, divisor, outTy); DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp( @@ -501,7 +500,7 @@ LogicalResult ConvertAtenPoolingOp::matchAndRewrite( // Use another mhlo.ReduceWindowOp to get the divisor Value windowSizeConst = - mhlo::getConstTensor(rewriter, op, {1.0}, {}).getValue(); + mhlo::getConstTensor(rewriter, op, {1.0}, {}).value(); windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy); auto inputShapeVec = *mhlo::getDimSizesOfTensor(rewriter, op, input); auto inputShapeTensor = rewriter.create( diff --git a/lib/Conversion/TorchToMhlo/Reduction.cpp b/lib/Conversion/TorchToMhlo/Reduction.cpp index cf6d5ef9d5aa..7f8ab24a38f4 100644 --- a/lib/Conversion/TorchToMhlo/Reduction.cpp +++ b/lib/Conversion/TorchToMhlo/Reduction.cpp @@ -87,7 +87,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, if (!initValue) return llvm::None; Value initIndex = - mhlo::getConstTensor(rewriter, op, {0}, {}).getValue(); + mhlo::getConstTensor(rewriter, op, {0}, {}).value(); DenseIntElementsAttr dimensions = DenseIntElementsAttr::get( RankedTensorType::get({}, rewriter.getI64Type()), dim); @@ -224,7 +224,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } auto inputShapeVec = *inputShapeInfo; auto mhloReduceResults = - getMaxInDim(rewriter, op, input, inputShapeVec, dim).getValue(); + getMaxInDim(rewriter, op, input, inputShapeVec, dim).value(); if (keepDim) { auto outShapeVec = inputShapeVec; @@ -301,7 +301,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } auto inputShapeVec = *inputShapeInfo; auto mhloReduceResults = - getMaxInDim(rewriter, op, input, inputShapeVec, dim).getValue(); + getMaxInDim(rewriter, op, input, inputShapeVec, dim).value(); if (keepDim) { auto outShapeVec = inputShapeVec; diff --git a/lib/Conversion/TorchToStd/TorchToStd.cpp b/lib/Conversion/TorchToStd/TorchToStd.cpp index 974c1c0d2a64..d79918893322 100644 --- a/lib/Conversion/TorchToStd/TorchToStd.cpp +++ b/lib/Conversion/TorchToStd/TorchToStd.cpp @@ -178,7 +178,7 @@ class ConvertTorchTensorLiteralOp })); return success(); } - if (auto elements = op.valueAttr().dyn_cast()) { + if (auto elements = op.valueAttr().dyn_cast()) { if (auto type = elements.getType().dyn_cast()) { if (auto intType = type.getElementType().dyn_cast()) { Type builtinTensorElemTy = @@ -186,8 +186,7 @@ class ConvertTorchTensorLiteralOp auto shapedType = RankedTensorType::get(type.getShape(), builtinTensorElemTy); rewriter.replaceOpWithNewOp( - op, OpaqueElementsAttr::get(elements.getDialect(), shapedType, - elements.getValue())); + op, DenseElementsAttr::get(shapedType, elements.getValues())); return success(); } } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 770e76945aff..558ac82a1ab9 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -148,7 +148,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, if (dtype.isa()) { tosaTensor = tosa::getConstTensor( rewriter, op, (isFloat ? doubleValue : intValue), dshape) - .getValue(); + .value(); } else if (auto intType = dtype.dyn_cast()) { auto w = intType.getWidth(); if (w != 32 && w != 64) @@ -165,7 +165,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, int32_t d = isFloat ? static_cast(doubleValue) : static_cast(intValue); tosaTensor = - tosa::getConstTensor(rewriter, op, {d}, dshape).getValue(); + tosa::getConstTensor(rewriter, op, {d}, dshape).value(); } else if (w == 64) { if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { return rewriter.notifyMatchFailure( @@ -174,7 +174,7 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, } int64_t d = (isFloat ? static_cast(doubleValue) : intValue); tosaTensor = - tosa::getConstTensor(rewriter, op, {d}, dshape).getValue(); + tosa::getConstTensor(rewriter, op, {d}, dshape).value(); } } else { return rewriter.notifyMatchFailure(op, "Usupported element type"); @@ -592,7 +592,7 @@ class ConvertAtenReductionOp : public OpConversionPattern { // TBD - support dtype casting. - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1222,7 +1222,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { op->getLoc(), OpConversionPattern::getTypeConverter() ->convertType(transposedLhsType), - rankBroadcastedLhs, transposedLhsDimsConst.getValue()) + rankBroadcastedLhs, transposedLhsDimsConst.value()) .getResult(); } @@ -1301,7 +1301,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { op->getLoc(), OpConversionPattern::getTypeConverter() ->convertType(transposedRhsType), - rankBroadcastedRhs, transposedRhsDimsConst.getValue()) + rankBroadcastedRhs, transposedRhsDimsConst.value()) .getResult(); } @@ -1452,14 +1452,13 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { auto transposedOpType = RankedTensorType::get(transposedOpShape, outputElemTy); - output = - rewriter - .create( - op->getLoc(), - OpConversionPattern::getTypeConverter() - ->convertType(transposedOpType), - reshapedOp.getResult(), transposedOpShapeConst.getValue()) - .getResult(); + output = rewriter + .create( + op->getLoc(), + OpConversionPattern::getTypeConverter() + ->convertType(transposedOpType), + reshapedOp.getResult(), transposedOpShapeConst.value()) + .getResult(); } else { output = reshapedOp.getResult(); @@ -1646,7 +1645,7 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( transposedRhsType), - rhs, transposedRhsShapeConst.getValue()); + rhs, transposedRhsShapeConst.value()); Value matmulOutput; if (failed( @@ -1759,12 +1758,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector zeroVec(weightShape[0], 0); bias = tosa::getConstTensor( rewriter, op, zeroVec, {static_cast(weightShape[0])}) - .getValue(); + .value(); } else { SmallVector zeroVec(weightShape[0], 0); bias = tosa::getConstTensor(rewriter, op, zeroVec, {static_cast(weightShape[0])}) - .getValue(); + .value(); } } else { if (!bias.getType().cast()) @@ -1808,7 +1807,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .create( op->getLoc(), getTypeConverter()->convertType(transposedInputType), input, - nchwToNhwcTransposeConst.getValue()) + nchwToNhwcTransposeConst.value()) .getResult(); SmallVector transposedWeightShape( @@ -1820,7 +1819,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .create( op->getLoc(), getTypeConverter()->convertType(transposedWeightType), weight, - nchwToNhwcTransposeConst.getValue()) + nchwToNhwcTransposeConst.value()) .getResult(); int64_t outputHDim, outputWDim; @@ -1867,7 +1866,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .create( op->getLoc(), getTypeConverter()->convertType(transposedOutputType), - convOpResult, nhwcToNchwTransposeConst.getValue()) + convOpResult, nhwcToNchwTransposeConst.value()) .getResult(); Value rescaledResult = transposedOutput; @@ -2146,7 +2145,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto elemCntConst = tosa::getConstTensor(rewriter, op.getOperation(), {static_cast(elemCnt)}, {1}) - .getValue(); + .value(); Value elemCntRcp = rewriter.create( op.getLoc(), elemCntConst.getType(), elemCntConst); @@ -2313,7 +2312,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.self(), - transposeDimsConst.getValue()); + transposeDimsConst.value()); return success(); } @@ -2333,7 +2332,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector ln2Shape(selfType.getRank(), 1); auto ln2Op = tosa::getConstTensor(rewriter, op, {0.69314718056}, ln2Shape) - .getValue(); + .value(); auto rcpOp = rewriter.create(op.getLoc(), ln2Op.getType(), ln2Op); @@ -2523,24 +2522,24 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter, auto outType = x.getType().cast(); auto loc = op->getLoc(); auto absX = rewriter.create(loc, outType, x); - auto zero = tosa::getConstTensor(rewriter, op, 0, {}).getValue(); - auto one = tosa::getConstTensor(rewriter, op, 1, {}).getValue(); + auto zero = tosa::getConstTensor(rewriter, op, 0, {}).value(); + auto one = tosa::getConstTensor(rewriter, op, 1, {}).value(); - auto a1 = tosa::getConstTensor(rewriter, op, 0.278393, {}).getValue(); + auto a1 = tosa::getConstTensor(rewriter, op, 0.278393, {}).value(); auto a1X = rewriter.create(loc, outType, a1, absX, /*shift=*/0); auto sum = rewriter.create(loc, outType, a1X, one); - auto a2 = tosa::getConstTensor(rewriter, op, 0.230389, {}).getValue(); + auto a2 = tosa::getConstTensor(rewriter, op, 0.230389, {}).value(); auto x2 = rewriter.create(loc, outType, absX, absX, /*shift=*/0); auto a2X = rewriter.create(loc, outType, a2, x2, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a2X); - auto a3 = tosa::getConstTensor(rewriter, op, 0.000972, {}).getValue(); + auto a3 = tosa::getConstTensor(rewriter, op, 0.000972, {}).value(); auto x3 = rewriter.create(loc, outType, x2, absX, /*shift=*/0); auto a3X = rewriter.create(loc, outType, a3, x3, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a3X); - auto a4 = tosa::getConstTensor(rewriter, op, 0.078108, {}).getValue(); + auto a4 = tosa::getConstTensor(rewriter, op, 0.078108, {}).value(); auto x4 = rewriter.create(loc, outType, x3, absX, /*shift=*/0); auto a4X = rewriter.create(loc, outType, a4, x4, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a4X); @@ -2564,8 +2563,8 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter, static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Operation *op, Value x) { - auto zero = tosa::getConstTensor(rewriter, op, 0, {}).getValue(); - auto one = tosa::getConstTensor(rewriter, op, 1, {}).getValue(); + auto zero = tosa::getConstTensor(rewriter, op, 0, {}).value(); + auto one = tosa::getConstTensor(rewriter, op, 1, {}).value(); auto loc = op->getLoc(); // buildNormalCdf, mean = zero, sigma = one @@ -2574,12 +2573,12 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Value xMinusMean = rewriter.create(loc, outType, x, mean); // rsqrt of 2 Value rsqrt2 = - tosa::getConstTensor(rewriter, op, 0.70710678, {}).getValue(); + tosa::getConstTensor(rewriter, op, 0.70710678, {}).value(); Value erfArg = rewriter.create(loc, outType, xMinusMean, rsqrt2, /*shift=*/0); Value erf = approximateErfOp(rewriter, op, erfArg); Value erfPlus1 = rewriter.create(loc, outType, one, erf); - Value oneHalf = tosa::getConstTensor(rewriter, op, 0.5, {}).getValue(); + Value oneHalf = tosa::getConstTensor(rewriter, op, 0.5, {}).value(); Value normalCdf = rewriter.create(loc, outType, oneHalf, erfPlus1, /*shift=*/0); return normalCdf; @@ -2651,10 +2650,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( const double kAlpha = cstAlpha0 * cstAlpha1; Value kAlphaHalf = - tosa::getConstTensor(rewriter, op, kAlpha * oneHalf, {}) - .getValue(); + tosa::getConstTensor(rewriter, op, kAlpha * oneHalf, {}).value(); Value negOneHalf = - tosa::getConstTensor(rewriter, op, -0.5, {}).getValue(); + tosa::getConstTensor(rewriter, op, -0.5, {}).value(); Value inputSquared = rewriter.create( loc, selfType, adaptor.self(), adaptor.self(), /*shift=*/0); Value negHalfInputSquared = rewriter.create( @@ -2810,7 +2808,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.self(), - transposeDimsConst.getValue()); + transposeDimsConst.value()); return success(); } @@ -2992,7 +2990,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { RankedTensorType::get(transposedInputShape, inputElemTy); return rewriter .create(op->getLoc(), transposedInputType, input, - transposeDimsConst.getValue()) + transposeDimsConst.value()) .getResult(); } @@ -3319,7 +3317,7 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { SmallVector values(size, fillVal); auto constOp = - tosa::getConstTensor(rewriter, op, values, shape).getValue(); + tosa::getConstTensor(rewriter, op, values, shape).value(); rewriter.replaceOpWithNewOp(op, outType, constOp); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 9b1b61cd5338..bcc29c0195fe 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -297,13 +297,13 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op, reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale, output_zp); - if (!val.hasValue()) + if (!val.has_value()) return llvm::None; if (!input_is_qtype) { Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale); return CreateOpAndInfer(rewriter, op->getLoc(), output_type, - val.getValue(), div_const, 0) + val.value(), div_const, 0) .getResult(); } diff --git a/lib/Dialect/Torch/IR/TorchDialect.cpp b/lib/Dialect/Torch/IR/TorchDialect.cpp index bae66e8b07b3..835be031644b 100644 --- a/lib/Dialect/Torch/IR/TorchDialect.cpp +++ b/lib/Dialect/Torch/IR/TorchDialect.cpp @@ -65,7 +65,7 @@ Type Torch::parseTorchDialectType(AsmParser &parser) { StringRef mnemonic; Type genType; auto parseResult = generatedTypeParser(parser, &mnemonic, genType); - if (parseResult.hasValue()) + if (parseResult.has_value()) return genType; parser.emitError(typeLoc) << "unknown type `" << mnemonic << "` in dialect `" << TorchDialect::getDialectNamespace() << "`"; diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index b6a21b7b884d..34f656a81d96 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -247,7 +247,7 @@ LogicalResult ClassTypeOp::verify() { //===----------------------------------------------------------------------===// OperandRange PrimLoopOp::getSuccessorEntryOperands(Optional index) { - assert(index.hasValue() && index.value() == 0); + assert(index.has_value() && index.value() == 0); return iterArgsInit(); } @@ -256,7 +256,7 @@ void PrimLoopOp::getSuccessorRegions( SmallVectorImpl ®ions) { (void)operands; - if (!index.hasValue()) { + if (!index.has_value()) { regions.emplace_back(®ion(), region().getArguments().slice(1)); return; } @@ -328,7 +328,7 @@ void PrimIfOp::getSuccessorRegions(Optional index, ArrayRef operands, SmallVectorImpl ®ions) { // The `then` and the `else` region branch back to the parent operation. - if (index.hasValue()) { + if (index.has_value()) { regions.push_back(RegionSuccessor(getResults())); return; } @@ -536,7 +536,7 @@ OpFoldResult Aten__RangeLengthOp::fold(ArrayRef operands) { // r[i] = lo + step*i such that i >= 0 and r[i] < hi // So maximize `i` such that lo + step * i < hi // ==> i == ceildiv(hi - lo, step) - return IntegerAttr::get(lo.getType(), + return IntegerAttr::get(lo.cast().getType(), llvm::APIntOps::RoundingSDiv(hiInt - loInt, stepInt, APInt::Rounding::UP)); } @@ -554,7 +554,8 @@ OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef operands) { auto indexInt = index.dyn_cast_or_null().getValue(); auto startInt = start.dyn_cast_or_null().getValue(); auto stepInt = step.dyn_cast_or_null().getValue(); - return IntegerAttr::get(index.getType(), startInt + stepInt * indexInt); + return IntegerAttr::get(index.cast().getType(), + startInt + stepInt * indexInt); } //===----------------------------------------------------------------------===// @@ -1903,7 +1904,7 @@ void ShapeCalculateOp::getSuccessorRegions( SmallVectorImpl ®ions) { (void)operands; - if (!index.hasValue()) { + if (!index.has_value()) { // First thing the op does is branch into the shape calculation. regions.emplace_back(&shapeCalculation()); return; diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index 10e2008adf18..bff1a2e8910c 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -236,7 +236,7 @@ Type parseTensorType(MLIRContext *context, AsmParser &parser, } int64_t size; auto optionalInt = parser.parseOptionalInteger(size); - if (optionalInt.hasValue()) { + if (optionalInt.has_value()) { if (failed(*optionalInt)) return Type(); sizes.push_back(size); diff --git a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp index 6ca39f37993c..a337d8b86fbf 100644 --- a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp @@ -648,7 +648,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) { monomorphization.argInstances[0].instance.getDefiningOp(), monomorphization.func); } - if (linkageInfo.hasValue()) { + if (linkageInfo.has_value()) { // It's a method. newFunc.setVisibility(linkageInfo->isPrivate ? SymbolTable::Visibility::Private diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index de189928321b..9a31c30ed4dc 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -123,8 +123,8 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock PatternRewriter &rewriter) { DenseMap originalReturnTypes; - if (ops.returnOp.hasValue()) { - auto returnOp = ops.returnOp.getValue(); + if (ops.returnOp.has_value()) { + auto returnOp = ops.returnOp.value(); for (auto operand : llvm::enumerate(returnOp->getOperands())) { auto type = operand.value().getType(); if (!type.isa()) @@ -160,8 +160,8 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock result.setType(resultType.getWithValueSemantics()); }); } - if (ops.returnOp.hasValue()) { - auto returnOp = ops.returnOp.getValue(); + if (ops.returnOp.has_value()) { + auto returnOp = ops.returnOp.value(); for (int i = 0, e = returnOp->getNumOperands(); i < e; i++) { OpOperand &operand = returnOp->getOpOperand(i); auto it = originalReturnTypes.find(i); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 60fa8692f18e..9163b442ebd8 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -310,15 +310,15 @@ struct ValueKnowledge { const ValueKnowledge &rhs) { Optional knowledge = meetTypes(lhs, rhs); - if (!knowledge.hasValue()) + if (!knowledge.has_value()) return None; - ValueKnowledge result = knowledge.getValue(); + ValueKnowledge result = knowledge.value(); Optional optional = meetOptionalKnowledge(lhs.optional, rhs.optional); - if (!optional.hasValue()) + if (!optional.has_value()) return None; - result.optional = optional.getValue(); + result.optional = optional.value(); return result; } @@ -517,13 +517,13 @@ updateResultTypeState(const ValueKnowledge *tensor, Optional rankIsNonZero, const torch_upstream::ResultTypeState &inState, bool skipRankCheck = false) { - if (!rankIsNonZero.hasValue() && !skipRankCheck) + if (!rankIsNonZero.has_value() && !skipRankCheck) return torch_upstream::ResultTypeState{}; assert(tensor->dtype && "tensor.dtype must be not none"); torch_upstream::ResultTypeState new_state = inState; torch_upstream::ScalarType current = getScalarTypeForType(tensor->dtype); - if (skipRankCheck || rankIsNonZero.getValue()) + if (skipRankCheck || rankIsNonZero.value()) new_state.dimResult = promote_skip_undefined(inState.dimResult, current); else new_state.zeroResult = promote_skip_undefined(inState.zeroResult, current); @@ -1124,8 +1124,8 @@ void TypeAnalysis::incorporateKnowledge(Value v, const ValueKnowledge &knowledge) { auto updatedKnowledge = ValueKnowledge::meet( knowledge, ValueKnowledge::getPessimisticValueState(v)); - assert(updatedKnowledge.hasValue() && "IR has contradictory type!"); - getLatticeElement(v)->join(updatedKnowledge.getValue()); + assert(updatedKnowledge.has_value() && "IR has contradictory type!"); + getLatticeElement(v)->join(updatedKnowledge.value()); } void TypeAnalysis::visitAtenLinearOp(AtenLinearOp op, @@ -1169,9 +1169,9 @@ void TypeAnalysis::visitAtenArangeLikeOpHelper(Operation *op, // `dtype` is inferred to be the default dtype, see // `torch.get_default_dtype`. Otherwise, the `dtype` is inferred to // be `torch.int64` - if ((start.hasValue() && (*start).getType().isa()) || + if ((start.has_value() && (*start).getType().isa()) || end.getType().isa() || - (step.hasValue() && (*step).getType().isa())) { + (step.has_value() && (*step).getType().isa())) { // TODO: Should get the dtype from torch.get_default_dtype(). // For now, use float32 which is the initial default dtype. knowledge.dtype = Float32Type::get(op->getContext()); @@ -1263,7 +1263,7 @@ void TypeAnalysis::visitConstantTensorAllocOp(OpTy op, ValueKnowledge::getTensorPessimisticValueState(op->getContext()); if (!dataType) dataType = Torch::FloatType::get(op->getContext()); - fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), dataType.getValue()); + fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), dataType.value()); incorporateKnowledge(op.getResult(), knowledge); } @@ -1333,11 +1333,11 @@ void TypeAnalysis::visitAtenCatOp(AtenCatOp op, })); for (auto tensor : tensors) { auto newDtype = meetElementTypes(knowledge.dtype, tensor.dtype); - if (!newDtype.hasValue()) { + if (!newDtype.has_value()) { incorporateKnowledge(op.getResult(), knowledge); return; } - knowledge.dtype = newDtype.getValue(); + knowledge.dtype = newDtype.value(); } incorporateKnowledge(op.getResult(), knowledge); } diff --git a/test/Conversion/TorchToMhlo/pooling.mlir b/test/Conversion/TorchToMhlo/pooling.mlir index ab057522ba2f..b976e473ed7d 100644 --- a/test/Conversion/TorchToMhlo/pooling.mlir +++ b/test/Conversion/TorchToMhlo/pooling.mlir @@ -98,9 +98,9 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) - // CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0> : tensor // CHECK: %[[VAL_19:.*]]:2 = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_17]], %[[VAL_6]], %[[VAL_18]]) ({ // CHECK: ^bb0(%[[IVAL_0:.*]]: tensor, %[[IVAL_1:.*]]: tensor, %[[IVAL_2:.*]]: tensor, %[[IVAL_3:.*]]: tensor): -// CHECK: %[[IVAL_4:.*]] = "mhlo.compare"(%[[IVAL_0]], %[[IVAL_2]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor +// CHECK: %[[IVAL_4:.*]] = mhlo.compare GE, %[[IVAL_0]], %[[IVAL_2]], FLOAT : (tensor, tensor) -> tensor // CHECK: %[[IVAL_5:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_0]], %[[IVAL_2]]) : (tensor, tensor, tensor) -> tensor -// CHECK: %[[IVAL_6:.*]] = "mhlo.compare"(%[[IVAL_0]], %[[IVAL_2]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor +// CHECK: %[[IVAL_6:.*]] = mhlo.compare EQ, %[[IVAL_0]], %[[IVAL_2]], FLOAT : (tensor, tensor) -> tensor // CHECK: %[[IVAL_7:.*]] = mhlo.minimum %[[IVAL_1]], %[[IVAL_3]] : tensor // CHECK: %[[IVAL_8:.*]] = "mhlo.select"(%[[IVAL_4]], %[[IVAL_1]], %[[IVAL_3]]) : (tensor, tensor, tensor) -> tensor // CHECK: %[[IVAL_9:.*]] = "mhlo.select"(%[[IVAL_6]], %[[IVAL_7]], %[[IVAL_8]]) : (tensor, tensor, tensor) -> tensor @@ -215,4 +215,4 @@ func.func @torch.aten.avg_pool2d$count_include_pad(%arg0: !torch.vtensor<[?,?,?, %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %true, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32> return %3 : !torch.vtensor<[?,?,?,?],f32> -} \ No newline at end of file +} diff --git a/test/Conversion/TorchToMhlo/reduction.mlir b/test/Conversion/TorchToMhlo/reduction.mlir index 7af580e97dea..21f50e677c2c 100644 --- a/test/Conversion/TorchToMhlo/reduction.mlir +++ b/test/Conversion/TorchToMhlo/reduction.mlir @@ -17,9 +17,9 @@ // CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor // CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor, tensor, tensor, tensor) -> (tensor, tensor) // CHECK: reducer(%[[VAL_11:.*]]: tensor, %[[VAL_13:.*]]: tensor) (%[[VAL_12:.*]]: tensor, %[[VAL_14:.*]]: tensor) { -// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor, tensor) -> tensor // CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor, tensor, tensor) -> tensor -// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor, tensor) -> tensor // CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor // CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor, tensor, tensor) -> tensor // CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor, tensor, tensor) -> tensor @@ -58,9 +58,9 @@ func.func @torch.aten.max.dim$keepdim(%arg0: !torch.vtensor<[?,?],f32>) -> (!tor // CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor // CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor, tensor, tensor, tensor) -> (tensor, tensor) // CHECK: reducer(%[[VAL_11:.*]]: tensor, %[[VAL_13:.*]]: tensor) (%[[VAL_12:.*]]: tensor, %[[VAL_14:.*]]: tensor) { -// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor, tensor) -> tensor // CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor, tensor, tensor) -> tensor -// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor, tensor) -> tensor // CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor // CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor, tensor, tensor) -> tensor // CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor, tensor, tensor) -> tensor @@ -95,9 +95,9 @@ func.func @torch.aten.max.dim(%arg0: !torch.vtensor<[?,?],f32>) -> (!torch.vtens // CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor // CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor, tensor, tensor, tensor) -> (tensor, tensor) // CHECK: reducer(%[[VAL_11:.*]]: tensor, %[[VAL_13:.*]]: tensor) (%[[VAL_12:.*]]: tensor, %[[VAL_14:.*]]: tensor) { -// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor, tensor) -> tensor // CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor, tensor, tensor) -> tensor -// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor, tensor) -> tensor // CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor // CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor, tensor, tensor) -> tensor // CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor, tensor, tensor) -> tensor @@ -134,9 +134,9 @@ func.func @torch.aten.argmax$keepdim(%arg0: !torch.vtensor<[?,?],f32>) -> !torch // CHECK: %[[VAL_9:.*]] = "mhlo.dynamic_iota"(%[[VAL_8]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor // CHECK: %[[VAL_10:.*]]:2 = mhlo.reduce(%[[VAL_1]] init: %[[VAL_6]]), (%[[VAL_9]] init: %[[VAL_7]]) across dimensions = [1] : (tensor, tensor, tensor, tensor) -> (tensor, tensor) // CHECK: reducer(%[[VAL_11:.*]]: tensor, %[[VAL_13:.*]]: tensor) (%[[VAL_12:.*]]: tensor, %[[VAL_14:.*]]: tensor) { -// CHECK: %[[VAL_15:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_15:.*]] = mhlo.compare GE, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor, tensor) -> tensor // CHECK: %[[VAL_16:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_11]], %[[VAL_13]]) : (tensor, tensor, tensor) -> tensor -// CHECK: %[[VAL_17:.*]] = "mhlo.compare"(%[[VAL_11]], %[[VAL_13]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_17:.*]] = mhlo.compare EQ, %[[VAL_11]], %[[VAL_13]], FLOAT : (tensor, tensor) -> tensor // CHECK: %[[VAL_18:.*]] = mhlo.minimum %[[VAL_12]], %[[VAL_14]] : tensor // CHECK: %[[VAL_19:.*]] = "mhlo.select"(%[[VAL_15]], %[[VAL_12]], %[[VAL_14]]) : (tensor, tensor, tensor) -> tensor // CHECK: %[[VAL_20:.*]] = "mhlo.select"(%[[VAL_17]], %[[VAL_18]], %[[VAL_19]]) : (tensor, tensor, tensor) -> tensor @@ -240,4 +240,4 @@ func.func @torch.aten.sum(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor< func.func @torch.aten.max(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[],f32> { %0 = torch.aten.max %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[],f32> return %0 : !torch.vtensor<[],f32> -} \ No newline at end of file +}