Skip to content

[mlir][TilingInterface] Early return cloned ops if tile sizes are zeros. #75410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 19, 2023

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Dec 14, 2023

It is a trivial early-return case. If the cloned ops are not returned, it will generate extract_slice op that extracts the whole slice. However, it is not folded away. Early-return to avoid the case.

E.g.,

func.func @matmul_tensors(
  %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
    -> tensor<?x?xf32> {
  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
                     outs(%arg2: tensor<?x?xf32>)
    -> tensor<?x?xf32>
  return %0 : tensor<?x?xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
    %1 = transform.structured.tile_using_for %0 [0, 0, 0] : (!transform.any_op) -> (!transform.any_op)
    transform.yield
  }
}

Apply the transforms and canonicalize the IR:

mlir-opt --transform-interpreter -canonicalize input.mlir

we will get

module {
  func.func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
    %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
    %dim_1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
    %extracted_slice = tensor.extract_slice %arg0[0, 0] [%dim, %dim_0] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
    %extracted_slice_2 = tensor.extract_slice %arg1[0, 0] [%dim_0, %dim_1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
    %extracted_slice_3 = tensor.extract_slice %arg2[0, 0] [%dim, %dim_1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
    %0 = linalg.matmul ins(%extracted_slice, %extracted_slice_2 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%extracted_slice_3 : tensor<?x?xf32>) -> tensor<?x?xf32>
    return %0 : tensor<?x?xf32>
  }
}

The revision early-return the case so we can get:

func.func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
  return %0 : tensor<?x?xf32>
}

It is a trivial early-return case. If the cloned ops are not returned,
it will generate `extract_slice` op that extracts the whole slice.
However, it is not folded away. Early-return to avoid the case.

E.g.,

```mlir
func.func @matmul_tensors(
  %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
    -> tensor<?x?xf32> {
  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
                     outs(%arg2: tensor<?x?xf32>)
    -> tensor<?x?xf32>
  return %0 : tensor<?x?xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
    %1 = transform.structured.tile_using_for %0 [0, 0, 0] : (!transform.any_op) -> (!transform.any_op)
    transform.yield
  }
}
```

Apply the transforms and canonicalize the IR:

```
mlir-opt --transform-interpreter -canonicalize input.mlir
```

we will get

```mlir
module {
  func.func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
    %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
    %dim_1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
    %extracted_slice = tensor.extract_slice %arg0[0, 0] [%dim, %dim_0] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
    %extracted_slice_2 = tensor.extract_slice %arg1[0, 0] [%dim_0, %dim_1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
    %extracted_slice_3 = tensor.extract_slice %arg2[0, 0] [%dim, %dim_1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
    %0 = linalg.matmul ins(%extracted_slice, %extracted_slice_2 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%extracted_slice_3 : tensor<?x?xf32>) -> tensor<?x?xf32>
    return %0 : tensor<?x?xf32>
  }
}
```
@llvmbot
Copy link
Member

llvmbot commented Dec 14, 2023

@llvm/pr-subscribers-mlir-scf
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Han-Chung Wang (hanhanW)

Changes

It is a trivial early-return case. If the cloned ops are not returned, it will generate extract_slice op that extracts the whole slice. However, it is not folded away. Early-return to avoid the case.

E.g.,

func.func @<!-- -->matmul_tensors(
  %arg0: tensor&lt;?x?xf32&gt;, %arg1: tensor&lt;?x?xf32&gt;, %arg2: tensor&lt;?x?xf32&gt;)
    -&gt; tensor&lt;?x?xf32&gt; {
  %0 = linalg.matmul  ins(%arg0, %arg1: tensor&lt;?x?xf32&gt;, tensor&lt;?x?xf32&gt;)
                     outs(%arg2: tensor&lt;?x?xf32&gt;)
    -&gt; tensor&lt;?x?xf32&gt;
  return %0 : tensor&lt;?x?xf32&gt;
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @<!-- -->__transform_main(%arg1: !transform.any_op {transform.readonly}) {
    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -&gt; !transform.any_op
    %1 = transform.structured.tile_using_for %0 [0, 0, 0] : (!transform.any_op) -&gt; (!transform.any_op)
    transform.yield
  }
}

Apply the transforms and canonicalize the IR:

mlir-opt --transform-interpreter -canonicalize input.mlir

we will get

