-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][sparse] fuse concat and extract_slice op if possible. #89825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sparse Author: Peiming Liu (PeimingLiu) ChangesFull diff: https://github.com/llvm/llvm-project/pull/89825.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 02375f54d7152f..9e8998a8a07f35 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -209,6 +209,86 @@ static void concatSizesFromInputs(OpBuilder &builder,
namespace {
+/// TODO: move it to tensor dialect instead.
+///
+/// Fold `tensor.concat` and `tensor.extract_slice`
+///
+/// %concat = tensor.concat dim(2) %t0, %t1
+/// : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
+/// % extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
+/// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
+/// % extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
+/// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
+///
+/// Becomes
+///
+/// %extract0, %extract1 = %t0, %t1
+struct FuseExtractSliceWithConcat
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
+ if (!concatOp)
+ return failure();
+
+ Location loc = extractOp.getLoc();
+ int64_t dim = concatOp.getDim();
+ int64_t rank = extractOp.getResultType().getRank();
+
+ SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0));
+
+ // Compute the partial sums for the slice offsets.
+ AffineExpr sum = rewriter.getAffineDimExpr(0);
+ SmallVector<AffineExpr> partialSums = {sum};
+ SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
+ for (auto [idx, input] :
+ llvm::enumerate(concatOp.getInputs().drop_back())) {
+ sum = sum + rewriter.getAffineDimExpr(idx + 1);
+ partialSums.push_back(sum);
+ offsetStrides.push_back(
+ rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
+ }
+ auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
+ partialSums, rewriter.getContext());
+ SmallVector<OpFoldResult> dimOffsets =
+ affine::makeComposedFoldedMultiResultAffineApply(
+ rewriter, loc, partialSumMap, offsetStrides);
+
+ auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
+ for (auto [l, r] : llvm::zip(lhs, rhs)) {
+ std::optional<int64_t> staticVal = getConstantIntValue(l);
+ if (!staticVal.has_value() || staticVal != getConstantIntValue(r))
+ return false;
+ }
+ return lhs.size() == rhs.size();
+ };
+
+ for (auto [i, input, offset] :
+ llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
+ SmallVector<OpFoldResult> srcSizes =
+ tensor::getMixedSizes(rewriter, loc, input);
+ srcOffsets[dim] = offset;
+
+ SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
+ SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
+ SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
+
+ if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
+ allEqual(srcStrides, dstStrides)) {
+ Value operand = concatOp.getOperand(i);
+ if (operand.getType() == extractOp.getResultType())
+ rewriter.replaceOp(extractOp, operand);
+ break;
+ }
+ }
+
+ return success();
+ }
+};
+
/// Rewriting rule that converts direct yield of zero with initial allocation.
struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
public:
@@ -1426,9 +1506,9 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
//===---------------------------------------------------------------------===//
void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
- patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
- GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
- patterns.getContext());
+ patterns.add<FuseExtractSliceWithConcat, FoldInvariantYield,
+ FuseSparseMultiplyOverAdd, FuseTensorCast, GenSemiRingReduction,
+ GenSemiRingSelect, PrintRewriter>(patterns.getContext());
}
void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
|
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah looks good, but with a CHECK test
69fd4bf
to
0d33540
Compare
|
||
|
||
|
||
// CHECK-LABEL: func.func @fuse_concat_with_extract( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
No description provided.