10
10
#include " torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
11
11
12
12
#include " ../PassDetail.h"
13
+ #include " ./MhloLegalizeUtils.h"
13
14
#include " ./PopulatePatterns.h"
15
+ #include " mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
14
16
#include " mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15
17
#include " mlir/Dialect/Tensor/IR/Tensor.h"
16
- #include " mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
17
18
#include " torch-mlir/Conversion/Utils/Utils.h"
18
19
#include " torch-mlir/Dialect/Torch/IR/TorchDialect.h"
19
20
#include " torch-mlir/Dialect/Torch/IR/TorchOps.h"
@@ -28,61 +29,7 @@ using namespace mlir::torch;
28
29
using namespace mlir ::torch::Torch;
29
30
using namespace mlir ::torch::TorchConversion;
30
31
31
- #ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
32
- static constexpr size_t kMhloDimSizeBits = 32 ;
33
- #else
34
- static constexpr size_t kMhloDimSizeBits = 64 ;
35
- #endif
36
-
37
32
namespace {
38
-
39
- SmallVector<size_t > toPositiveDims (ArrayRef<int64_t > dims, int64_t rank) {
40
- SmallVector<size_t > posDims;
41
- posDims.reserve (rank);
42
- std::transform (
43
- dims.begin (), dims.end (), std::back_inserter (posDims),
44
- [rank](int64_t d) -> size_t { return toPositiveDim (d, rank); });
45
- return posDims;
46
- }
47
-
48
- FailureOr<SmallVector<Value, 4 >>
49
- getDimSizesOfTensor (PatternRewriter &rewriter, Operation *op, Value value,
50
- ArrayRef<int64_t > inpDims) {
51
- auto valueTy = value.getType ().dyn_cast <RankedTensorType>();
52
- if (!valueTy) {
53
- return rewriter.notifyMatchFailure (
54
- op, " getDimSizesOfTensor(): the input is not a ranked tensor" );
55
- }
56
-
57
- auto rank = valueTy.getRank ();
58
- auto dims = toPositiveDims (inpDims, rank);
59
- SmallVector<Value, 4 > dimSizes;
60
- dimSizes.reserve (dims.size ());
61
-
62
- auto loc = op->getLoc ();
63
- for (auto d : dims) {
64
- dimSizes.emplace_back (rewriter.create <arith::IndexCastOp>(
65
- loc, rewriter.getIntegerType (kMhloDimSizeBits ),
66
- rewriter.create <tensor::DimOp>(loc, value, d)));
67
- }
68
- return dimSizes;
69
- }
70
-
71
- FailureOr<SmallVector<Value, 4 >>
72
- getDimSizesOfTensor (PatternRewriter &rewriter, Operation *op, Value value) {
73
- auto valueTy = value.getType ().dyn_cast <RankedTensorType>();
74
- if (!valueTy) {
75
- return rewriter.notifyMatchFailure (
76
- op, " getDimSizesOfTensor(): the input is not a ranked tensor" );
77
- }
78
-
79
- auto rank = valueTy.getRank ();
80
- // Get int vector [0, 1, ..., rank-1]
81
- std::vector<int64_t > dims (rank);
82
- std::iota (dims.begin (), dims.end (), 0 );
83
- return getDimSizesOfTensor (rewriter, op, value, dims);
84
- }
85
-
86
33
// A dimension index from torch.dialect might outside the range [0, dimSize].
87
34
// The function is used to normalize the input index into the range.
88
35
Value getNormalizedDimSizeInternal (PatternRewriter &rewriter, Operation *op,
@@ -111,7 +58,7 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
111
58
ArrayRef<Value> dimSizes) {
112
59
auto loc = op->getLoc ();
113
60
// startIndex & endIndex has been normailized into range [0, dSize]
114
- Type intType = rewriter.getIntegerType (kMhloDimSizeBits );
61
+ Type intType = rewriter.getIntegerType (mhlo:: kMhloDimSizeBits );
115
62
Value zero = rewriter.create <arith::ConstantOp>(
116
63
loc, rewriter.getIntegerAttr (intType, 0 ));
117
64
Value one = rewriter.create <arith::ConstantOp>(
@@ -192,14 +139,14 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
192
139
loc, rewriter.getIntegerAttr (rewriter.getI64Type (), 1 ));
193
140
194
141
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
195
- auto i32Type = rewriter.getIntegerType (kMhloDimSizeBits );
142
+ auto i32Type = rewriter.getIntegerType (mhlo:: kMhloDimSizeBits );
196
143
normStartIndex =
197
144
rewriter.create <arith::TruncIOp>(loc, i32Type, normStartIndex);
198
145
normEndIndex = rewriter.create <arith::TruncIOp>(loc, i32Type, normEndIndex);
199
146
step = rewriter.create <arith::TruncIOp>(loc, i32Type, step);
200
147
#endif
201
148
FailureOr<SmallVector<Value, 4 >> dimSizesInfo =
202
- getDimSizesOfTensor (rewriter, op, input);
149
+ mhlo:: getDimSizesOfTensor (rewriter, op, input);
203
150
if (failed (dimSizesInfo))
204
151
return rewriter.notifyMatchFailure (
205
152
op, " failed to get dimension sizes of the input" );
@@ -305,7 +252,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
305
252
});
306
253
#endif
307
254
308
- Type intType = rewriter.getIntegerType (kMhloDimSizeBits );
255
+ Type intType = rewriter.getIntegerType (mhlo:: kMhloDimSizeBits );
309
256
Value numel = rewriter.create <arith::ConstantOp>(
310
257
loc, rewriter.getIntegerAttr (intType, 1 ));
311
258
for (auto d : dimSizes) {
@@ -350,59 +297,6 @@ bool ConvertAtenViewOp<AtenReshapeOp>::getAtenViewOpSizes(
350
297
return getListConstructElements (adaptor.shape (), dimSizes);
351
298
}
352
299
353
- FailureOr<Value> unsqueezeTensor (PatternRewriter &rewriter, Operation *op,
354
- Value tensor,
355
- ArrayRef<int64_t > inputUnsqzDims) {
356
- // Returns a new tensor with dims of size 1 inserted at the specified
357
- // position.
358
- //
359
- // The position indices (must be high to low dimension number of the returned
360
- // tensor) are specified with unsqzDims. Indices must be in-order, and in
361
- // range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1,
362
- // 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not.
363
- auto dimSizesInfo = getDimSizesOfTensor (rewriter, op, tensor);
364
- if (failed (dimSizesInfo))
365
- return rewriter.notifyMatchFailure (
366
- op, " failed to get dimension sizes of the input" );
367
-
368
- auto dimSizes = *dimSizesInfo;
369
- auto rank = dimSizes.size ();
370
- size_t newRank = rank + inputUnsqzDims.size ();
371
- auto unsqzDims = toPositiveDims (inputUnsqzDims, newRank);
372
- for (size_t k = 0 , sz = unsqzDims.size (); k < sz; ++k)
373
- if (k > 1 && unsqzDims[k] <= unsqzDims[k - 1 ])
374
- return rewriter.notifyMatchFailure (
375
- op, " unsqueeze dimensions must be specified in order" );
376
-
377
- auto loc = op->getLoc ();
378
- auto rankTy = tensor.getType ().dyn_cast <RankedTensorType>();
379
- auto oldShape = rankTy.getShape ();
380
- Type intType = rewriter.getIntegerType (kMhloDimSizeBits );
381
- auto one = rewriter.create <arith::ConstantOp>(
382
- loc, rewriter.getIntegerAttr (intType, 1 ));
383
-
384
- std::vector<Value> newDimSizes;
385
- std::vector<int64_t > newShape;
386
- newDimSizes.reserve (newRank);
387
- newShape.reserve (newRank);
388
- for (size_t k = 0 , i = 0 , j = 0 ; k < newRank; ++k) {
389
- if (j < unsqzDims.size () && unsqzDims[j] == k) {
390
- newDimSizes.push_back (one);
391
- newShape.push_back (1 );
392
- j++;
393
- } else {
394
- newDimSizes.push_back (dimSizes[i]);
395
- newShape.push_back (oldShape[i]);
396
- i++;
397
- }
398
- }
399
-
400
- auto outTy = RankedTensorType::get (newShape, rankTy.getElementType ());
401
- auto mhloShape = rewriter.create <tensor::FromElementsOp>(loc, newDimSizes);
402
- return rewriter.create <mhlo::DynamicReshapeOp>(loc, outTy, tensor, mhloShape)
403
- .getResult ();
404
- }
405
-
406
300
template <>
407
301
LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
408
302
AtenSqueezeOp op, OpAdaptor adaptor,
@@ -428,7 +322,7 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
428
322
dims.push_back (r);
429
323
}
430
324
431
- auto newDimSizesInfo = getDimSizesOfTensor (rewriter, op, self, dims);
325
+ auto newDimSizesInfo = mhlo:: getDimSizesOfTensor (rewriter, op, self, dims);
432
326
if (failed (newDimSizesInfo))
433
327
return rewriter.notifyMatchFailure (
434
328
op, " failed to get dimension sizes of the input" );
@@ -471,7 +365,7 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
471
365
SmallVector<int64_t , 4 > dims (rank);
472
366
std::iota (dims.begin (), dims.end (), 0 );
473
367
dims.erase (dims.begin () + dim);
474
- auto newDimSizesInfo = getDimSizesOfTensor (rewriter, op, self, dims);
368
+ auto newDimSizesInfo = mhlo:: getDimSizesOfTensor (rewriter, op, self, dims);
475
369
if (failed (newDimSizesInfo))
476
370
return rewriter.notifyMatchFailure (
477
371
op, " failed to get dimension sizes of the input" );
@@ -496,7 +390,8 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
496
390
if (!matchPattern (op.dim (), m_TorchConstantInt (&dim)))
497
391
return op->emitError (" dim must be a Scalar constant" );
498
392
499
- auto unsqzTensorInfo = unsqueezeTensor (rewriter, op, adaptor.self (), {dim});
393
+ auto unsqzTensorInfo =
394
+ mhlo::unsqueezeTensor (rewriter, op, adaptor.self (), {dim});
500
395
if (failed (unsqzTensorInfo))
501
396
return rewriter.notifyMatchFailure (op,
502
397
" failed to create unsqueezed tensor" );
0 commit comments