@@ -209,6 +209,86 @@ static void concatSizesFromInputs(OpBuilder &builder,
209
209
210
210
namespace {
211
211
212
+ // / TODO: move it to tensor dialect instead.
213
+ // /
214
+ // / Fold `tensor.concat` and `tensor.extract_slice`
215
+ // /
216
+ // / %concat = tensor.concat dim(2) %t0, %t1
217
+ // / : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
218
+ // / %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
219
+ // / : tensor<1x64x2xf32> to tensor<1x64x1xf32>
220
+ // / %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
221
+ // / : tensor<1x64x2xf32> to tensor<1x64x1xf32>
222
+ // /
223
+ // / Becomes
224
+ // /
225
+ // / %extract0, %extract1 = %t0, %t1
226
+ struct FuseExtractSliceWithConcat
227
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
228
+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
229
+
230
+ LogicalResult matchAndRewrite (tensor::ExtractSliceOp extractOp,
231
+ PatternRewriter &rewriter) const override {
232
+ auto concatOp = extractOp.getSource ().getDefiningOp <tensor::ConcatOp>();
233
+ if (!concatOp)
234
+ return failure ();
235
+
236
+ Location loc = extractOp.getLoc ();
237
+ int64_t dim = concatOp.getDim ();
238
+ int64_t rank = extractOp.getResultType ().getRank ();
239
+
240
+ SmallVector<OpFoldResult> srcStrides (rank, rewriter.getIndexAttr (1 ));
241
+ SmallVector<OpFoldResult> srcOffsets (rank, rewriter.getIndexAttr (0 ));
242
+
243
+ // Compute the partial sums for the slice offsets.
244
+ AffineExpr sum = rewriter.getAffineDimExpr (0 );
245
+ SmallVector<AffineExpr> partialSums = {sum};
246
+ SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr (0 )};
247
+ for (auto [idx, input] :
248
+ llvm::enumerate (concatOp.getInputs ().drop_back ())) {
249
+ sum = sum + rewriter.getAffineDimExpr (idx + 1 );
250
+ partialSums.push_back (sum);
251
+ offsetStrides.push_back (
252
+ rewriter.createOrFold <tensor::DimOp>(loc, input, dim));
253
+ }
254
+ auto partialSumMap = AffineMap::get (concatOp.getInputs ().size (), 0 ,
255
+ partialSums, rewriter.getContext ());
256
+ SmallVector<OpFoldResult> dimOffsets =
257
+ affine::makeComposedFoldedMultiResultAffineApply (
258
+ rewriter, loc, partialSumMap, offsetStrides);
259
+
260
+ auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
261
+ for (auto [l, r] : llvm::zip (lhs, rhs)) {
262
+ std::optional<int64_t > staticVal = getConstantIntValue (l);
263
+ if (!staticVal.has_value () || staticVal != getConstantIntValue (r))
264
+ return false ;
265
+ }
266
+ return lhs.size () == rhs.size ();
267
+ };
268
+
269
+ for (auto [i, input, offset] :
270
+ llvm::enumerate (concatOp.getInputs (), dimOffsets)) {
271
+ SmallVector<OpFoldResult> srcSizes =
272
+ tensor::getMixedSizes (rewriter, loc, input);
273
+ srcOffsets[dim] = offset;
274
+
275
+ SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes ();
276
+ SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets ();
277
+ SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides ();
278
+
279
+ if (allEqual (srcSizes, dstSizes) && allEqual (srcOffsets, dstOffsets) &&
280
+ allEqual (srcStrides, dstStrides)) {
281
+ Value operand = concatOp.getOperand (i);
282
+ if (operand.getType () == extractOp.getResultType ())
283
+ rewriter.replaceOp (extractOp, operand);
284
+ break ;
285
+ }
286
+ }
287
+
288
+ return success ();
289
+ }
290
+ };
291
+
212
292
// / Rewriting rule that converts direct yield of zero with initial allocation.
213
293
struct FoldInvariantYield : public OpRewritePattern <GenericOp> {
214
294
public:
@@ -1426,9 +1506,9 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
1426
1506
// ===---------------------------------------------------------------------===//
1427
1507
1428
1508
void mlir::populatePreSparsificationRewriting (RewritePatternSet &patterns) {
1429
- patterns.add <FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast ,
1430
- GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
1431
- patterns.getContext ());
1509
+ patterns.add <FuseExtractSliceWithConcat, FoldInvariantYield ,
1510
+ FuseSparseMultiplyOverAdd, FuseTensorCast, GenSemiRingReduction,
1511
+ GenSemiRingSelect, PrintRewriter>( patterns.getContext ());
1432
1512
}
1433
1513
1434
1514
void mlir::populateLowerSparseOpsToForeachPatterns (RewritePatternSet &patterns,
0 commit comments