module {
  func.func @<!-- -->matmul_tensors(%arg0: tensor&lt;?x?xf32&gt;, %arg1: tensor&lt;?x?xf32&gt;, %arg2: tensor&lt;?x?xf32&gt;) -&gt; tensor&lt;?x?xf32&gt; {
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg0, %c0 : tensor&lt;?x?xf32&gt;
    %dim_0 = tensor.dim %arg0, %c1 : tensor&lt;?x?xf32&gt;
    %dim_1 = tensor.dim %arg1, %c1 : tensor&lt;?x?xf32&gt;
    %extracted_slice = tensor.extract_slice %arg0[0, 0] [%dim, %dim_0] [1, 1] : tensor&lt;?x?xf32&gt; to tensor&lt;?x?xf32&gt;
    %extracted_slice_2 = tensor.extract_slice %arg1[0, 0] [%dim_0, %dim_1] [1, 1] : tensor&lt;?x?xf32&gt; to tensor&lt;?x?xf32&gt;
    %extracted_slice_3 = tensor.extract_slice %arg2[0, 0] [%dim, %dim_1] [1, 1] : tensor&lt;?x?xf32&gt; to tensor&lt;?x?xf32&gt;
    %0 = linalg.matmul ins(%extracted_slice, %extracted_slice_2 : tensor&lt;?x?xf32&gt;, tensor&lt;?x?xf32&gt;) outs(%extracted_slice_3 : tensor&lt;?x?xf32&gt;) -&gt; tensor&lt;?x?xf32&gt;
    return %0 : tensor&lt;?x?xf32&gt;
  }
}

Full diff: https://github.com/llvm/llvm-project/pull/75410.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+9-2)
  • (modified) mlir/test/Dialect/Linalg/tile-tensors.mlir (+27)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 8057b3898012d4..20413aba8730be 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -362,14 +362,21 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
   auto clonedOp = cast<TilingInterface>(
       cloneOpAndUpdateDestinationArgs(rewriter, op, clonedOpDestination));
 
-  // 5b. Tile the cloned operation.
+  // 5b. Early return cloned op if tiling is not happenning.
+  if (llvm::all_of(tileSizeVector,
+                   [](OpFoldResult v) { return isZeroIndex(v); })) {
+    return scf::SCFTilingResult{/*tiledOps=*/{clonedOp}, /*loops=*/{},
+                                clonedOp->getResults()};
+  }
+
+  // 5c. Tile the cloned operation.
   FailureOr<TilingResult> tiledImplementation =
       clonedOp.getTiledImplementation(rewriter, offsets, sizes);
   if (failed(tiledImplementation)) {
     return rewriter.notifyMatchFailure(op, "failed to tile operation");
   }
 
-  // 5c. Delete the cloned operation.
+  // 5d. Delete the cloned operation.
   rewriter.eraseOp(clonedOp);
 
   // If loops are empty, the tiled op is used as the replacement for the untiled
diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index e0429b1f873298..e8e63302286400 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -37,6 +37,33 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func @matmul_tensors_with_size_zeros(
+// CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME:    %[[TB:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME:    %[[TC:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @matmul_tensors_with_size_zeros(
+  %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+
+//      CHECK:     %[[RES:.*]] = linalg.matmul ins(%[[TA]], %[[TB]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME:                                outs(%[[TC]] : tensor<?x?xf32>)  -> tensor<?x?xf32>
+//      CHECK:     return %[[RES]]
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
+                     outs(%arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.tile_using_for %0 [0, 0, 0] : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @generic_op_tensors(
   %arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
   %c0 = arith.constant 0 : index

@@ -362,14 +362,20 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
auto clonedOp = cast<TilingInterface>(
cloneOpAndUpdateDestinationArgs(rewriter, op, clonedOpDestination));

// 5b. Tile the cloned operation.
// 5b. Early return cloned op if tiling is not happening.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might work, but maybe we should just not clone the op if tile sizes are zero.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried that, but it crashes about rewriter states. I did not dig into details because I have less experience about that...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My take is that there are issues about replace op with itself. Do you want me to look into the crash further or add more comments to the code and commit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tested it locally, and think it makes sense to return the cloned op. The suggestion will lead to rewriter.replace(op, op->getResults()) and get the crash. Returning the cloned op instead of op itself or failure makes more sense in this case. I will add some comments to it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, can't replace an op with itself. In theory you could say the caller should check for this and not do the replace, but that is strange. Let's go with this.

@hanhanW hanhanW merged commit 899c2be into llvm:main Dec 19, 2023
@hanhanW hanhanW deleted the tiling-early-return branch December 19, 2023 17:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants