Skip to content

Commit 057fc8e

Browse files
committed
[ODS] Use Adaptor Trait for Shaped Type Inference
Author inferReturnTypeComponents methods with the Op Adaptor by using the InferShapedTypeOpAdaptor. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D155243
1 parent 04cc892 commit 057fc8e

File tree

6 files changed

+240
-215
lines changed

6 files changed

+240
-215
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 32 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,7 @@ include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
3232
//===----------------------------------------------------------------------===//
3333
// Operator: argmax
3434
//===----------------------------------------------------------------------===//
35-
def Tosa_ArgMaxOp : Tosa_Op<"argmax", [
36-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
37-
["inferReturnTypeComponents"]>,
38-
Pure]> {
35+
def Tosa_ArgMaxOp : Tosa_Op<"argmax", [InferShapedTypeOpAdaptor, Pure]> {
3936
let summary = "Perform argmax on the input.";
4037

4138
let description = [{
@@ -62,10 +59,7 @@ def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
6259
//===----------------------------------------------------------------------===//
6360
// Operator: avg_pool2d
6461
//===----------------------------------------------------------------------===//
65-
def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [
66-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
67-
["inferReturnTypeComponents"]>,
68-
Pure]> {
62+
def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [InferShapedTypeOpAdaptor, Pure]> {
6963
let summary = "Performs max pooling on the input.";
7064

7165
let description = [{
@@ -95,10 +89,7 @@ def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [
9589
//===----------------------------------------------------------------------===//
9690
// Operator: conv2d
9791
//===----------------------------------------------------------------------===//
98-
def Tosa_Conv2DOp : Tosa_Op<"conv2d", [
99-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
100-
["inferReturnTypeComponents"]>,
101-
Pure]> {
92+
def Tosa_Conv2DOp : Tosa_Op<"conv2d", [InferShapedTypeOpAdaptor, Pure]> {
10293
let summary = "2D Convolution Operator";
10394

10495
let description = [{
@@ -128,10 +119,7 @@ def Tosa_Conv2DOp : Tosa_Op<"conv2d", [
128119
//===----------------------------------------------------------------------===//
129120
// Operator: conv3d
130121
//===----------------------------------------------------------------------===//
131-
def Tosa_Conv3DOp : Tosa_Op<"conv3d", [
132-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
133-
["inferReturnTypeComponents"]>,
134-
Pure]> {
122+
def Tosa_Conv3DOp : Tosa_Op<"conv3d", [InferShapedTypeOpAdaptor, Pure]> {
135123
let summary = "3D Convolution operator";
136124

137125
let description = [{
@@ -160,10 +148,8 @@ def Tosa_Conv3DOp : Tosa_Op<"conv3d", [
160148
//===----------------------------------------------------------------------===//
161149
// Operator: depthwise_conv2d
162150
//===----------------------------------------------------------------------===//
163-
def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
164-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
165-
["inferReturnTypeComponents"]>,
166-
Pure]> {
151+
def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d",
152+
[InferShapedTypeOpAdaptor, Pure]> {
167153
let summary = "Depthwise 2D Convolution operator";
168154

169155
let description = [{
@@ -193,10 +179,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
193179
//===----------------------------------------------------------------------===//
194180
// Operator: fft2d
195181
//===----------------------------------------------------------------------===//
196-
def Tosa_FFT2dOp : Tosa_Op<"fft2d", [
197-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
198-
["inferReturnTypeComponents"]>,
199-
Pure]> {
182+
def Tosa_FFT2dOp : Tosa_Op<"fft2d", [InferShapedTypeOpAdaptor, Pure]> {
200183
let summary = "Performs FFT2D operation on the input.";
201184

202185
let description = [{
@@ -224,9 +207,7 @@ def Tosa_FFT2dOp : Tosa_Op<"fft2d", [
224207
// Operator: fully_connected
225208
//===----------------------------------------------------------------------===//
226209
def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [
227-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
228-
["inferReturnTypeComponents"]>,
229-
Pure]> {
210+
InferShapedTypeOpAdaptor, Pure]> {
230211
let summary = "Fully Connected operator";
231212

232213
let description = [{
@@ -251,10 +232,7 @@ def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [
251232
//===----------------------------------------------------------------------===//
252233
// Operator: matmul
253234
//===----------------------------------------------------------------------===//
254-
def Tosa_MatMulOp : Tosa_Op<"matmul", [
255-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
256-
["inferReturnTypeComponents"]>,
257-
Pure]> {
235+
def Tosa_MatMulOp : Tosa_Op<"matmul", [InferShapedTypeOpAdaptor, Pure]> {
258236
let summary = "Matrix multiplication with bias";
259237

260238
let description = [{
@@ -279,10 +257,7 @@ def Tosa_MatMulOp : Tosa_Op<"matmul", [
279257
//===----------------------------------------------------------------------===//
280258
// Operator: max_pool2d
281259
//===----------------------------------------------------------------------===//
282-
def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [
283-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
284-
["inferReturnTypeComponents"]>,
285-
Pure]> {
260+
def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [InferShapedTypeOpAdaptor, Pure]> {
286261
let summary = "Performs max pooling on the input.";
287262

288263
let description = [{
@@ -310,10 +285,7 @@ def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [
310285
//===----------------------------------------------------------------------===//
311286
// Operator: rfft2d
312287
//===----------------------------------------------------------------------===//
313-
def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [
314-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
315-
["inferReturnTypeComponents"]>,
316-
Pure]> {
288+
def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [InferShapedTypeOpAdaptor, Pure]> {
317289
let summary = "Performs RFFT2D operation on the input.";
318290

319291
let description = [{
@@ -338,10 +310,8 @@ def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [
338310
//===----------------------------------------------------------------------===//
339311
// Operator: transpose_conv2d
340312
//===----------------------------------------------------------------------===//
341-
def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", [
342-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
343-
["inferReturnTypeComponents"]>,
344-
Pure]> {
313+
def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d",
314+
[InferShapedTypeOpAdaptor, Pure]> {
345315
let summary = "Transpose 2D Convolution operator.";
346316

347317
let description = [{
@@ -828,10 +798,7 @@ def Tosa_SubOp : Tosa_ElemWiseBinaryOp<"sub"> {
828798
//===----------------------------------------------------------------------===//
829799
// Operator: table
830800
//===----------------------------------------------------------------------===//
831-
def Tosa_TableOp : Tosa_Op<"table", [
832-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
833-
["inferReturnTypeComponents"]>,
834-
Pure]> {
801+
def Tosa_TableOp : Tosa_Op<"table", [InferShapedTypeOpAdaptor, Pure]> {
835802
let summary = "Table lookup op";
836803

837804
let description = [{
@@ -1214,7 +1181,7 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [
12141181
// Operator: reduce_all
12151182
//===----------------------------------------------------------------------===//
12161183
def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
1217-
InferTensorType, Pure]> {
1184+
InferTensorTypeAdaptor, Pure]> {
12181185
let summary = "Reduce All operator";
12191186

12201187
let description = [{
@@ -1243,7 +1210,7 @@ def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
12431210
// Operator: reduce_any
12441211
//===----------------------------------------------------------------------===//
12451212
def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
1246-
InferTensorType, Pure]> {
1213+
InferTensorTypeAdaptor, Pure]> {
12471214
let summary = "Reduce Any operator";
12481215

12491216
let description = [{
@@ -1272,7 +1239,7 @@ def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
12721239
// Operator: reduce_max
12731240
//===----------------------------------------------------------------------===//
12741241
def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
1275-
InferTensorType, Pure]> {
1242+
InferTensorTypeAdaptor, Pure]> {
12761243
let summary = "Reduce Max operator";
12771244

12781245
let description = [{
@@ -1301,7 +1268,7 @@ def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
13011268
// Operator: reduce_min
13021269
//===----------------------------------------------------------------------===//
13031270
def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
1304-
InferTensorType, Pure]> {
1271+
InferTensorTypeAdaptor, Pure]> {
13051272
let summary = "Reduce Min operator";
13061273

13071274
let description = [{
@@ -1330,7 +1297,7 @@ def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
13301297
// Operator: reduce_prod
13311298
//===----------------------------------------------------------------------===//
13321299
def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
1333-
InferTensorType, Pure]> {
1300+
InferTensorTypeAdaptor, Pure]> {
13341301
let summary = "Reduce Prod operator";
13351302

13361303
let description = [{
@@ -1359,7 +1326,7 @@ def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
13591326
// Operator: reduce_sum
13601327
//===----------------------------------------------------------------------===//
13611328
def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
1362-
InferTensorType, Pure]> {
1329+
InferTensorTypeAdaptor, Pure]> {
13631330
let summary = "Reduce Sum operator";
13641331

13651332
let description = [{
@@ -1393,7 +1360,7 @@ def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
13931360
// Operator: concat
13941361
//===----------------------------------------------------------------------===//
13951362
def Tosa_ConcatOp : Tosa_Op<"concat", [
1396-
InferTensorType, Pure]> {
1363+
InferTensorTypeAdaptor, Pure]> {
13971364
let summary = "Concatenates tensors along one dimension.";
13981365

13991366
let description = [{
@@ -1423,10 +1390,7 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
14231390
//===----------------------------------------------------------------------===//
14241391
// Operator: pad
14251392
//===----------------------------------------------------------------------===//
1426-
def Tosa_PadOp : Tosa_Op<"pad", [
1427-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
1428-
["inferReturnTypeComponents"]>,
1429-
Pure]> {
1393+
def Tosa_PadOp : Tosa_Op<"pad", [InferShapedTypeOpAdaptor, Pure]> {
14301394
let summary = "Pads a tensor with value specified.";
14311395

14321396
let description = [{
@@ -1471,7 +1435,7 @@ def Tosa_PadOp : Tosa_Op<"pad", [
14711435
// Operator: reshape
14721436
//===----------------------------------------------------------------------===//
14731437
def Tosa_ReshapeOp: Tosa_Op<"reshape", [
1474-
InferTensorType, Pure]> {
1438+
InferTensorTypeAdaptor, Pure]> {
14751439
let summary = "Reshape operator";
14761440

14771441
let description = [{
@@ -1528,9 +1492,7 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
15281492
//===----------------------------------------------------------------------===//
15291493
// Operator: slice
15301494
//===----------------------------------------------------------------------===//
1531-
def Tosa_SliceOp: Tosa_Op<"slice", [
1532-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
1533-
["inferReturnTypeComponents"]>, Pure]> {
1495+
def Tosa_SliceOp: Tosa_Op<"slice", [InferShapedTypeOpAdaptor, Pure]> {
15341496
let summary = "Slice operator";
15351497

15361498
let description = [{
@@ -1556,10 +1518,7 @@ def Tosa_SliceOp: Tosa_Op<"slice", [
15561518
//===----------------------------------------------------------------------===//
15571519
// Operator: tile
15581520
//===----------------------------------------------------------------------===//
1559-
def Tosa_TileOp: Tosa_Op<"tile", [
1560-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
1561-
["inferReturnTypeComponents"]>,
1562-
Pure]> {
1521+
def Tosa_TileOp: Tosa_Op<"tile", [InferShapedTypeOpAdaptor, Pure]> {
15631522
let summary = "Tile operator";
15641523

15651524
let description = [{
@@ -1580,10 +1539,7 @@ def Tosa_TileOp: Tosa_Op<"tile", [
15801539
//===----------------------------------------------------------------------===//
15811540
// Operator: transpose
15821541
//===----------------------------------------------------------------------===//
1583-
def Tosa_TransposeOp : Tosa_Op<"transpose", [
1584-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
1585-
["inferReturnTypeComponents"]>,
1586-
Pure]> {
1542+
def Tosa_TransposeOp : Tosa_Op<"transpose", [InferShapedTypeOpAdaptor, Pure]> {
15871543
let summary = "Transpose operator";
15881544

15891545
let description = [{
@@ -1615,10 +1571,7 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [
16151571
//===----------------------------------------------------------------------===//
16161572
// Operator: gather
16171573
//===----------------------------------------------------------------------===//
1618-
def Tosa_GatherOp : Tosa_Op<"gather", [
1619-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
1620-
["inferReturnTypeComponents"]>,
1621-
Pure]> {
1574+
def Tosa_GatherOp : Tosa_Op<"gather", [InferShapedTypeOpAdaptor, Pure]> {
16221575
let summary = "Gather operation,";
16231576

16241577
let description = [{
@@ -1639,10 +1592,7 @@ def Tosa_GatherOp : Tosa_Op<"gather", [
16391592
//===----------------------------------------------------------------------===//
16401593
// Operator: scatter
16411594
//===----------------------------------------------------------------------===//
1642-
def Tosa_ScatterOp : Tosa_Op<"scatter", [
1643-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
1644-
["inferReturnTypeComponents"]>,
1645-
Pure]> {
1595+
def Tosa_ScatterOp : Tosa_Op<"scatter", [InferShapedTypeOpAdaptor, Pure]> {
16461596
let summary = "Scatter operation,";
16471597

16481598
let description = [{
@@ -1669,10 +1619,7 @@ def Tosa_ScatterOp : Tosa_Op<"scatter", [
16691619
//===----------------------------------------------------------------------===//
16701620
// Operator: resize
16711621
//===----------------------------------------------------------------------===//
1672-
def Tosa_ResizeOp : Tosa_Op<"resize", [
1673-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
1674-
["inferReturnTypeComponents"]>,
1675-
Pure]> {
1622+
def Tosa_ResizeOp : Tosa_Op<"resize", [InferShapedTypeOpAdaptor, Pure]> {
16761623

16771624
let summary = "Resize operation, supports various resize/upsample modes";
16781625

@@ -1898,9 +1845,8 @@ def Tosa_CustomOp : Tosa_Op<"custom"> {
18981845
//===----------------------------------------------------------------------===//
18991846
// Further described in docs/Rationale/RationaleTOSADialect.md .
19001847
//===----------------------------------------------------------------------===//
1901-
def Tosa_IfOp : Tosa_Op<"cond_if", [
1902-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
1903-
["inferReturnTypeComponents"]>,
1848+
def Tosa_IfOp : Tosa_Op<"cond_if",
1849+
[InferShapedTypeOpAdaptor,
19041850
SingleBlockImplicitTerminator<"YieldOp">,
19051851
RecursiveMemoryEffects]> {
19061852
let summary = "Conditional if operator";
@@ -1933,8 +1879,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if", [
19331879
//===----------------------------------------------------------------------===//
19341880
def Tosa_WhileOp : Tosa_Op<"while_loop", [
19351881
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
1936-
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
1937-
["inferReturnTypeComponents"]>,
1882+
InferShapedTypeOpAdaptor,
19381883
SingleBlockImplicitTerminator<"YieldOp">,
19391884
RecursiveMemoryEffects]> {
19401885
let summary = "output = input; While (Cond(output)) {output = Body(output)}";

mlir/include/mlir/Interfaces/InferTypeOpInterface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,10 @@ template <typename ConcreteType>
262262
class InferTypeOpAdaptor : public TraitBase<ConcreteType, InferTypeOpAdaptor> {
263263
};
264264

265+
template <typename ConcreteType>
266+
class InferShapedTypeOpAdaptor
267+
: public TraitBase<ConcreteType, InferShapedTypeOpAdaptor> {};
268+
265269
/// Tensor type inference trait that constructs a tensor from the inferred
266270
/// shape and elemental types.
267271
/// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.

0 commit comments

Comments
 (0)