|
| 1 | +//===----------------------------------------------------------------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// Also available under a BSD-style license. See LICENSE. |
| 7 | +// |
| 8 | +//===----------------------------------------------------------------------===// |
| 9 | + |
| 10 | +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" |
| 11 | + |
| 12 | +#include "../PassDetail.h" |
| 13 | +#include "./PopulatePatterns.h" |
| 14 | +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" |
| 15 | +#include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 16 | +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" |
| 17 | +#include "torch-mlir/Conversion/Utils/Utils.h" |
| 18 | +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" |
| 19 | +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" |
| 20 | +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" |
| 21 | + |
| 22 | +using namespace mlir; |
| 23 | +using namespace mlir::torch; |
| 24 | +using namespace mlir::torch::Torch; |
| 25 | + |
| 26 | +#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 |
| 27 | +static constexpr size_t kMhloDimSizeBits = 32; |
| 28 | +#else |
| 29 | +static constexpr size_t kMhloDimSizeBits = 64; |
| 30 | +#endif |
| 31 | + |
| 32 | +namespace { |
| 33 | + |
| 34 | +SmallVector<Value, 4> getDimSizesOfTensor( |
| 35 | + PatternRewriter& rewriter, |
| 36 | + Operation* op, |
| 37 | + Value value) { |
| 38 | + auto valueTy = value.getType().dyn_cast<RankedTensorType>(); |
| 39 | + if (!valueTy) { |
| 40 | + op->emitOpError("getDimSizesOfTensor(): the input is not a ranked tensor"); |
| 41 | + return {}; |
| 42 | + } |
| 43 | + |
| 44 | + auto rank = valueTy.getRank(); |
| 45 | + if (rank == 0) { |
| 46 | + return {}; |
| 47 | + } |
| 48 | + |
| 49 | + SmallVector<Value, 4> dimSizes; |
| 50 | + dimSizes.reserve(rank); |
| 51 | + auto loc = op->getLoc(); |
| 52 | + for (auto d = 0; d < rank; ++d) { |
| 53 | + dimSizes.emplace_back(rewriter.create<arith::IndexCastOp>( |
| 54 | + loc, |
| 55 | + rewriter.getIntegerType(kMhloDimSizeBits), |
| 56 | + rewriter.create<tensor::DimOp>(loc, value, d))); |
| 57 | + } |
| 58 | + return dimSizes; |
| 59 | +} |
| 60 | + |
| 61 | +// A dimension index from torch.dialect might outside the range [0, dimSize]. |
| 62 | +// The function is used to normalize the input index into the range. |
| 63 | +Value getNormalizedDimSizeInternal( |
| 64 | + PatternRewriter& rewriter, |
| 65 | + Operation* op, |
| 66 | + Value index, |
| 67 | + Value dimSize) { |
| 68 | + auto loc = op->getLoc(); |
| 69 | + Value zero = rewriter.create<arith::ConstantOp>( |
| 70 | + loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0)); |
| 71 | + |
| 72 | + // To normalize index into range [-dimSize, dimSize] |
| 73 | + // index = min(max(-dimSize, index), dimSize) |
| 74 | + auto negDimSize = rewriter.create<arith::SubIOp>(loc, zero, dimSize); |
| 75 | + index = rewriter.create<arith::MaxSIOp>(loc, negDimSize, index); |
| 76 | + index = rewriter.create<arith::MinSIOp>(loc, dimSize, index); |
| 77 | + |
| 78 | + auto dimSizePlusIndex = rewriter.create<arith::AddIOp>(loc, dimSize, index); |
| 79 | + auto indexPositive = rewriter.create<arith::CmpIOp>( |
| 80 | + loc, arith::CmpIPredicate::sge, index, zero); |
| 81 | + // get positive index: (index >=0) ? index: index + dimSize |
| 82 | + return rewriter.create<arith::SelectOp>( |
| 83 | + loc, indexPositive, index, dimSizePlusIndex); |
| 84 | +} |
| 85 | + |
| 86 | +Value getDynamicSliceInternal( |
| 87 | + PatternRewriter& rewriter, |
| 88 | + Operation* op, |
| 89 | + Value input, |
| 90 | + Value startIndex, |
| 91 | + Value endIndex, |
| 92 | + Value step, |
| 93 | + size_t dimIndex, |
| 94 | + ArrayRef<Value> dimSizes) { |
| 95 | + auto loc = op->getLoc(); |
| 96 | + // startIndex & endIndex has been normailized into range [0, dSize] |
| 97 | + Type intType = rewriter.getIntegerType(kMhloDimSizeBits); |
| 98 | + Value zero = rewriter.create<arith::ConstantOp>( |
| 99 | + loc, rewriter.getIntegerAttr(intType, 0)); |
| 100 | + Value one = rewriter.create<arith::ConstantOp>( |
| 101 | + loc, rewriter.getIntegerAttr(intType, 1)); |
| 102 | + |
| 103 | + SmallVector<Value, 4> startIndices; |
| 104 | + SmallVector<Value, 4> endIndices; |
| 105 | + SmallVector<Value, 4> strides; |
| 106 | + |
| 107 | + auto inputTy = input.getType().dyn_cast<RankedTensorType>(); |
| 108 | + size_t rank = inputTy.getRank(); |
| 109 | + startIndices.reserve(rank); |
| 110 | + endIndices.reserve(rank); |
| 111 | + strides.reserve(rank); |
| 112 | + |
| 113 | + auto endIndexIsZero = rewriter.create<arith::CmpIOp>( |
| 114 | + loc, arith::CmpIPredicate::eq, endIndex, zero); |
| 115 | + endIndex = rewriter.create<arith::SelectOp>( |
| 116 | + loc, endIndexIsZero, dimSizes[dimIndex], endIndex); |
| 117 | + |
| 118 | + for (size_t r = 0; r < rank; ++r) { |
| 119 | + if (r == dimIndex) { |
| 120 | + startIndices.push_back(startIndex); |
| 121 | + endIndices.push_back(endIndex); |
| 122 | + strides.push_back(step); |
| 123 | + } else { |
| 124 | + startIndices.push_back(zero); |
| 125 | + endIndices.push_back(dimSizes[r]); |
| 126 | + strides.push_back(one); |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | + auto startTensor = |
| 131 | + rewriter.create<tensor::FromElementsOp>(loc, startIndices).getResult(); |
| 132 | + auto endTensor = |
| 133 | + rewriter.create<tensor::FromElementsOp>(loc, endIndices).getResult(); |
| 134 | + auto stridesTensor = |
| 135 | + rewriter.create<tensor::FromElementsOp>(loc, strides).getResult(); |
| 136 | + |
| 137 | + auto inputShape = inputTy.getShape(); |
| 138 | + SmallVector<int64_t, 4> sliceShape(inputShape.begin(), inputShape.end()); |
| 139 | + sliceShape[dimIndex] = ShapedType::kDynamicSize; |
| 140 | + auto sliceoutputTy = |
| 141 | + RankedTensorType::get(sliceShape, inputTy.getElementType()); |
| 142 | + return rewriter.create<mhlo::RealDynamicSliceOp>( |
| 143 | + loc, sliceoutputTy, input, startTensor, endTensor, stridesTensor); |
| 144 | +} |
| 145 | + |
| 146 | +// Get a dynamic slice of the tensor from startIndex to endIndex with stride step |
| 147 | +// on the specifed dimension. The input startIndex(default to 0), |
| 148 | +// endIndex(default to dimSize), and step(default to 1) can be optional. |
| 149 | +Value getDynamicSlice( |
| 150 | + PatternRewriter& rewriter, |
| 151 | + Operation* op, |
| 152 | + Value input, |
| 153 | + llvm::Optional<Value> startIndexOpt, |
| 154 | + llvm::Optional<Value> endIndexOpt, |
| 155 | + llvm::Optional<Value> stepOpt, |
| 156 | + int64_t dim) { |
| 157 | + auto loc = op->getLoc(); |
| 158 | + auto inputTy = input.getType().dyn_cast<RankedTensorType>(); |
| 159 | + auto rank = inputTy.getRank(); |
| 160 | + |
| 161 | + dim = (dim + rank) % rank; |
| 162 | + Value dimSize = rewriter.create<arith::IndexCastOp>( |
| 163 | + loc, |
| 164 | + rewriter.getI64Type(), |
| 165 | + rewriter.create<tensor::DimOp>(loc, input, dim)); |
| 166 | + |
| 167 | + Value normStartIndex = startIndexOpt |
| 168 | + ? getNormalizedDimSizeInternal(rewriter, op, *startIndexOpt, dimSize) |
| 169 | + : rewriter.create<arith::ConstantOp>( |
| 170 | + loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0)); |
| 171 | + Value normEndIndex = endIndexOpt |
| 172 | + ? getNormalizedDimSizeInternal(rewriter, op, *endIndexOpt, dimSize) |
| 173 | + : dimSize; |
| 174 | + Value step = stepOpt |
| 175 | + ? *stepOpt |
| 176 | + : rewriter.create<arith::ConstantOp>( |
| 177 | + loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1)); |
| 178 | + |
| 179 | +#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 |
| 180 | + auto i32Type = rewriter.getIntegerType(kMhloDimSizeBits); |
| 181 | + normStartIndex = |
| 182 | + rewriter.create<arith::TruncIOp>(loc, i32Type, normStartIndex); |
| 183 | + normEndIndex = |
| 184 | + rewriter.create<arith::TruncIOp>(loc, i32Type, normEndIndex); |
| 185 | + step = rewriter.create<arith::TruncIOp>(loc, i32Type, step); |
| 186 | +#endif |
| 187 | + auto dimSizes = getDimSizesOfTensor(rewriter, op, input); |
| 188 | + |
| 189 | + return getDynamicSliceInternal( |
| 190 | + rewriter, op, input, normStartIndex, normEndIndex, step, dim, dimSizes); |
| 191 | +} |
| 192 | + |
| 193 | +template <typename AtenOpT> |
| 194 | +class ConvertAtenOp : public OpConversionPattern<AtenOpT> { |
| 195 | +public: |
| 196 | + using OpConversionPattern<AtenOpT>::OpConversionPattern; |
| 197 | + using OpAdaptor = typename AtenOpT::Adaptor; |
| 198 | + LogicalResult |
| 199 | + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, |
| 200 | + ConversionPatternRewriter &rewriter) const override; |
| 201 | +}; |
| 202 | + |
| 203 | +template <> |
| 204 | +LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite( |
| 205 | + AtenSliceTensorOp op, |
| 206 | + OpAdaptor adaptor, |
| 207 | + ConversionPatternRewriter& rewriter) const { |
| 208 | + auto self = adaptor.self(); |
| 209 | + auto selfTy = self.getType().template cast<RankedTensorType>(); |
| 210 | + if (!selfTy) |
| 211 | + return op.emitError("Only ranked tensor types supported in MHLO Rsub"); |
| 212 | + int64_t dim; |
| 213 | + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) |
| 214 | + return rewriter.notifyMatchFailure( |
| 215 | + op, "Only constant dim is currently supported"); |
| 216 | + |
| 217 | + auto getOptionalVal = [&](Value val) -> llvm::Optional<Value> { |
| 218 | + if (val.getType().isa<Torch::NoneType>()) { |
| 219 | + return llvm::None; |
| 220 | + } else { |
| 221 | + return val; |
| 222 | + } |
| 223 | + }; |
| 224 | + |
| 225 | + llvm::Optional<Value> start = getOptionalVal(adaptor.start()); |
| 226 | + llvm::Optional<Value> end = getOptionalVal(adaptor.end()); |
| 227 | + llvm::Optional<Value> step = getOptionalVal(adaptor.step()); |
| 228 | + |
| 229 | + Value sliced = |
| 230 | + getDynamicSlice(rewriter, op, self, start, end, step, dim); |
| 231 | + rewriter.replaceOpWithNewOp<mhlo::ConvertOp>( |
| 232 | + op, getTypeConverter()->convertType(op.getType()), sliced); |
| 233 | + |
| 234 | + return success(); |
| 235 | +} |
| 236 | +} // namespace |
| 237 | + |
| 238 | +void mlir::torch::torch_to_mhlo::populateSliceLikeOpPatternsAndLegality( |
| 239 | + TypeConverter &typeConverter, RewritePatternSet &patterns, |
| 240 | + ConversionTarget &target) { |
| 241 | + MLIRContext *context = patterns.getContext(); |
| 242 | + |
| 243 | +#define INSERT_ATENOP_PATTERN(AtenOp) \ |
| 244 | + target.addIllegalOp<AtenOp>(); \ |
| 245 | + patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context); |
| 246 | + INSERT_ATENOP_PATTERN(AtenSliceTensorOp); |
| 247 | +#undef INSERT_ATENOP_PATTERN |
| 248 | + |
| 249 | +} |
0 commit comments