Skip to content

Commit ea3eeb4

Browse files
author
Peiming Liu
authored
[mlir][sparse] fuse concat and extract_slice op if possible. (#89825)
1 parent d433873 commit ea3eeb4

File tree

2 files changed

+106
-3
lines changed

2 files changed

+106
-3
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,86 @@ static void concatSizesFromInputs(OpBuilder &builder,
209209

210210
namespace {
211211

212+
/// TODO: move it to tensor dialect instead.
213+
///
214+
/// Fold `tensor.concat` and `tensor.extract_slice`
215+
///
216+
/// %concat = tensor.concat dim(2) %t0, %t1
217+
/// : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
218+
/// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
219+
/// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
220+
/// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
221+
/// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
222+
///
223+
/// Becomes
224+
///
225+
/// %extract0, %extract1 = %t0, %t1
226+
struct FuseExtractSliceWithConcat
227+
: public OpRewritePattern<tensor::ExtractSliceOp> {
228+
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
229+
230+
LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
231+
PatternRewriter &rewriter) const override {
232+
auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
233+
if (!concatOp)
234+
return failure();
235+
236+
Location loc = extractOp.getLoc();
237+
int64_t dim = concatOp.getDim();
238+
int64_t rank = extractOp.getResultType().getRank();
239+
240+
SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1));
241+
SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0));
242+
243+
// Compute the partial sums for the slice offsets.
244+
AffineExpr sum = rewriter.getAffineDimExpr(0);
245+
SmallVector<AffineExpr> partialSums = {sum};
246+
SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
247+
for (auto [idx, input] :
248+
llvm::enumerate(concatOp.getInputs().drop_back())) {
249+
sum = sum + rewriter.getAffineDimExpr(idx + 1);
250+
partialSums.push_back(sum);
251+
offsetStrides.push_back(
252+
rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
253+
}
254+
auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
255+
partialSums, rewriter.getContext());
256+
SmallVector<OpFoldResult> dimOffsets =
257+
affine::makeComposedFoldedMultiResultAffineApply(
258+
rewriter, loc, partialSumMap, offsetStrides);
259+
260+
auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
261+
for (auto [l, r] : llvm::zip(lhs, rhs)) {
262+
std::optional<int64_t> staticVal = getConstantIntValue(l);
263+
if (!staticVal.has_value() || staticVal != getConstantIntValue(r))
264+
return false;
265+
}
266+
return lhs.size() == rhs.size();
267+
};
268+
269+
for (auto [i, input, offset] :
270+
llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
271+
SmallVector<OpFoldResult> srcSizes =
272+
tensor::getMixedSizes(rewriter, loc, input);
273+
srcOffsets[dim] = offset;
274+
275+
SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
276+
SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
277+
SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
278+
279+
if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
280+
allEqual(srcStrides, dstStrides)) {
281+
Value operand = concatOp.getOperand(i);
282+
if (operand.getType() == extractOp.getResultType())
283+
rewriter.replaceOp(extractOp, operand);
284+
break;
285+
}
286+
}
287+
288+
return success();
289+
}
290+
};
291+
212292
/// Rewriting rule that converts direct yield of zero with initial allocation.
213293
struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
214294
public:
@@ -1426,9 +1506,9 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
14261506
//===---------------------------------------------------------------------===//
14271507

14281508
void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
1429-
patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1430-
GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
1431-
patterns.getContext());
1509+
patterns.add<FuseExtractSliceWithConcat, FoldInvariantYield,
1510+
FuseSparseMultiplyOverAdd, FuseTensorCast, GenSemiRingReduction,
1511+
GenSemiRingSelect, PrintRewriter>(patterns.getContext());
14321512
}
14331513

14341514
void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt %s --pre-sparsification-rewrite | FileCheck %s
2+
3+
#CCCD = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }>
4+
5+
6+
7+
// CHECK-LABEL: func.func @fuse_concat_with_extract(
8+
// CHECK-SAME: %[[VAL_0:.*0]]: tensor<128x32x32x1xf32, #sparse{{[0-9]*}}>,
9+
// CHECK-SAME: %[[VAL_1:.*1]]: tensor<128x32x32x1xf32, #sparse{{[0-9]*}}>,
10+
// CHECK-SAME: %[[VAL_2:.*2]]: tensor<128x32x32x1xf32, #sparse{{[0-9]*}}>)
11+
// CHECK-NOT: tensor.concat
12+
// CHECK-NOT: tensor.extract_slice
13+
// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
14+
// CHECK: }
15+
func.func @fuse_concat_with_extract(%t0 : tensor<128x32x32x1xf32, #CCCD>,
16+
%t1 : tensor<128x32x32x1xf32, #CCCD>,
17+
%t2 : tensor<128x32x32x1xf32, #CCCD>) -> (tensor<128x32x32x1xf32, #CCCD>, tensor<128x32x32x1xf32, #CCCD>, tensor<128x32x32x1xf32, #CCCD>) {
18+
%concat = tensor.concat dim(3) %t0, %t1, %t2 : (tensor<128x32x32x1xf32, #CCCD>, tensor<128x32x32x1xf32, #CCCD>, tensor<128x32x32x1xf32, #CCCD>) -> tensor<128x32x32x3xf32, #CCCD>
19+
%r0 = tensor.extract_slice %concat[0, 0, 0, 0] [128, 32, 32, 1] [1, 1, 1, 1] : tensor<128x32x32x3xf32, #CCCD> to tensor<128x32x32x1xf32, #CCCD>
20+
%r1 = tensor.extract_slice %concat[0, 0, 0, 1] [128, 32, 32, 1] [1, 1, 1, 1] : tensor<128x32x32x3xf32, #CCCD> to tensor<128x32x32x1xf32, #CCCD>
21+
%r2 = tensor.extract_slice %concat[0, 0, 0, 2] [128, 32, 32, 1] [1, 1, 1, 1] : tensor<128x32x32x3xf32, #CCCD> to tensor<128x32x32x1xf32, #CCCD>
22+
return %r0, %r1, %r2 : tensor<128x32x32x1xf32, #CCCD>, tensor<128x32x32x1xf32, #CCCD>, tensor<128x32x32x1xf32, #CCCD>
23+
}

0 commit comments

Comments
 (0)