Skip to content

Commit 3772e0b

Browse files
author
Tanyo Kwok
authored
[NFC][MHLO] move util funcs to MhloLegalizeUtils.h/cpp (#1128)
See RFC: #999 Co-authored-by: Bairen Yi [email protected] Co-authored-by: Jiawei Wu [email protected] Co-authored-by: Tianyou Guo [email protected] Co-authored-by: Xu Yan [email protected] Co-authored-by: Ziheng Jiang [email protected]
1 parent fe3c9f5 commit 3772e0b

File tree

3 files changed

+138
-117
lines changed

3 files changed

+138
-117
lines changed

lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
//
88
//===----------------------------------------------------------------------===//
99

10+
#include "./MhloLegalizeUtils.h"
1011
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
12+
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
14+
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
1115
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
1216
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
13-
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
14-
#include "./MhloLegalizeUtils.h"
17+
#include <numeric>
1518

1619
using namespace mlir;
1720
using namespace mlir::torch;
@@ -314,5 +317,106 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter,
314317
rewriter.create<mhlo::BroadcastInDimOp>(op->getLoc(), outType, input, bcast_attr);
315318
return bcast_op.getResult();
316319
}
320+
321+
SmallVector<size_t> toPositiveDims(ArrayRef<int64_t> dims, int64_t rank) {
322+
SmallVector<size_t> posDims;
323+
posDims.reserve(rank);
324+
std::transform(
325+
dims.begin(), dims.end(), std::back_inserter(posDims),
326+
[rank](int64_t d) -> size_t { return toPositiveDim(d, rank); });
327+
return posDims;
328+
}
329+
330+
FailureOr<SmallVector<Value, 4>>
331+
getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value,
332+
ArrayRef<int64_t> inpDims) {
333+
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
334+
if (!valueTy) {
335+
return rewriter.notifyMatchFailure(
336+
op, "getDimSizesOfTensor(): the input is not a ranked tensor");
337+
}
338+
339+
auto rank = valueTy.getRank();
340+
auto dims = toPositiveDims(inpDims, rank);
341+
SmallVector<Value, 4> dimSizes;
342+
dimSizes.reserve(dims.size());
343+
344+
auto loc = op->getLoc();
345+
for (auto d : dims) {
346+
dimSizes.emplace_back(rewriter.create<arith::IndexCastOp>(
347+
loc, rewriter.getIntegerType(kMhloDimSizeBits),
348+
rewriter.create<tensor::DimOp>(loc, value, d)));
349+
}
350+
return dimSizes;
351+
}
352+
353+
FailureOr<SmallVector<Value, 4>>
354+
getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
355+
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
356+
if (!valueTy) {
357+
return rewriter.notifyMatchFailure(
358+
op, "getDimSizesOfTensor(): the input is not a ranked tensor");
359+
}
360+
361+
auto rank = valueTy.getRank();
362+
// Get int vector [0, 1, ..., rank-1]
363+
std::vector<int64_t> dims(rank);
364+
std::iota(dims.begin(), dims.end(), 0);
365+
return getDimSizesOfTensor(rewriter, op, value, dims);
366+
}
367+
368+
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
369+
Value tensor,
370+
ArrayRef<int64_t> inputUnsqzDims) {
371+
// Returns a new tensor with dims of size 1 inserted at the specified
372+
// position.
373+
//
374+
// The position indices (must be high to low dimension number of the returned
375+
// tensor) are specified with unsqzDims. Indices must be in-order, and in
376+
// range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1,
377+
// 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not.
378+
auto dimSizesInfo = getDimSizesOfTensor(rewriter, op, tensor);
379+
if (failed(dimSizesInfo))
380+
return rewriter.notifyMatchFailure(
381+
op, "failed to get dimension sizes of the input");
382+
383+
auto dimSizes = *dimSizesInfo;
384+
auto rank = dimSizes.size();
385+
size_t newRank = rank + inputUnsqzDims.size();
386+
auto unsqzDims = toPositiveDims(inputUnsqzDims, newRank);
387+
for (size_t k = 0, sz = unsqzDims.size(); k < sz; ++k)
388+
if (k > 1 && unsqzDims[k] <= unsqzDims[k - 1])
389+
return rewriter.notifyMatchFailure(
390+
op, "unsqueeze dimensions must be specified in order");
391+
392+
auto loc = op->getLoc();
393+
auto rankTy = tensor.getType().dyn_cast<RankedTensorType>();
394+
auto oldShape = rankTy.getShape();
395+
Type intType = rewriter.getIntegerType(kMhloDimSizeBits);
396+
auto one = rewriter.create<arith::ConstantOp>(
397+
loc, rewriter.getIntegerAttr(intType, 1));
398+
399+
std::vector<Value> newDimSizes;
400+
std::vector<int64_t> newShape;
401+
newDimSizes.reserve(newRank);
402+
newShape.reserve(newRank);
403+
for (size_t k = 0, i = 0, j = 0; k < newRank; ++k) {
404+
if (j < unsqzDims.size() && unsqzDims[j] == k) {
405+
newDimSizes.push_back(one);
406+
newShape.push_back(1);
407+
j++;
408+
} else {
409+
newDimSizes.push_back(dimSizes[i]);
410+
newShape.push_back(oldShape[i]);
411+
i++;
412+
}
413+
}
414+
415+
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
416+
auto mhloShape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
417+
return rewriter.create<mhlo::DynamicReshapeOp>(loc, outTy, tensor, mhloShape)
418+
.getResult();
419+
}
420+
317421
} // namespace mhlo
318422
} // namespace mlir

lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919

2020
namespace mlir {
2121
namespace mhlo {
22+
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
23+
static constexpr size_t kMhloDimSizeBits = 32;
24+
#else
25+
static constexpr size_t kMhloDimSizeBits = 64;
26+
#endif
2227

2328
using mlir::ConversionPatternRewriter;
2429

@@ -56,6 +61,23 @@ LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
5661

5762
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
5863
TensorType outType);
64+
65+
SmallVector<size_t> toPositiveDims(ArrayRef<int64_t> dims, int64_t rank);
66+
67+
// Get the dimension sizes of the input tensor, given the dimension axes
68+
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
69+
Operation *op, Value value,
70+
ArrayRef<int64_t> inpDims);
71+
72+
// Get the dimension sizes of the input tensor
73+
FailureOr<SmallVector<Value, 4>>
74+
getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value);
75+
76+
// Get a tensor that unsqueezed the specified dimensions of the input tensor
77+
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
78+
Value tensor,
79+
ArrayRef<int64_t> inputUnsqzDims);
80+
5981
} // namespace mhlo
6082
} // namespace mlir
6183

lib/Conversion/TorchToMhlo/ViewLikeOps.cpp

Lines changed: 10 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
1111

1212
#include "../PassDetail.h"
13+
#include "./MhloLegalizeUtils.h"
1314
#include "./PopulatePatterns.h"
15+
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
1416
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1517
#include "mlir/Dialect/Tensor/IR/Tensor.h"
16-
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
1718
#include "torch-mlir/Conversion/Utils/Utils.h"
1819
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
1920
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@@ -28,61 +29,7 @@ using namespace mlir::torch;
2829
using namespace mlir::torch::Torch;
2930
using namespace mlir::torch::TorchConversion;
3031

