diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 8057b3898012d..1b6b4db9d2090 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -362,14 +362,22 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, auto clonedOp = cast( cloneOpAndUpdateDestinationArgs(rewriter, op, clonedOpDestination)); - // 5b. Tile the cloned operation. + // 5b. Early return cloned op if tiling is not happening. We can not return + // the original op because it could lead to + // `rewriter.replaceOp(op, op->getResults())` and user would get crash. + if (llvm::all_of(tileSizeVector, isZeroIndex)) { + return scf::SCFTilingResult{/*tiledOps=*/{clonedOp}, /*loops=*/{}, + clonedOp->getResults()}; + } + + // 5c. Tile the cloned operation. FailureOr tiledImplementation = clonedOp.getTiledImplementation(rewriter, offsets, sizes); if (failed(tiledImplementation)) { return rewriter.notifyMatchFailure(op, "failed to tile operation"); } - // 5c. Delete the cloned operation. + // 5d. Delete the cloned operation. rewriter.eraseOp(clonedOp); // If loops are empty, the tiled op is used as the replacement for the untiled diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir index e0429b1f87329..e8e6330228640 100644 --- a/mlir/test/Dialect/Linalg/tile-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir @@ -37,6 +37,33 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func @matmul_tensors_with_size_zeros( +// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor) -> tensor { +func.func @matmul_tensors_with_size_zeros( + %arg0: tensor, %arg1: tensor, %arg2: tensor) + -> tensor { + +// CHECK: %[[RES:.*]] = linalg.matmul ins(%[[TA]], %[[TB]] : tensor, tensor) +// CHECK-SAME: outs(%[[TC]] : tensor) -> tensor +// CHECK: return %[[RES]] + %0 = linalg.matmul ins(%arg0, %arg1: tensor, tensor) + outs(%arg2: tensor) + -> tensor + return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.tile_using_for %0 [0, 0, 0] : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// ----- + func.func @generic_op_tensors( %arg0 : tensor, %arg1 : tensor) -> tensor { %c0 = arith.constant 0 : index