Skip to content

Commit bbb88c5

Browse files
author
Tanyo Kwok
committed
[MHLO] Init MHLO view like op patterns
See RFC: llvm#999 Co-authored-by: Bairen Yi [email protected] Co-authored-by: Jiawei Wu [email protected] Co-authored-by: Tianyou Guo [email protected] Co-authored-by: Xu Yan [email protected] Co-authored-by: Ziheng Jiang [email protected]
1 parent ad283c1 commit bbb88c5

File tree

6 files changed

+265
-1
lines changed

6 files changed

+265
-1
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ endmacro()
3939
option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON)
4040
if(TORCH_MLIR_ENABLE_MHLO)
4141
add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
42+
option(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
43+
"Enable truncate dimension size from i64 to i32(unsafely)" OFF)
44+
if(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32)
45+
add_definitions(-DTORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32)
46+
endif()
4247
endif()
4348

4449
torch_mlir_add_llvm_external_project(

lib/Conversion/TorchToMhlo/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
add_mlir_conversion_library(TorchMLIRTorchToMhlo
22
TorchToMhlo.cpp
33
BasicOp.cpp
4+
ViewLikeOps.cpp
45

56
ADDITIONAL_HEADER_DIRS
67
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo
78

89
DEPENDS
910
MhloDialect
11+
ChloDialect
1012
TorchMLIRConversionPassIncGen
1113

1214
LINK_COMPONENTS
@@ -16,6 +18,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
1618
MLIRIR
1719
MLIRPass
1820
MhloDialect
21+
ChloDialect
1922
TorchMLIRTorchDialect
2023
)
2124

lib/Conversion/TorchToMhlo/PopulatePatterns.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ namespace torch_to_mhlo {
1919
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
2020
RewritePatternSet &patterns,
2121
ConversionTarget &target);
22+
void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter,
23+
RewritePatternSet &patterns,
24+
ConversionTarget &target);
25+
2226

2327
} // namespace torch_to_mhlo
2428
} // namespace torch