31-
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
32-
static constexpr size_t kMhloDimSizeBits = 32;
33-
#else
34-
static constexpr size_t kMhloDimSizeBits = 64;
35-
#endif
36-
3732
namespace {
38-
39-
SmallVector<size_t> toPositiveDims(ArrayRef<int64_t> dims, int64_t rank) {
40-
SmallVector<size_t> posDims;
41-
posDims.reserve(rank);
42-
std::transform(
43-
dims.begin(), dims.end(), std::back_inserter(posDims),
44-
[rank](int64_t d) -> size_t { return toPositiveDim(d, rank); });
45-
return posDims;
46-
}
47-
48-
FailureOr<SmallVector<Value, 4>>
49-
getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value,
50-
ArrayRef<int64_t> inpDims) {
51-
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
52-
if (!valueTy) {
53-
return rewriter.notifyMatchFailure(
54-
op, "getDimSizesOfTensor(): the input is not a ranked tensor");
55-
}
56-
57-
auto rank = valueTy.getRank();
58-
auto dims = toPositiveDims(inpDims, rank);
59-
SmallVector<Value, 4> dimSizes;
60-
dimSizes.reserve(dims.size());
61-
62-
auto loc = op->getLoc();
63-
for (auto d : dims) {
64-
dimSizes.emplace_back(rewriter.create<arith::IndexCastOp>(
65-
loc, rewriter.getIntegerType(kMhloDimSizeBits),
66-
rewriter.create<tensor::DimOp>(loc, value, d)));
67-
}
68-
return dimSizes;
69-
}
70-
71-
FailureOr<SmallVector<Value, 4>>
72-
getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
73-
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
74-
if (!valueTy) {
75-
return rewriter.notifyMatchFailure(
76-
op, "getDimSizesOfTensor(): the input is not a ranked tensor");
77-
}
78-
79-
auto rank = valueTy.getRank();
80-
// Get int vector [0, 1, ..., rank-1]
81-
std::vector<int64_t> dims(rank);
82-
std::iota(dims.begin(), dims.end(), 0);
83-
return getDimSizesOfTensor(rewriter, op, value, dims);
84-
}
85-
8633
// A dimension index from torch.dialect might outside the range [0, dimSize].
8734
// The function is used to normalize the input index into the range.
8835
Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op,
@@ -111,7 +58,7 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
11158
ArrayRef<Value> dimSizes) {
11259
auto loc = op->getLoc();
11360
// startIndex & endIndex has been normailized into range [0, dSize]
114-
Type intType = rewriter.getIntegerType(kMhloDimSizeBits);
61+
Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
11562
Value zero = rewriter.create<arith::ConstantOp>(
11663
loc, rewriter.getIntegerAttr(intType, 0));
11764
Value one = rewriter.create<arith::ConstantOp>(
@@ -192,14 +139,14 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
192139
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1));
193140

194141
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
195-
auto i32Type = rewriter.getIntegerType(kMhloDimSizeBits);
142+
auto i32Type = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
196143
normStartIndex =
197144
rewriter.create<arith::TruncIOp>(loc, i32Type, normStartIndex);
198145
normEndIndex = rewriter.create<arith::TruncIOp>(loc, i32Type, normEndIndex);
199146
step = rewriter.create<arith::TruncIOp>(loc, i32Type, step);
200147
#endif
201148
FailureOr<SmallVector<Value, 4>> dimSizesInfo =
202-
getDimSizesOfTensor(rewriter, op, input);
149+
mhlo::getDimSizesOfTensor(rewriter, op, input);
203150
if (failed(dimSizesInfo))
204151
return rewriter.notifyMatchFailure(
205152
op, "failed to get dimension sizes of the input");
@@ -305,7 +252,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
305252
});
306253
#endif
307254

