Skip to content

Commit e046f20

Browse files
authored
[mlir][tosa] Enhance error_if and verify checks for RESCALE Op (llvm#137021)
* add verifier for rank-0 input with per-channel * add checkErrorIfRescale to tosa validation pass that align with TOSAv1.0 * add LIT tests Signed-off-by: Peng Sun <[email protected]>
1 parent 4955c3c commit e046f20

File tree

5 files changed

+207
-1
lines changed

5 files changed

+207
-1
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3206,6 +3206,12 @@ LogicalResult RescaleOp::verify() {
32063206
// otherwise numChannel is dimension in input shape's last axis
32073207
int64_t numChannels = 1;
32083208
if (getPerChannel()) {
3209+
if (inputType.getRank() < 1) {
3210+
emitOpError("requires input to be at least rank 1 when per_channel is "
3211+
"true, but got rank ")
3212+
<< inputType.getRank();
3213+
return failure();
3214+
}
32093215
numChannels = inputType.getDimSize(inputType.getRank() - 1);
32103216
}
32113217

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1033,8 +1033,88 @@ bool checkErrorIfTable(Operation *op) {
10331033
return true;
10341034
}
10351035

1036+
bool checkErrorIfRescale(Operation *op) {
1037+
auto rescale = dyn_cast<tosa::RescaleOp>(op);
1038+
if (!rescale)
1039+
return true;
1040+
1041+
auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
1042+
auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
1043+
if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
1044+
!outputType.getElementType().isInteger())
1045+
return true;
1046+
1047+
auto inElemType = inputType.getElementType();
1048+
auto outElemType = outputType.getElementType();
1049+
auto inWidth = inElemType.getIntOrFloatBitWidth();
1050+
auto outWidth = outElemType.getIntOrFloatBitWidth();
1051+
1052+
bool inputUnsigned = rescale.getInputUnsigned();
1053+
bool outputUnsigned = rescale.getOutputUnsigned();
1054+
1055+
bool scale32 = rescale.getScale32();
1056+
auto roundingMode = rescale.getRoundingMode();
1057+
1058+
// ERROR_IF(scale32 && is_same<in_t,i48_t>())
1059+
if (scale32 && inWidth == 48) {
1060+
op->emitOpError() << "scale32 is not allowed with 48-bit input.";
1061+
return false;
1062+
}
1063+
1064+
// ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
1065+
if (!scale32 && roundingMode == "DOUBLE_ROUND") {
1066+
op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true.";
1067+
return false;
1068+
}
1069+
1070+
// ERROR_IF(input_unsigned && output_unsigned)
1071+
if (inputUnsigned && outputUnsigned) {
1072+
op->emitOpError() << "input and output cannot be both unsigned.";
1073+
return false;
1074+
}
1075+
1076+
// ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
1077+
if (outWidth == 32 && inputUnsigned) {
1078+
op->emitOpError() << "i32 output type is not allowed with unsigned input.";
1079+
return false;
1080+
}
1081+
1082+
// ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
1083+
if (inWidth == 32 && outputUnsigned) {
1084+
op->emitOpError() << "i32 input type is not allowed with unsigned output.";
1085+
return false;
1086+
}
1087+
1088+
// ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
1089+
if (inWidth == 48 && outputUnsigned) {
1090+
op->emitOpError() << "i48 input type is not allowed with unsigned output.";
1091+
return false;
1092+
}
1093+
1094+
// ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
1095+
if (inWidth == 48 && inputUnsigned) {
1096+
op->emitOpError() << "i48 input type cannot be unsigned.";
1097+
return false;
1098+
}
1099+
1100+
// ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
1101+
if (inWidth == 32 && inputUnsigned) {
1102+
op->emitOpError() << "i32 input type cannot be unsigned.";
1103+
return false;
1104+
}
1105+
1106+
// ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
1107+
if (outWidth == 32 && outputUnsigned) {
1108+
op->emitOpError() << "i32 output type cannot be unsigned.";
1109+
return false;
1110+
}
1111+
1112+
return true;
1113+
}
1114+
10361115
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
1037-
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) || !checkErrorIfTable(op))
1116+
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
1117+
!checkErrorIfTable(op) || !checkErrorIfRescale(op))
10381118
return failure();
10391119
return success();
10401120
}

