Skip to content

Commit 04e628a

Browse files
psunnIanWood1
authored andcommitted
[mlir][tosa] Enhance verify checks for PAD Op (llvm#137177)
* add padding shape verification * add and update LIT test Signed-off-by: Peng Sun <[email protected]>
1 parent 9cedd52 commit 04e628a

File tree

6 files changed

+118
-36
lines changed

6 files changed

+118
-36
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,15 +1524,51 @@ LogicalResult tosa::PadOp::verify() {
15241524
if (!inputType || !outputType)
15251525
return success();
15261526

1527-
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
1527+
auto inputRank = inputType.getRank();
1528+
auto outputRank = outputType.getRank();
1529+
if (inputRank != outputRank)
1530+
return emitOpError() << "expect same input and output tensor rank, but got "
1531+
<< "inputRank: " << inputRank
1532+
<< ", outputRank: " << outputRank;
15281533

1529-
if (inputType.getRank() != outputType.getRank())
1530-
return emitOpError() << "expect same input and output tensor rank.";
1534+
DenseIntElementsAttr paddingAttr;
1535+
if (!matchPattern(getPadding(), m_Constant(&paddingAttr))) {
1536+
return failure();
1537+
}
1538+
1539+
auto paddingValues = paddingAttr.getValues<APInt>();
1540+
if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
1541+
return emitOpError() << "padding tensor must have " << inputRank
1542+
<< " * 2 = " << inputRank * 2 << " elements, but got "
1543+
<< paddingValues.size();
15311544

1532-
if (paddingRank != inputType.getRank() * 2)
1533-
return emitOpError() << "expected padding tensor dim 0 to have size "
1534-
<< inputType.getRank() * 2
1535-
<< " (2*rank(shape1)) but got size " << paddingRank;
1545+
auto inputShape = inputType.getShape();
1546+
auto outputShape = outputType.getShape();
1547+
1548+
for (int64_t i = 0; i < inputRank; ++i) {
1549+
int64_t padStart = paddingValues[i * 2].getSExtValue();
1550+
int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
1551+
1552+
if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
1553+
return emitOpError()
1554+
<< "invalid padding values at dimension " << i
1555+
<< ": values must be non-negative or -1 for dynamic padding, got ["
1556+
<< padStart << ", " << padEnd << "]";
1557+
}
1558+
1559+
// Skip shape verification for dynamic input/output
1560+
if (inputShape[i] == ShapedType::kDynamic ||
1561+
outputShape[i] == ShapedType::kDynamic)
1562+
continue;
1563+
1564+
if (outputShape[i] != inputShape[i] + padStart + padEnd) {
1565+
return emitOpError() << "mismatch in output shape at dimension " << i
1566+
<< ": expected " << inputShape[i] << " + "
1567+
<< padStart << " + " << padEnd << " = "
1568+
<< (inputShape[i] + padStart + padEnd)
1569+
<< ", but got " << outputShape[i];
1570+
}
1571+
}
15361572

15371573
return success();
15381574
}

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1112,9 +1112,31 @@ bool checkErrorIfRescale(Operation *op) {
11121112
return true;
11131113
}
11141114

1115+
bool checkErrorIfPad(Operation *op) {
1116+
auto pad = dyn_cast<tosa::PadOp>(op);
1117+
if (!pad)
1118+
return true;
1119+
1120+
DenseIntElementsAttr paddingAttr;
1121+
if (!matchPattern(pad.getPadding(), m_Constant(&paddingAttr)))
1122+
// Pad verifier will catch this
1123+
return true;
1124+
1125+
for (const APInt &val : paddingAttr.getValues<APInt>()) {
1126+
if (val.getSExtValue() < 0) {
1127+
op->emitOpError() << "padding value must all be non-negative, got "
1128+
<< val.getSExtValue();
1129+
return false;
1130+
}
1131+
}
1132+
1133+
return true;
1134+
}
1135+
11151136
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
11161137
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
1117-
!checkErrorIfTable(op) || !checkErrorIfRescale(op))
1138+
!checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
1139+
!checkErrorIfPad(op))
11181140
return failure();
11191141
return success();
11201142
}

mlir/test/Dialect/Tosa/dynamic_extension.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<256xi8>)
2020

2121
// -----
2222

23-
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
23+
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x22x4xi8> {
2424
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
25-
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
26-
return %1 : tensor<13x21x3xi8>
25+
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x22x4xi8>
26+
return %1 : tensor<13x22x4xi8>
2727
}
2828

2929
// -----

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
262262

263263
// -----
264264

265-
func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
265+
func.func @test_pad_padding_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> {
266266
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
267267
// expected-error@+1 {{'tosa.pad' op shape operand is not compile time resolvable}}
268268
%0 = tosa.pad %arg0, %arg1, %pad_const : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21x3xf32>
@@ -271,19 +271,19 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>)
271271