308-
Type intType = rewriter.getIntegerType(kMhloDimSizeBits);
255+
Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
309256
Value numel = rewriter.create<arith::ConstantOp>(
310257
loc, rewriter.getIntegerAttr(intType, 1));
311258
for (auto d : dimSizes) {
@@ -350,59 +297,6 @@ bool ConvertAtenViewOp<AtenReshapeOp>::getAtenViewOpSizes(
350297
return getListConstructElements(adaptor.shape(), dimSizes);
351298
}
352299

353-
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
354-
Value tensor,
355-
ArrayRef<int64_t> inputUnsqzDims) {
356-
// Returns a new tensor with dims of size 1 inserted at the specified
357-
// position.
358-
//
359-
// The position indices (must be high to low dimension number of the returned
360-
// tensor) are specified with unsqzDims. Indices must be in-order, and in
361-
// range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1,
362-
// 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not.
363-
auto dimSizesInfo = getDimSizesOfTensor(rewriter, op, tensor);
364-
if (failed(dimSizesInfo))
365-
return rewriter.notifyMatchFailure(
366-
op, "failed to get dimension sizes of the input");
367-
368-
auto dimSizes = *dimSizesInfo;
369-
auto rank = dimSizes.size();
370-
size_t newRank = rank + inputUnsqzDims.size();
371-
auto unsqzDims = toPositiveDims(inputUnsqzDims, newRank);
372-
for (size_t k = 0, sz = unsqzDims.size(); k < sz; ++k)
373-
if (k > 1 && unsqzDims[k] <= unsqzDims[k - 1])
374-
return rewriter.notifyMatchFailure(
375-
op, "unsqueeze dimensions must be specified in order");
376-
377-
auto loc = op->getLoc();
378-
auto rankTy = tensor.getType().dyn_cast<RankedTensorType>();
379-
auto oldShape = rankTy.getShape();
380-
Type intType = rewriter.getIntegerType(kMhloDimSizeBits);
381-
auto one = rewriter.create<arith::ConstantOp>(
382-
loc, rewriter.getIntegerAttr(intType, 1));
383-
384-
std::vector<Value> newDimSizes;
385-
std::vector<int64_t> newShape;
386-
newDimSizes.reserve(newRank);
387-
newShape.reserve(newRank);
388-
for (size_t k = 0, i = 0, j = 0; k < newRank; ++k) {
389-
if (j < unsqzDims.size() && unsqzDims[j] == k) {
390-
newDimSizes.push_back(one);
391-
newShape.push_back(1);
392-
j++;
393-
} else {
394-
newDimSizes.push_back(dimSizes[i]);
395-
newShape.push_back(oldShape[i]);
396-
i++;
397-
}
398-
}
399-
400-
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
401-
auto mhloShape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
402-
return rewriter.create<mhlo::DynamicReshapeOp>(loc, outTy, tensor, mhloShape)
403-
.getResult();
404-
}
405-
406300
template <>
407301
LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
408302
AtenSqueezeOp op, OpAdaptor adaptor,
@@ -428,7 +322,7 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
428322
dims.push_back(r);
429323
}
430324

431-
auto newDimSizesInfo = getDimSizesOfTensor(rewriter, op, self, dims);
325+
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims);
432326
if (failed(newDimSizesInfo))
433327
return rewriter.notifyMatchFailure(
434328
op, "failed to get dimension sizes of the input");
@@ -471,7 +365,7 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
471365
SmallVector<int64_t, 4> dims(rank);
472366
std::iota(dims.begin(), dims.end(), 0);
473367
dims.erase(dims.begin() + dim);
474-
auto newDimSizesInfo = getDimSizesOfTensor(rewriter, op, self, dims);
368+
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims);
475369
if (failed(newDimSizesInfo))
476370
return rewriter.notifyMatchFailure(
477371
op, "failed to get dimension sizes of the input");
@@ -496,7 +390,8 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
496390
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
497391
return op->emitError("dim must be a Scalar constant");
498392

499-
auto unsqzTensorInfo = unsqueezeTensor(rewriter, op, adaptor.self(), {dim});
393+
auto unsqzTensorInfo =
394+
mhlo::unsqueezeTensor(rewriter, op, adaptor.self(), {dim});
500395
if (failed(unsqzTensorInfo))
501396
return rewriter.notifyMatchFailure(op,
502397
"failed to create unsqueezed tensor");

0 commit comments

Comments
 (0)