mlir/test/Dialect/Tosa/error_if_check.mlir

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,99 @@ func.func @test_i8_table_size(%arg0: tensor<2x64xi8>, %arg1: tensor<513xi8>) ->
129129
%0 = tosa.table %arg0, %arg1 : (tensor<2x64xi8>, tensor<513xi8>) -> tensor<2x64xi8>
130130
return %0 : tensor<2x64xi8>
131131
}
132+
133+
// -----
134+
// CHECK-LABEL: test_error_scale32_with_i48
135+
func.func @test_error_scale32_with_i48(%arg0: tensor<1xi48>) -> tensor<1xi8> {
136+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
137+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
138+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
139+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
140+
// expected-error@+1 {{'tosa.rescale' op scale32 is not allowed with 48-bit input}}
141+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi48>, tensor<1xi32>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
142+
return %0 : tensor<1xi8>
143+
}
144+
145+
// -----
146+
// CHECK-LABEL: test_error_input_output_unsigned
147+
func.func @test_error_input_output_unsigned(%arg0: tensor<1xi8>) -> tensor<1xi16> {
148+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
149+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
150+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
151+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
152+
// expected-error@+1 {{'tosa.rescale' op input and output cannot be both unsigned}}
153+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
154+
return %0 : tensor<1xi16>
155+
}
156+
157+
// -----
158+
// CHECK-LABEL: test_error_i32_output_unsigned_input
159+
func.func @test_error_i32_output_unsigned_input(%arg0: tensor<1xi8>) -> tensor<1xi32> {
160+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
161+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
162+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
163+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
164+
// expected-error@+1 {{'tosa.rescale' op i32 output type is not allowed with unsigned input}}
165+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
166+
return %0 : tensor<1xi32>
167+
}
168+
169+
// -----
170+
// CHECK-LABEL: test_error_i32_input_unsigned_output
171+
func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<1xi8> {
172+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
173+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
174+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
175+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
176+
// expected-error@+1 {{'tosa.rescale' op i32 input type is not allowed with unsigned output}}
177+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
178+
return %0 : tensor<1xi8>
179+
}
180+
181+
// -----
182+
// CHECK-LABEL: test_error_i48_input_unsigned_output
183+
func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<1xi8> {
184+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
185+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
186+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
187+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
188+
// expected-error@+1 {{'tosa.rescale' op i48 input type is not allowed with unsigned output}}
189+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
190+
return %0 : tensor<1xi8>
191+
}
192+
193+
// -----
194+
// CHECK-LABEL: test_error_i48_unsigned_input
195+
func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<1xi8> {
196+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
197+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
198+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
199+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
200+
// expected-error@+1 {{'tosa.rescale' op i48 input type cannot be unsigned}}
201+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
202+
return %0 : tensor<1xi8>
203+
}
204+
205+
// -----
206+
// CHECK-LABEL: test_error_i32_unsigned_input
207+
func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<1xi8> {
208+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
209+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
210+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
211+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
212+
// expected-error@+1 {{'tosa.rescale' op i32 input type cannot be unsigned}}
213+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
214+
return %0 : tensor<1xi8>
215+
}
216+
217+
// -----
218+
// CHECK-LABEL: test_error_i32_unsigned_output
219+
func.func @test_error_i32_unsigned_output(%arg0: tensor<1xi8>) -> tensor<1xi32> {
220+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
221+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
222+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
223+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
224+
// expected-error@+1 {{'tosa.rescale' op i32 output type cannot be unsigned}}
225+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
226+
return %0 : tensor<1xi32>
227+
}

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,6 +1638,18 @@ func.func @test_rescale_invalid_non_perchannel_shift_shape(%arg0: tensor<13x21x3
16381638
return %0 : tensor<13x21x3xi16>
16391639
}
16401640

1641+
// -----
1642+
// CHECK-LABEL: test_error_double_round_without_scale32
1643+
func.func @test_error_double_round_without_scale32(%arg0: tensor<1xi8>) -> tensor<1xi16> {
1644+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
1645+
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
1646+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
1647+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
1648+
// expected-error@+1 {{'tosa.rescale' op DOUBLE_ROUND is only allowed with scale32=true}}
1649+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
1650+
return %0 : tensor<1xi16>
1651+
}
1652+
16411653
// -----
16421654
// CHECK-LABEL: test_matmul_a_zp_same_element_type
16431655
func.func @test_matmul_a_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {

mlir/test/Dialect/Tosa/verifier.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,3 +358,15 @@ func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?x
358358
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
359359
return %0 : tensor<2x?xf32>
360360
}
361+
362+
// -----
363+
364+
func.func @test_error_scalar_input_with_per_channel(%arg0: tensor<i8>) -> tensor<i16> {
365+
%multiplier = "tosa.const"() {values = dense<4> : tensor<1xi32> } : () -> tensor<1xi32>
366+
%shift = "tosa.const"() {values = dense<2> : tensor<1xi8> } : () -> tensor<1xi8>
367+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
368+
%output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
369+
// expected-error@+1 {{'tosa.rescale' op requires input to be at least rank 1 when per_channel is true, but got rank 0}}
370+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<i8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<i16>
371+
return %0 : tensor<i16>
372+
}

0 commit comments

Comments
 (0)