Skip to content

Commit 899c2be

Browse files
authored
[mlir][TilingInterface] Early return cloned ops if tile sizes are zeros. (#75410)
It is a trivial early-return case. If the cloned ops are not returned, it will generate `extract_slice` op that extracts the whole slice. However, it is not folded away. Early-return to avoid the case. E.g., ```mlir func.func @matmul_tensors( %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2: tensor<?x?xf32>) -> tensor<?x?xf32> return %0 : tensor<?x?xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.tile_using_for %0 [0, 0, 0] : (!transform.any_op) -> (!transform.any_op) transform.yield } } ``` Apply the transforms and canonicalize the IR: ``` mlir-opt --transform-interpreter -canonicalize input.mlir ``` we will get ```mlir module { func.func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32> %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32> %dim_1 = tensor.dim %arg1, %c1 : tensor<?x?xf32> %extracted_slice = tensor.extract_slice %arg0[0, 0] [%dim, %dim_0] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> %extracted_slice_2 = tensor.extract_slice %arg1[0, 0] [%dim_0, %dim_1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> %extracted_slice_3 = tensor.extract_slice %arg2[0, 0] [%dim, %dim_1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> %0 = linalg.matmul ins(%extracted_slice, %extracted_slice_2 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%extracted_slice_3 : tensor<?x?xf32>) -> tensor<?x?xf32> return %0 : tensor<?x?xf32> } } ``` The revision early-return the case so we can get: ```mlir func.func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> return %0 : tensor<?x?xf32> } ```
1 parent ac82c8b commit 899c2be

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,14 +362,22 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
362362
auto clonedOp = cast<TilingInterface>(
363363
cloneOpAndUpdateDestinationArgs(rewriter, op, clonedOpDestination));
364364

365-
// 5b. Tile the cloned operation.
365+
// 5b. Early return cloned op if tiling is not happening. We can not return
366+
// the original op because it could lead to
367+
// `rewriter.replaceOp(op, op->getResults())` and user would get crash.
368+
if (llvm::all_of(tileSizeVector, isZeroIndex)) {
369+
return scf::SCFTilingResult{/*tiledOps=*/{clonedOp}, /*loops=*/{},
370+
clonedOp->getResults()};
371+
}
372+
373+
// 5c. Tile the cloned operation.
366374
FailureOr<TilingResult> tiledImplementation =
367375
clonedOp.getTiledImplementation(rewriter, offsets, sizes);
368376
if (failed(tiledImplementation)) {
369377
return rewriter.notifyMatchFailure(op, "failed to tile operation");
370378
}
371379

372-
// 5c. Delete the cloned operation.
380+
// 5d. Delete the cloned operation.
373381
rewriter.eraseOp(clonedOp);
374382

375383
// If loops are empty, the tiled op is used as the replacement for the untiled

mlir/test/Dialect/Linalg/tile-tensors.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,33 @@ module attributes {transform.with_named_sequence} {
3737

3838
// -----
3939

40+
// CHECK-LABEL: func @matmul_tensors_with_size_zeros(
41+
// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
42+
// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<?x?xf32>
43+
// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
44+
func.func @matmul_tensors_with_size_zeros(
45+
%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
46+
-> tensor<?x?xf32> {
47+
48+
// CHECK: %[[RES:.*]] = linalg.matmul ins(%[[TA]], %[[TB]] : tensor<?x?xf32>, tensor<?x?xf32>)
49+
// CHECK-SAME: outs(%[[TC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
50+
// CHECK: return %[[RES]]
51+
%0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
52+
outs(%arg2: tensor<?x?xf32>)
53+
-> tensor<?x?xf32>
54+
return %0 : tensor<?x?xf32>
55+
}
56+
57+
module attributes {transform.with_named_sequence} {
58+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
59+
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
60+
%1 = transform.structured.tile_using_for %0 [0, 0, 0] : (!transform.any_op) -> (!transform.any_op)
61+
transform.yield
62+
}
63+
}
64+
65+
// -----
66+
4067
func.func @generic_op_tensors(
4168
%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
4269
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)