272272
// -----
273273

274-
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
274+
func.func @test_pad_const_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x22x4xi8> {
275275
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
276276
// expected-error@+1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
277-
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
278-
return %1 : tensor<13x21x3xi8>
277+
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x22x4xi8>
278+
return %1 : tensor<13x22x4xi8>
279279
}
280280

281281
// -----
282282

283283
func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
284284
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
285285
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
286-
// expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank.}}
286+
// expected-error@+1 {{'tosa.pad' op expect same input and output tensor rank}}
287287
%1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<13x21x3xf32>
288288
}
289289

@@ -297,17 +297,7 @@ func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tens
297297

298298
// -----
299299

300-
func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>) {
301-
%0 = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
302-
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
303-
// expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 4 (2*rank(shape1)) but got size 6}}
304-
%1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21xf32>
305-
return
306-
}
307-
308-
// -----
309-
310-
func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tensor<2x2xi32>) {
300+
func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>) {
311301
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
312302
%1 = "tosa.const"() {values = dense<3.14> : tensor<2xf32>} : () -> tensor<2xf32>
313303
// expected-error@+1 {{'tosa.pad' op operand #2 must be tosa-conformant scalar tensor of number values, but got 'tensor<2xf32>'}}
@@ -317,12 +307,12 @@ func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>, %arg1: tenso
317307

318308
// -----
319309

320-
func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
321-
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
322-
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
323-
// expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 6 (2*rank(shape1)) but got size 4}}
324-
%1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21x3xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<13x21x3xf32>
325-
return %1 : tensor<13x21x3xf32>
310+
func.func @test_pad_invalid_padding_value(%arg0: tensor<10xf32>) {
311+
%0 = tosa.const_shape {values = dense<[-1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
312+
%1 = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
313+
// expected-error@+1 {{padding value must all be non-negative, got -1}}
314+
%2 = tosa.pad %arg0, %0, %1 : (tensor<10xf32>, !tosa.shape<2>, tensor<1xf32>) -> tensor<10xf32>
315+
return
326316
}
327317

328318
// -----

mlir/test/Dialect/Tosa/invalid_extension.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -407,11 +407,11 @@ func.func @test_inexact_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21
407407

408408
// -----
409409

410-
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
410+
func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x22x4xi8> {
411411
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
412412
// expected-error@+1 {{'tosa.pad' op expected compile time resolvable constant, but got variable value for operand #2}}
413-
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
414-
return %1 : tensor<13x21x3xi8>
413+
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x22x4xi8>
414+
return %1 : tensor<13x22x4xi8>
415415
}
416416

417417
// -----

mlir/test/Dialect/Tosa/verifier.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,3 +403,37 @@ func.func @test_gather_invalid_out_C(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1
403403
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x26x8xf32>
404404
return %0 : tensor<13x26x8xf32>
405405
}
406+
407+
// -----
408+
func.func @test_pad_padding_shape_mismatch(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
409+
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
410+
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
411+
// expected-error@+1 {{'tosa.pad' op padding tensor must have 3 * 2 = 6 elements, but got 4}}
412+
%1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21x3xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<13x21x3xf32>
413+
return %1 : tensor<13x21x3xf32>
414+
}
415+
416+
// -----
417+
func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>) {
418+
%0 = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
419+
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
420+
// expected-error@+1 {{'tosa.pad' op padding tensor must have 2 * 2 = 4 elements, but got 6}}
421+
%1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21xf32>
422+
return
423+
}
424+
425+
// -----
426+
func.func @test_pad_output_mismatch(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi8>) -> tensor<13x21x3xi8> {
427+
%0 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
428+
// expected-error@+1 {{mismatch in output shape at dimension 1: expected 21 + 0 + 1 = 22, but got 21}}
429+
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<13x21x3xi8>, !tosa.shape<6>, tensor<1xi8>) -> tensor<13x21x3xi8>
430+
return %1 : tensor<13x21x3xi8>
431+
}
432+
433+
// -----
434+
func.func @test_pad_invalid_padding_value(%arg0: tensor<10xi8>, %arg1: tensor<1xi8>) -> tensor<10xi8> {
435+
%0 = tosa.const_shape {values = dense<[-2, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
436+
// expected-error@+1 {{invalid padding values at dimension 0: values must be non-negative or -1 for dynamic padding, got [-2, 2]}}
437+
%1 = tosa.pad %arg0, %0, %arg1 : (tensor<10xi8>, !tosa.shape<2>, tensor<1xi8>) -> tensor<10xi8>
438+
return %1 : tensor<10xi8>
439+
}

0 commit comments

Comments
 (0)