Skip to content

Commit a02dbb2

Browse files
author
Tanyo Kwok
authored
[MHLO] Init MHLO slice like op patterns (#1091)
See RFC: #999 Co-authored-by: Bairen Yi [email protected] Co-authored-by: Jiawei Wu [email protected] Co-authored-by: Tianyou Guo [email protected] Co-authored-by: Xu Yan [email protected] Co-authored-by: Ziheng Jiang [email protected]
1 parent f271e6a commit a02dbb2

File tree

5 files changed

+555
-0
lines changed

5 files changed

+555
-0
lines changed

lib/Conversion/TorchToMhlo/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_conversion_library(TorchMLIRTorchToMhlo
22
TorchToMhlo.cpp
33
BasicOp.cpp
4+
SliceLikeOps.cpp
45

56
ADDITIONAL_HEADER_DIRS
67
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo

lib/Conversion/TorchToMhlo/PopulatePatterns.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ namespace torch_to_mhlo {
1919
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
2020
RewritePatternSet &patterns,
2121
ConversionTarget &target);
22+
void populateSliceLikeOpPatternsAndLegality(TypeConverter &typeConverter,
23+
RewritePatternSet &patterns,
24+
ConversionTarget &target);
25+
2226

2327
} // namespace torch_to_mhlo
2428
} // namespace torch
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// Also available under a BSD-style license. See LICENSE.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
11+
12+
#include "../PassDetail.h"
13+
#include "./PopulatePatterns.h"
14+
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
16+
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
17+
#include "torch-mlir/Conversion/Utils/Utils.h"
18+
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
19+
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
20+
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
21+
22+
using namespace mlir;
23+
using namespace mlir::torch;
24+
using namespace mlir::torch::Torch;
25+
26+
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
27+
static constexpr size_t kMhloDimSizeBits = 32;
28+
#else
29+
static constexpr size_t kMhloDimSizeBits = 64;
30+
#endif
31+
32+
namespace {
33+
34+
SmallVector<Value, 4> getDimSizesOfTensor(
35+
PatternRewriter& rewriter,
36+
Operation* op,
37+
Value value) {
38+
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
39+
if (!valueTy) {
40+
op->emitOpError("getDimSizesOfTensor(): the input is not a ranked tensor");
41+
return {};
42+
}
43+
44+
auto rank = valueTy.getRank();
45+
if (rank == 0) {
46+
return {};
47+
}
48+
49+
SmallVector<Value, 4> dimSizes;
50+
dimSizes.reserve(rank);
51+
auto loc = op->getLoc();
52+
for (auto d = 0; d < rank; ++d) {
53+
dimSizes.emplace_back(rewriter.create<arith::IndexCastOp>(
54+
loc,
55+
rewriter.getIntegerType(kMhloDimSizeBits),
56+
rewriter.create<tensor::DimOp>(loc, value, d)));
57+
}
58+
return dimSizes;
59+
}
60+
61+
// A dimension index from torch.dialect might outside the range [0, dimSize].
62+
// The function is used to normalize the input index into the range.
63+
Value getNormalizedDimSizeInternal(
64+
PatternRewriter& rewriter,
65+
Operation* op,
66+
Value index,
67+
Value dimSize) {
68+
auto loc = op->getLoc();
69+
Value zero = rewriter.create<arith::ConstantOp>(
70+
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0));
71+
72+
// To normalize index into range [-dimSize, dimSize]
73+
// index = min(max(-dimSize, index), dimSize)
74+
auto negDimSize = rewriter.create<arith::SubIOp>(loc, zero, dimSize);
75+
index = rewriter.create<arith::MaxSIOp>(loc, negDimSize, index);
76+
index = rewriter.create<arith::MinSIOp>(loc, dimSize, index);
77+
78+
auto dimSizePlusIndex = rewriter.create<arith::AddIOp>(loc, dimSize, index);
79+
auto indexPositive = rewriter.create<arith::CmpIOp>(
80+
loc, arith::CmpIPredicate::sge, index, zero);
81+
// get positive index: (index >=0) ? index: index + dimSize
82+
return rewriter.create<arith::SelectOp>(
83+
loc, indexPositive, index, dimSizePlusIndex);
84+
}
85+
86+
Value getDynamicSliceInternal(
87+
PatternRewriter& rewriter,
88+
Operation* op,
89+
Value input,
90+
Value startIndex,
91+
Value endIndex,
92+
Value step,
93+
size_t dimIndex,
94+
ArrayRef<Value> dimSizes) {
95+
auto loc = op->getLoc();
96+
// startIndex & endIndex has been normailized into range [0, dSize]
97+
Type intType = rewriter.getIntegerType(kMhloDimSizeBits);
98+
Value zero = rewriter.create<arith::ConstantOp>(
99+
loc, rewriter.getIntegerAttr(intType, 0));
100+
Value one = rewriter.create<arith::ConstantOp>(
101+
loc, rewriter.getIntegerAttr(intType, 1));
102+
103+
SmallVector<Value, 4> startIndices;
104+
SmallVector<Value, 4> endIndices;
105+
SmallVector<Value, 4> strides;
106+
107+
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
108+
size_t rank = inputTy.getRank();
109+
startIndices.reserve(rank);
110+
endIndices.reserve(rank);
111+
strides.reserve(rank);
112+
113+
auto endIndexIsZero = rewriter.create<arith::CmpIOp>(
114+
loc, arith::CmpIPredicate::eq, endIndex, zero);
115+
endIndex = rewriter.create<arith::SelectOp>(
116+
loc, endIndexIsZero, dimSizes[dimIndex], endIndex);
117+
118+
for (size_t r = 0; r < rank; ++r) {
119+
if (r == dimIndex) {
120+
startIndices.push_back(startIndex);
121+
endIndices.push_back(endIndex);
122+
strides.push_back(step);
123+
} else {
124+
startIndices.push_back(zero);
125+
endIndices.push_back(dimSizes[r]);
126+
strides.push_back(one);
127+
}
128+
}
129+
130+
auto startTensor =
131+
rewriter.create<tensor::FromElementsOp>(loc, startIndices).getResult();
132+
auto endTensor =
133+
rewriter.create<tensor::FromElementsOp>(loc, endIndices).getResult();
134+
auto stridesTensor =
135+
rewriter.create<tensor::FromElementsOp>(loc, strides).getResult();
136+
137+
auto inputShape = inputTy.getShape();
138+
SmallVector<int64_t, 4> sliceShape(inputShape.begin(), inputShape.end());
139+
sliceShape[dimIndex] = ShapedType::kDynamicSize;
140+
auto sliceoutputTy =
141+
RankedTensorType::get(sliceShape, inputTy.getElementType());
142+
return rewriter.create<mhlo::RealDynamicSliceOp>(
143+
loc, sliceoutputTy, input, startTensor, endTensor, stridesTensor);
144+
}
145+
146+
// Get a dynamic slice of the tensor from startIndex to endIndex with stride step
147+
// on the specifed dimension. The input startIndex(default to 0),
148+
// endIndex(default to dimSize), and step(default to 1) can be optional.
149+
Value getDynamicSlice(
150+
PatternRewriter& rewriter,
151+
Operation* op,
152+
Value input,
153+
llvm::Optional<Value> startIndexOpt,
154+
llvm::Optional<Value> endIndexOpt,
155+
llvm::Optional<Value> stepOpt,
156+
int64_t dim) {
157+
auto loc = op->getLoc();
158+
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
159+
auto rank = inputTy.getRank();
160+
161+
dim = (dim + rank) % rank;
162+
Value dimSize = rewriter.create<arith::IndexCastOp>(
163+
loc,
164+
rewriter.getI64Type(),
165+
rewriter.create<tensor::DimOp>(loc, input, dim));
166+
167+
Value normStartIndex = startIndexOpt
168+
? getNormalizedDimSizeInternal(rewriter, op, *startIndexOpt, dimSize)
169+
: rewriter.create<arith::ConstantOp>(
170+
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0));
171+
Value normEndIndex = endIndexOpt
172+
? getNormalizedDimSizeInternal(rewriter, op, *endIndexOpt, dimSize)
173+
: dimSize;
174+
Value step = stepOpt
175+
? *stepOpt
176+
: rewriter.create<arith::ConstantOp>(
177+
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1));
178+
179+
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
180+
auto i32Type = rewriter.getIntegerType(kMhloDimSizeBits);
181+
normStartIndex =
182+
rewriter.create<arith::TruncIOp>(loc, i32Type, normStartIndex);
183+
normEndIndex =
184+
rewriter.create<arith::TruncIOp>(loc, i32Type, normEndIndex);
185+
step = rewriter.create<arith::TruncIOp>(loc, i32Type, step);
186+
#endif
187+
auto dimSizes = getDimSizesOfTensor(rewriter, op, input);
188+
189+
return getDynamicSliceInternal(
190+
rewriter, op, input, normStartIndex, normEndIndex, step, dim, dimSizes);
191+
}
192+
193+
template <typename AtenOpT>
194+
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
195+
public:
196+
using OpConversionPattern<AtenOpT>::OpConversionPattern;
197+
using OpAdaptor = typename AtenOpT::Adaptor;
198+
LogicalResult
199+
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
200+
ConversionPatternRewriter &rewriter) const override;
201+
};
202+
203+
template <>
204+
LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
205+
AtenSliceTensorOp op,
206+
OpAdaptor adaptor,
207+
ConversionPatternRewriter& rewriter) const {
208+
auto self = adaptor.self();
209+
auto selfTy = self.getType().template cast<RankedTensorType>();
210+
if (!selfTy)
211+
return op.emitError("Only ranked tensor types supported in MHLO Rsub");
212+
int64_t dim;
213+
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
214+
return rewriter.notifyMatchFailure(
215+
op, "Only constant dim is currently supported");
216+
217+
auto getOptionalVal = [&](Value val) -> llvm::Optional<Value> {
218+
if (val.getType().isa<Torch::NoneType>()) {
219+
return llvm::None;
220+
} else {
221+
return val;
222+
}
223+
};
224+
225+
llvm::Optional<Value> start = getOptionalVal(adaptor.start());
226+
llvm::Optional<Value> end = getOptionalVal(adaptor.end());
227+
llvm::Optional<Value> step = getOptionalVal(adaptor.step());
228+
229+
Value sliced =
230+
getDynamicSlice(rewriter, op, self, start, end, step, dim);
231+
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
232+
op, getTypeConverter()->convertType(op.getType()), sliced);
233+
234+
return success();
235+
}
236+
} // namespace
237+
238+
void mlir::torch::torch_to_mhlo::populateSliceLikeOpPatternsAndLegality(
239+
TypeConverter &typeConverter, RewritePatternSet &patterns,
240+
ConversionTarget &target) {
241+
MLIRContext *context = patterns.getContext();
242+
243+
#define INSERT_ATENOP_PATTERN(AtenOp) \
244+
target.addIllegalOp<AtenOp>(); \
245+
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
246+
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
247+
#undef INSERT_ATENOP_PATTERN
248+
249+
}

lib/Conversion/TorchToMhlo/TorchToMhlo.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
5151

5252
torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
5353
target);
54+
torch_to_mhlo::populateSliceLikeOpPatternsAndLegality(typeConverter, patterns,
55+
target);
56+
5457
if (failed(applyPartialConversion(getOperation(), target,
5558
std::move(patterns)))) {
5659
return signalPassFailure();

0 commit comments

Comments
 (0)