Skip to content

[mlir][sparse] fuse concat and extract_slice op if possible. #89825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 24, 2024

Conversation

PeimingLiu
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Apr 23, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/89825.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+83-3)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 02375f54d7152f..9e8998a8a07f35 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -209,6 +209,86 @@ static void concatSizesFromInputs(OpBuilder &builder,
 
 namespace {
 
+/// TODO: move it to tensor dialect instead.
+///
+/// Fold `tensor.concat` and `tensor.extract_slice`
+///
+/// %concat = tensor.concat dim(2) %t0, %t1
+///   : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
+/// % extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
+///   : tensor<1x64x2xf32> to tensor<1x64x1xf32>
+/// % extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
+///   : tensor<1x64x2xf32> to tensor<1x64x1xf32>
+///
+/// Becomes
+///
+/// %extract0, %extract1 = %t0, %t1
+struct FuseExtractSliceWithConcat
+    : public OpRewritePattern<tensor::ExtractSliceOp> {
+  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
+    if (!concatOp)
+      return failure();
+
+    Location loc = extractOp.getLoc();
+    int64_t dim = concatOp.getDim();
+    int64_t rank = extractOp.getResultType().getRank();
+
+    SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0));
+
+    // Compute the partial sums for the slice offsets.
+    AffineExpr sum = rewriter.getAffineDimExpr(0);
+    SmallVector<AffineExpr> partialSums = {sum};
+    SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
+    for (auto [idx, input] :
+         llvm::enumerate(concatOp.getInputs().drop_back())) {
+      sum = sum + rewriter.getAffineDimExpr(idx + 1);
+      partialSums.push_back(sum);
+      offsetStrides.push_back(
+          rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
+    }
+    auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
+                                        partialSums, rewriter.getContext());
+    SmallVector<OpFoldResult> dimOffsets =
+        affine::makeComposedFoldedMultiResultAffineApply(
+            rewriter, loc, partialSumMap, offsetStrides);
+
+    auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
+      for (auto [l, r] : llvm::zip(lhs, rhs)) {
+        std::optional<int64_t> staticVal = getConstantIntValue(l);
+        if (!staticVal.has_value() || staticVal != getConstantIntValue(r))
+          return false;
+      }
+      return lhs.size() == rhs.size();
+    };
+
+    for (auto [i, input, offset] :
+         llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
+      SmallVector<OpFoldResult> srcSizes =
+          tensor::getMixedSizes(rewriter, loc, input);
+      srcOffsets[dim] = offset;
+
+      SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
+      SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
+      SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
+
+      if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
+          allEqual(srcStrides, dstStrides)) {
+        Value operand = concatOp.getOperand(i);
+        if (operand.getType() == extractOp.getResultType())
+          rewriter.replaceOp(extractOp, operand);
+        break;
+      }
+    }
+
+    return success();
+  }
+};
+
 /// Rewriting rule that converts direct yield of zero with initial allocation.
 struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
 public:
@@ -1426,9 +1506,9 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
 //===---------------------------------------------------------------------===//
 
 void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
-  patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
-               GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
-      patterns.getContext());
+  patterns.add<FuseExtractSliceWithConcat, FoldInvariantYield,
+               FuseSparseMultiplyOverAdd, FuseTensorCast, GenSemiRingReduction,
+               GenSemiRingSelect, PrintRewriter>(patterns.getContext());
 }
 
 void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,

Copy link
Contributor

@aartbik aartbik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah looks good, but with a CHECK test

@PeimingLiu PeimingLiu force-pushed the fuse-concat-extract branch from 69fd4bf to 0d33540 Compare April 24, 2024 18:50



// CHECK-LABEL: func.func @fuse_concat_with_extract(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@PeimingLiu PeimingLiu merged commit ea3eeb4 into llvm:main Apr 24, 2024
3 of 4 checks passed
@PeimingLiu PeimingLiu deleted the fuse-concat-extract branch April 24, 2024 20:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants