Skip to content

Commit 0c34d7a

Browse files
authored
[mlir][tosa] Require operand/result tensors of at least rank 1 for some operations (#131335)
This commit updates the following operations (operands/results) to be of at least rank 1 such that it aligns with the expectations of the specification: - ARGMAX (input) - REDUCE_ALL (input/output) - REDUCE_ANY (input/output) - REDUCE_MAX (input/output) - REDUCE_MIN (input/output) - REDUCE_PRODUCT (input/output) - REDUCE_SUM (input/output) - CONCAT (each input in input1/output) - PAD (input1/output) - REVERSE (input1/output) - SLICE (input1/output) - TILE (input1/output) - TRANSPOSE (input1/output) In addition to this change, PAD has been updated to allow unranked tensors for input1/output, inline with other operations.
1 parent 5c73c5c commit 0c34d7a

File tree

5 files changed

+115
-57
lines changed

5 files changed

+115
-57
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
4141
}];
4242

4343
let arguments = (ins
44-
Tosa_Tensor: $input,
44+
Tosa_TensorAtLeast1D: $input,
4545
I32Attr: $axis,
4646
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
4747
);
@@ -1629,12 +1629,12 @@ def Tosa_ReduceAllOp : Tosa_InferTensorTypeOp<"reduce_all"> {
16291629
}];
16301630

16311631
let arguments = (ins
1632-
Tosa_Tensor:$input,
1632+
Tosa_TensorAtLeast1D:$input,
16331633
I32Attr:$axis
16341634
);
16351635

16361636
let results = (outs
1637-
Tosa_Tensor:$output
1637+
Tosa_TensorAtLeast1D:$output
16381638
);
16391639

16401640
list<Availability> availability = [
@@ -1668,12 +1668,12 @@ def Tosa_ReduceAnyOp : Tosa_InferTensorTypeOp<"reduce_any"> {
16681668
}];
16691669

16701670
let arguments = (ins
1671-
Tosa_Tensor:$input,
1671+
Tosa_TensorAtLeast1D:$input,
16721672
I32Attr:$axis
16731673
);
16741674

16751675
let results = (outs
1676-
Tosa_Tensor:$output
1676+
Tosa_TensorAtLeast1D:$output
16771677
);
16781678

16791679
list<Availability> availability = [
@@ -1707,13 +1707,13 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
17071707
}];
17081708

17091709
let arguments = (ins
1710-
Tosa_Tensor:$input,
1710+
Tosa_TensorAtLeast1D:$input,
17111711
I32Attr:$axis,
17121712
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
17131713
);
17141714

17151715
let results = (outs
1716-
Tosa_Tensor:$output
1716+
Tosa_TensorAtLeast1D:$output
17171717
);
17181718

17191719
list<Availability> availability = [
@@ -1748,13 +1748,13 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
17481748
}];
17491749

17501750
let arguments = (ins
1751-
Tosa_Tensor:$input,
1751+
Tosa_TensorAtLeast1D:$input,
17521752
I32Attr:$axis,
17531753
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
17541754
);
17551755

17561756
let results = (outs
1757-
Tosa_Tensor:$output
1757+
Tosa_TensorAtLeast1D:$output
17581758
);
17591759

17601760
list<Availability> availability = [
@@ -1789,12 +1789,12 @@ def Tosa_ReduceProductOp : Tosa_InferTensorTypeOp<"reduce_product"> {
17891789
}];
17901790

17911791
let arguments = (ins
1792-
Tosa_Tensor:$input,
1792+
Tosa_TensorAtLeast1D:$input,
17931793
I32Attr:$axis
17941794
);
17951795

17961796
let results = (outs
1797-
Tosa_Tensor:$output
1797+
Tosa_TensorAtLeast1D:$output
17981798
);
17991799

18001800
list<Availability> availability = [
@@ -1828,12 +1828,12 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
18281828
}];
18291829

18301830
let arguments = (ins
1831-
Tosa_Tensor:$input,
1831+
Tosa_TensorAtLeast1D:$input,
18321832
I32Attr:$axis
18331833
);
18341834

18351835
let results = (outs
1836-
Tosa_Tensor:$output
1836+
Tosa_TensorAtLeast1D:$output
18371837
);
18381838

18391839
list<Availability> availability = [
@@ -1872,12 +1872,12 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
18721872
}];
18731873

18741874
let arguments = (ins
1875-
Variadic<Tosa_Tensor>:$input1,
1875+
Variadic<Tosa_TensorAtLeast1D>:$input1,
18761876
I32Attr:$axis
18771877
);
18781878

18791879
let results = (outs
1880-
Tosa_Tensor:$output
1880+
Tosa_TensorAtLeast1D:$output
18811881
);
18821882

18831883
list<Availability> availability = [
@@ -1923,13 +1923,13 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
19231923
}];
19241924

19251925
let arguments = (ins
1926-
Tosa_RankedTensor:$input1,
1926+
Tosa_TensorAtLeast1D:$input1,
19271927
Tosa_Shape:$padding,
19281928
Tosa_ScalarTensor:$pad_const
19291929
);
19301930

19311931
let results = (outs
1932-
Tosa_RankedTensor:$output
1932+
Tosa_TensorAtLeast1D:$output
19331933
);
19341934

19351935
list<Availability> availability = [
@@ -1996,12 +1996,12 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
19961996
}];
19971997

19981998
let arguments = (ins
1999-
Tosa_Tensor:$input1,
1999+
Tosa_TensorAtLeast1D:$input1,
20002000
I32Attr:$axis
20012001
);
20022002

20032003
let results = (outs
2004-
Tosa_Tensor:$output
2004+
Tosa_TensorAtLeast1D:$output
20052005
);
20062006

20072007
list<Availability> availability = [
@@ -2028,13 +2028,13 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
20282028
}];
20292029

20302030
let arguments = (ins
2031-
Tosa_Tensor:$input1,
2031+
Tosa_TensorAtLeast1D:$input1,
20322032
Tosa_Shape:$start,
20332033
Tosa_Shape:$size
20342034
);
20352035

20362036
let results = (outs
2037-
Tosa_Tensor:$output
2037+
Tosa_TensorAtLeast1D:$output
20382038
);
20392039

20402040
list<Availability> availability = [
@@ -2058,11 +2058,11 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
20582058
}];
20592059

20602060
let arguments = (ins
2061-
Tosa_Tensor:$input1,
2061+
Tosa_TensorAtLeast1D:$input1,
20622062
Tosa_Shape:$multiples);
20632063

20642064
let results = (outs
2065-
Tosa_Tensor:$output
2065+
Tosa_TensorAtLeast1D:$output
20662066
);
20672067

20682068
list<Availability> availability = [
@@ -2093,12 +2093,12 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
20932093
}];
20942094

20952095
let arguments = (ins
2096-
Tosa_Tensor:$input1,
2096+
Tosa_TensorAtLeast1D:$input1,
20972097
DenseI32ArrayAttr:$perms
20982098
);
20992099

21002100
let results = (
2101-
outs Tosa_Tensor:$output
2101+
outs Tosa_TensorAtLeast1D:$output
21022102
);
21032103

21042104
list<Availability> availability = [

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ def AllDimensionsAreSizeOne : And<[
101101
IsRankedTensorTypePred,
102102
CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v == 1; })">]>;
103103

104+
def AtLeastRankOne : And<[
105+
IsRankedTensorTypePred,
106+
CPred<"::llvm::cast<::mlir::RankedTensorType>($_self).getRank() >= 1">]>;
107+
104108
class TosaTensorOf<
105109
list<Type> allowedTypes, string summary = "tosa-conformant tensor">
106110
: TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;
@@ -183,6 +187,9 @@ def Tosa_TensorUpto4D : AnyTypeOf<[
183187
def Tosa_Int32TensorUpto4D : AnyTypeOf<[
184188
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
185189

190+
def Tosa_TensorAtLeast1D : AnyTypeOf<[
191+
Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">;
192+
186193
//===----------------------------------------------------------------------===//
187194
// Generic scalar, vector, or tensor of a particular type.
188195
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,8 +1354,13 @@ LogicalResult tosa::PadOp::verify() {
13541354
}
13551355
}
13561356

1357-
RankedTensorType inputType = getInput1().getType();
1358-
RankedTensorType outputType = getOutput().getType();
1357+
RankedTensorType inputType =
1358+
llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1359+
RankedTensorType outputType =
1360+
llvm::dyn_cast<RankedTensorType>(getOutput().getType());
1361+
if (!inputType || !outputType)
1362+
return success();
1363+
13591364
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
13601365

13611366
if (inputType.getRank() != outputType.getRank())

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -915,29 +915,6 @@ func.func @fold_abs_abs(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
915915

916916
// -----
917917

918-
// CHECK-LABEL: @fold_reduce_rank_zero
919-
func.func @fold_reduce_rank_zero() {
920-
// CHECK-NOT: tosa.reduce_min
921-
// CHECK-NOT: tosa.reverse
922-
%0 = tensor.empty() : tensor<i32>
923-
%1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
924-
%2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
925-
return
926-
}
927-
928-
// -----
929-
930-
// CHECK-LABEL: @fold_tile_rank_zero
931-
func.func nested @fold_tile_rank_zero() -> tensor<i32> {
932-
// CHECK-NOT: tosa.tile
933-
%0 = tensor.empty() : tensor<i32>
934-
%cst = tosa.const_shape { values = dense<> : tensor<0xindex> } : () -> !tosa.shape<0>
935-
%1 = tosa.tile %0, %cst : (tensor<i32>, !tosa.shape<0>) -> tensor<i32>
936-
return %1 : tensor<i32>
937-
}
938-
939-
// -----
940-
941918
// CHECK-LABEL: @reshape_quant_nofold
942919
// check that segfault is fixed
943920
func.func @reshape_quant_nofold() -> tensor<1x1x1x1xi32> {
@@ -1015,12 +992,12 @@ func.func @cast_quant_nofold() -> tensor<!quant.uniform<i8:f32, 3.07574046018999
1015992
// -----
1016993

1017994
// CHECK-LABEL: @reverse_quant_fold
1018-
func.func @reverse_quant_fold() -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>> {
1019-
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<0> : tensor<i8>}> : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
995+
func.func @reverse_quant_fold() -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>> {
996+
// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
1020997
// CHECK: return %[[CST]]
1021-
%0 = "tosa.const"() {values = dense<0> : tensor<i8>} : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
1022-
%1 = "tosa.reverse"(%0) { axis = 0 : i32 } : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
1023-
return %1 : tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
998+
%0 = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
999+
%1 = "tosa.reverse"(%0) { axis = 0 : i32 } : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
1000+
return %1 : tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
10241001
}
10251002

10261003
// -----

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,9 +452,9 @@ func.func @test_reduce_sum_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
452452

453453
// -----
454454

455-
func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<i32>) -> () {
455+
func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<1xi32>) -> () {
456456
// expected-error@+1 {{'tosa.reduce_min' op expect output tensor rank to be equal to input tensor rank}}
457-
%0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
457+
%0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<1xi32>) -> tensor<1x10xi32>
458458
return
459459
}
460460

@@ -1852,3 +1852,72 @@ func.func @test_maxpool2d_unexpected_output_width(%arg0: tensor<1x32x32x8xf32>)
18521852
(tensor<1x32x32x8xf32>) -> tensor<1x32x2x8xf32>
18531853
return %0 : tensor<1x32x2x8xf32>
18541854
}
1855+
1856+
// -----
1857+
1858+
func.func @test_scalar_argmax(%arg0: tensor<i32>) -> tensor<i32> {
1859+
// expected-error@+1 {{'tosa.argmax' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<i32>'}}
1860+
%0 = tosa.argmax %arg0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
1861+
return %0 : tensor<i32>
1862+
}
1863+
1864+
// -----
1865+
1866+
func.func @test_scalar_reduce_all(%arg0: tensor<i1>) -> tensor<i1> {
1867+
// expected-error@+1 {{'tosa.reduce_all' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<i1>'}}
1868+
%0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<i1>) -> tensor<i1>
1869+
return %0 : tensor<i1>
1870+
}
1871+
1872+
// -----
1873+
1874+
func.func @test_scalar_inputs_concat(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<2xf32> {
1875+
// expected-error@+1 {{'tosa.concat' op operand #0 must be variadic of tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
1876+
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<f32>, tensor<f32>) -> tensor<2xf32>
1877+
return %0 : tensor<2xf32>
1878+
}
1879+
1880+
// -----
1881+
1882+
func.func @test_scalar_pad(%arg0: tensor<f32>) -> tensor<f32> {
1883+
%0 = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
1884+
%padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
1885+
// expected-error@+1 {{'tosa.pad' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
1886+
%1 = tosa.pad %arg0, %padding, %0 : (tensor<f32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<f32>
1887+
return %1 : tensor<f32>
1888+
}
1889+
1890+
// -----
1891+
1892+
func.func @test_scalar_reverse(%arg0: tensor<f32>) -> tensor<f32> {
1893+
// expected-error@+1 {{'tosa.reverse' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
1894+
%0 = tosa.reverse %arg0 {axis = 0: i32} : (tensor<f32>) -> tensor<f32>
1895+
return %arg0 : tensor<f32>
1896+
}
1897+
1898+
// -----
1899+
1900+
func.func @test_scalar_slice(%arg0: tensor<f32>) -> tensor<f32> {
1901+
%0 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0>
1902+
%1 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0>
1903+
// expected-error@+1 {{'tosa.slice' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
1904+
%2 = tosa.slice %arg0, %0, %1 : (tensor<f32>, !tosa.shape<0>, !tosa.shape<0>) -> tensor<f32>
1905+
return %2 : tensor<f32>
1906+
}
1907+
1908+
// -----
1909+
1910+
func.func @test_scalar_tile(%arg0: tensor<f32>) -> tensor<*xf32> {
1911+
%cst = tosa.const_shape { values = dense<[]> : tensor<0xindex> } : () -> !tosa.shape<0>
1912+
// expected-error@+1 {{'tosa.tile' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
1913+
%0 = tosa.tile %arg0, %cst: (tensor<f32>, !tosa.shape<0>) -> tensor<*xf32>
1914+
return %0 : tensor<*xf32>
1915+
}
1916+
1917+
// -----
1918+
1919+
func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
1920+
// expected-error@+1 {{'tosa.transpose' op result #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
1921+
%1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<*xf32>) -> tensor<f32>
1922+
return %1 : tensor<f32>
1923+
}

0 commit comments

Comments
 (0)