lib/Conversion/TorchToMhlo/TorchToMhlo.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "../PassDetail.h"
1313
#include "./PopulatePatterns.h"
14+
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
1415
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
1516
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1617
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -32,6 +33,7 @@ namespace {
3233
class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
3334
public:
3435
void getDependentDialects(DialectRegistry &registry) const override {
36+
registry.insert<chlo::ChloDialect>();
3537
registry.insert<mhlo::MhloDialect>();
3638
registry.insert<tensor::TensorDialect>();
3739
registry.insert<arith::ArithmeticDialect>();
@@ -40,7 +42,7 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
4042
void runOnOperation() override {
4143
MLIRContext *context = &getContext();
4244
ConversionTarget target(*context);
43-
target.addLegalDialect<mhlo::MhloDialect, tensor::TensorDialect,
45+
target.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect, tensor::TensorDialect,
4446
arith::ArithmeticDialect, Torch::TorchDialect>();
4547

4648
TypeConverter typeConverter;
@@ -51,6 +53,9 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
5153

5254
torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
5355
target);
56+
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(typeConverter, patterns,
57+
target);
58+
5459
if (failed(applyPartialConversion(getOperation(), target,
5560
std::move(patterns)))) {
5661
return signalPassFailure();
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// Also available under a BSD-style license. See LICENSE.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
11+
12+
#include "../PassDetail.h"
13+
#include "./PopulatePatterns.h"
14+
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
16+
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
17+
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
18+
#include "torch-mlir/Conversion/Utils/Utils.h"
19+
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
20+
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
21+
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
22+
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
23+
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
24+
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
25+
#include <numeric>
26+
27+
using namespace mlir;
28+
using namespace mlir::torch;
29+
using namespace mlir::torch::Torch;
30+
using namespace mlir::torch::TorchConversion;
31+
32+
33+
namespace {
34+
35+
// This defines a template to construct ops whose legalizations are
36+
// specialized.
37+
template <typename AtenOpT>
38+
class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
39+
public:
40+
using OpConversionPattern<AtenOpT>::OpConversionPattern;
41+
using OpAdaptor = typename AtenOpT::Adaptor;
42+
43+
LogicalResult matchAndRewrite(
44+
AtenOpT op,
45+
OpAdaptor adaptor,
46+
ConversionPatternRewriter& rewriter) const override {
47+
auto rankType =
48+
adaptor.self().getType().template dyn_cast<RankedTensorType>();
49+
if (!rankType)
50+
return op.emitError("Only ranked tensor types are currently supported");
51+
52+
SmallVector<Value, 4> dimSizes;
53+
if (!getAtenViewOpSizes(op, adaptor, rewriter, dimSizes)) {
54+
return op.emitError("Dims size must be a list of Scalar");
55+
}
56+
57+
auto loc = op.getLoc();
58+
auto newRank = dimSizes.size();
59+
if (newRank == 0 || rankType.getRank() == 0) {
60+
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
61+
op,
62+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
63+
op.getType()),
64+
adaptor.self());
65+
return success();
66+
}
67+
68+
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) {
69+
dSize = rewriter.create<ToI64Op>(loc, dSize).getResult();
70+
return dSize;
71+
});
72+
73+
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
74+
// The i64 calculation is much slow than i32 in some platform, such as Nvidia GPU.
75+
// One can truncate i64 to i32 given that dim sizes are unlikely to excceeds
76+
// the range of i32(4GiB)
77+
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) {
78+
// dimSize: cast i64 -> i32
79+
dSize = rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), dSize);
80+
return dSize;
81+
});
82+
#endif
83+
84+
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
85+
rewriter.replaceOpWithNewOp<chlo::DynamicReshapeOp>(
86+
op,
87+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
88+
op.getType()),
89+
adaptor.self(),
90+
mhloShape);
91+
return success();
92+
}
93+
94+
bool getAtenViewOpSizes(
95+
AtenOpT op,
96+
OpAdaptor adaptor,
97+
ConversionPatternRewriter& rewriter,
98+
SmallVector<Value, 4>& dimSizes) const;
99+
};
100+
101+
template <>
102+
bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes(
103+
AtenViewOp op,
104+
OpAdaptor adaptor,
105+
ConversionPatternRewriter& rewriter,
106+
SmallVector<Value, 4>& dimSizes) const {
107+
return getListConstructElements(adaptor.size(), dimSizes);
108+
}
109+
110+
template <>
111+
bool ConvertAtenViewOp<AtenReshapeOp>::getAtenViewOpSizes(
112+
AtenReshapeOp op,
113+
OpAdaptor adaptor,
114+
ConversionPatternRewriter& rewriter,
115+
SmallVector<Value, 4>& dimSizes) const {
116+
return getListConstructElements(adaptor.shape(), dimSizes);
117+
}
118+
119+
} // namespace
120+
121+
void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
122+
TypeConverter &typeConverter, RewritePatternSet &patterns,
123+
ConversionTarget &target) {
124+
MLIRContext *context = patterns.getContext();
125+
126+
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
127+
target.addIllegalOp<AtenOp>(); \
128+
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context);
129+
INSERT_VIEW_OP_PATTERN(AtenViewOp);
130+
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
131+
#undef INSERT_VIEW_OP_PATTERN
132+
}
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @torch.aten.view$view_like(
4+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> {
5+
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
6+
// CHECK: %[[INT:.*]]-1 = torch.constant.int -1
7+
// CHECK: %[[INT224:.*]] = torch.constant.int 224
8+
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT]]-1, %[[INT]]224 : (!torch.int, !torch.int) -> !torch.list<int>
9+
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT]]-1
10+
// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT224]]
11+
// CHECK: %[[T4:.*]] = arith.trunci %[[T2]] : i64 to i32
12+
// CHECK: %[[T5:.*]] = arith.trunci %[[T3]] : i64 to i32
13+
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T4]], %[[T5]] : tensor<2xi32>
14+
// CHECK: %[[T7:.*]] = "chlo.dynamic_reshape"(%[[T0]], %[[T6]]) : (tensor<?x?x?x?xf32>, tensor<2xi32>) -> tensor<?x224xf32>
15+
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x224xf32> -> !torch.vtensor<[?,224],f32>
16+
// CHECK: return %[[T8]] : !torch.vtensor<[?,224],f32>
17+
func.func @torch.aten.view$view_like(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> {
18+
%int-1 = torch.constant.int -1
19+
%int224 = torch.constant.int 224
20+
%0 = torch.prim.ListConstruct %int-1, %int224 : (!torch.int, !torch.int) -> !torch.list<int>
21+
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,224],f32>
22+
return %1 : !torch.vtensor<[?,224],f32>
23+
}
24+
25+
// -----
26+
// CHECK-LABEL: func.func @torch.aten.reshape$view_like(
27+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> {
28+
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?,?],f32> -> tensor<?x?x?x?x?xf32>
29+
// CHECK: %[[INT:.*]]-1 = torch.constant.int -1
30+
// CHECK: %[[INT120:.*]] = torch.constant.int 120
31+
// CHECK: %[[INT4:.*]] = torch.constant.int 4
32+
// CHECK: %[[INT64:.*]] = torch.constant.int 64
33+
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT]]-1, %[[INT]]120, %[[INT]]4, %[[INT]]64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
34+
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT]]-1
35+
// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT120]]
36+
// CHECK: %[[T4:.*]] = torch_c.to_i64 %[[INT4]]
37+
// CHECK: %[[T5:.*]] = torch_c.to_i64 %[[INT64]]
38+
// CHECK: %[[T6:.*]] = arith.trunci %[[T2]] : i64 to i32
39+
// CHECK: %[[T7:.*]] = arith.trunci %[[T3]] : i64 to i32
40+
// CHECK: %[[T8:.*]] = arith.trunci %[[T4]] : i64 to i32
41+
// CHECK: %[[T9:.*]] = arith.trunci %[[T5]] : i64 to i32
42+
// CHECK: %[[T10:.*]] = tensor.from_elements %[[T6]], %[[T7]], %[[T8]], %[[T9]] : tensor<4xi32>
43+
// CHECK: %[[T11:.*]] = "chlo.dynamic_reshape"(%[[T0]], %[[T10]]) : (tensor<?x?x?x?x?xf32>, tensor<4xi32>) -> tensor<?x120x4x64xf32>
44+
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x120x4x64xf32> -> !torch.vtensor<[?,120,4,64],f32>
45+
// CHECK: return %[[T12]] : !torch.vtensor<[?,120,4,64],f32>
46+
func.func @torch.aten.reshape$view_like(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> {
47+
%int-1 = torch.constant.int -1
48+
%int120 = torch.constant.int 120
49+
%int4 = torch.constant.int 4
50+
%int64 = torch.constant.int 64
51+
%0 = torch.prim.ListConstruct %int-1, %int120, %int4, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
52+
%1 = torch.aten.reshape %arg0, %0 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,120,4,64],f32>
53+
return %1 : !torch.vtensor<[?,120,4,64],f32>
54+
}
55+
56+
// -----
57+
// CHECK-LABEL: func.func @torch.aten.view.minus1$view_like(
58+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3,?,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
59+
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3,?,?],f32> -> tensor<2x3x?x?xf32>
60+
// CHECK: %[[INT:.*]]-1 = torch.constant.int -1
61+
// CHECK: %[[INT1:.*]] = torch.constant.int 1
62+
// CHECK: %[[INT0:.*]] = torch.constant.int 0
63+
// CHECK: %[[T1:.*]] = torch.aten.size.int %[[ARG0]], %[[INT0]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int
64+
// CHECK: %[[T2:.*]] = torch.aten.size.int %[[ARG0]], %[[INT1]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int
65+
// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]], %[[INT]]-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
66+
// CHECK: %[[T4:.*]] = torch_c.to_i64 %[[T1]]
67+
// CHECK: %[[T5:.*]] = torch_c.to_i64 %[[T2]]
68+
// CHECK: %[[T6:.*]] = torch_c.to_i64 %[[INT]]-1
69+
// CHECK: %[[T7:.*]] = arith.trunci %[[T4]] : i64 to i32
70+
// CHECK: %[[T8:.*]] = arith.trunci %[[T5]] : i64 to i32
71+
// CHECK: %[[T9:.*]] = arith.trunci %[[T6]] : i64 to i32
72+
// CHECK: %[[T10:.*]] = tensor.from_elements %[[T7]], %[[T8]], %[[T9]] : tensor<3xi32>
73+
// CHECK: %[[T11:.*]] = "chlo.dynamic_reshape"(%[[T0]], %[[T10]]) : (tensor<2x3x?x?xf32>, tensor<3xi32>) -> tensor<2x3x?xf32>
74+
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<2x3x?xf32> -> !torch.vtensor<[2,3,?],f32>
75+
// CHECK: return %[[T12]] : !torch.vtensor<[2,3,?],f32>
76+
func.func @torch.aten.view.minus1$view_like(%arg0: !torch.vtensor<[2,3,?,?],f32>) -> !torch.vtensor<[2,3,?],f32> {
77+
%int-1 = torch.constant.int -1
78+
%int1 = torch.constant.int 1
79+
%int0 = torch.constant.int 0
80+
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int
81+
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int
82+
%2 = torch.prim.ListConstruct %0, %1, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
83+
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[2,3,?,?],f32>, !torch.list<int> -> !torch.vtensor<[2,3,?],f32>
84+
return %3 : !torch.vtensor<[2,3,?],f32>
85+
}
86+
87+
// -----
88+
// CHECK-LABEL: func.func @torch.aten.view.to_rank1$view_like(
89+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
90+
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[],f32> -> tensor<f32>
91+
// CHECK: %[[INT1:.*]] = torch.constant.int 1
92+
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int>
93+
// CHECK: %[[T2:.*]] = "mhlo.reshape"(%[[T0]]) : (tensor<f32>) -> tensor<1xf32>
94+
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
95+
// CHECK: return %[[T3]] : !torch.vtensor<[1],f32>
96+
func.func @torch.aten.view.to_rank1$view_like(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
97+
%int1 = torch.constant.int 1
98+
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
99+
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[1],f32>
100+
return %1 : !torch.vtensor<[1],f32>
101+
}
102+
103+
// -----
104+
// CHECK-LABEL: func.func @torch.aten.view.to_rank0$view_like(
105+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> {
106+
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1],f32> -> tensor<1xf32>
107+
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
108+
// CHECK: %[[T2:.*]] = "mhlo.reshape"(%[[T0]]) : (tensor<1xf32>) -> tensor<f32>
109+
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<f32> -> !torch.vtensor<[],f32>
110+
// CHECK: return %[[T3]] : !torch.vtensor<[],f32>
111+
func.func @torch.aten.view.to_rank0$view_like(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> {
112+
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
113+
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
114+
return %1 : !torch.vtensor<[],f32>
115+
}

0 commit comments

Comments
 (0)