Skip to content

Commit bc194a5

Browse files
[mlir][SCF] Do not peel loops inside partial iterations
Do not apply loop peeling to loops that are contained in the partial iteration of an already peeled loop. This is to avoid code explosion when dealing with large loop nests. Can be controlled with a new pass option `skip-partial`. Differential Revision: https://reviews.llvm.org/D108542
1 parent 2556f58 commit bc194a5

File tree

4 files changed

+88
-10
lines changed

4 files changed

+88
-10
lines changed

mlir/include/mlir/Dialect/SCF/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ def SCFForLoopPeeling
2121
: FunctionPass<"for-loop-peeling"> {
2222
let summary = "Peel `for` loops at their upper bounds.";
2323
let constructor = "mlir::createForLoopPeelingPass()";
24+
let options = [
25+
Option<"skipPartial", "skip-partial", "bool",
26+
/*default=*/"true",
27+
"Do not peel loops inside of the last, partial iteration of another "
28+
"already peeled loop.">
29+
];
2430
let dependentDialects = ["AffineDialect"];
2531
}
2632

mlir/include/mlir/Dialect/SCF/Transforms.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ void naivelyFuseParallelOps(Region &region);
4141

4242
/// Rewrite a for loop with bounds/step that potentially do not divide evenly
4343
/// into a for loop where the step divides the iteration space evenly, followed
44-
/// by an scf.if for the last (partial) iteration (if any). This transformation
45-
/// is called "loop peeling".
44+
/// by an scf.if for the last (partial) iteration (if any; returned via `ifOp`).
45+
/// This transformation is called "loop peeling".
4646
///
4747
/// This transformation is beneficial for a wide range of transformations such
4848
/// as vectorization or loop tiling: It enables additional canonicalizations
@@ -81,7 +81,8 @@ void naivelyFuseParallelOps(Region &region);
8181
/// Note: This function rewrites the given scf.for loop in-place and creates a
8282
/// new scf.if operation for the last iteration. It replaces all uses of the
8383
/// unpeeled loop with the results of the newly generated scf.if.
84-
LogicalResult peelAndCanonicalizeForLoop(RewriterBase &rewriter, ForOp forOp);
84+
LogicalResult peelAndCanonicalizeForLoop(RewriterBase &rewriter, ForOp forOp,
85+
scf::IfOp &ifOp);
8586

8687
/// Tile a parallel loop of the form
8788
/// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)

mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,9 @@ static LogicalResult rewritePeeledAffineOp(RewriterBase &rewriter,
362362
}
363363

364364
LogicalResult mlir::scf::peelAndCanonicalizeForLoop(RewriterBase &rewriter,
365-
ForOp forOp) {
365+
ForOp forOp,
366+
scf::IfOp &ifOp) {
366367
Value ub = forOp.upperBound();
367-
scf::IfOp ifOp;
368368
Value splitBound;
369369
if (failed(peelForLoop(rewriter, forOp, ifOp, splitBound)))
370370
return failure();
@@ -383,23 +383,45 @@ LogicalResult mlir::scf::peelAndCanonicalizeForLoop(RewriterBase &rewriter,
383383
}
384384

385385
static constexpr char kPeeledLoopLabel[] = "__peeled_loop__";
386+
static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
386387

387388
namespace {
388389
struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
389-
using OpRewritePattern<ForOp>::OpRewritePattern;
390+
ForLoopPeelingPattern(MLIRContext *ctx, bool skipPartial)
391+
: OpRewritePattern<ForOp>(ctx), skipPartial(skipPartial) {}
390392

391393
LogicalResult matchAndRewrite(ForOp forOp,
392394
PatternRewriter &rewriter) const override {
395+
// Do not peel already peeled loops.
393396
if (forOp->hasAttr(kPeeledLoopLabel))
394397
return failure();
395-
if (failed(peelAndCanonicalizeForLoop(rewriter, forOp)))
398+
if (skipPartial) {
399+
// No peeling of loops inside the partial iteration (scf.if) of another
400+
// peeled loop.
401+
Operation *op = forOp.getOperation();
402+
while ((op = op->getParentOfType<scf::IfOp>())) {
403+
if (op->hasAttr(kPartialIterationLabel))
404+
return failure();
405+
}
406+
}
407+
// Apply loop peeling.
408+
scf::IfOp ifOp;
409+
if (failed(peelAndCanonicalizeForLoop(rewriter, forOp, ifOp)))
396410
return failure();
397411
// Apply label, so that the same loop is not rewritten a second time.
398412
rewriter.updateRootInPlace(forOp, [&]() {
399413
forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
400414
});
415+
ifOp->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
401416
return success();
402417
}
418+
419+
/// If set to true, loops inside partial iterations of another peeled loop
420+
/// are not peeled. This reduces the size of the generated code. Partial
421+
/// iterations are not usually performance critical.
422+
/// Note: Takes into account the entire chain of parent operations, not just
423+
/// the direct parent.
424+
bool skipPartial;
403425
};
404426
} // namespace
405427

@@ -424,11 +446,14 @@ struct ForLoopPeeling : public SCFForLoopPeelingBase<ForLoopPeeling> {
424446
FuncOp funcOp = getFunction();
425447
MLIRContext *ctx = funcOp.getContext();
426448
RewritePatternSet patterns(ctx);
427-
patterns.add<ForLoopPeelingPattern>(ctx);
449+
patterns.add<ForLoopPeelingPattern>(ctx, skipPartial);
428450
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
429451

430-
// Drop the marker.
431-
funcOp.walk([](ForOp op) { op->removeAttr(kPeeledLoopLabel); });
452+
// Drop the markers.
453+
funcOp.walk([](Operation *op) {
454+
op->removeAttr(kPeeledLoopLabel);
455+
op->removeAttr(kPartialIterationLabel);
456+
});
432457
}
433458
};
434459
} // namespace

mlir/test/Dialect/SCF/for-loop-peeling.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt %s -for-loop-peeling -canonicalize -split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s -for-loop-peeling=skip-partial=false -canonicalize -split-input-file | FileCheck %s -check-prefix=CHECK-NO-SKIP
23

34
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 - (s1 - s0) mod s2)>
45
// CHECK-DAG: #[[MAP1:.*]] = affine_map<()[s0, s1, s2] -> (-(s0 - (s0 - s1) mod s2) + s0)>
@@ -223,3 +224,48 @@ func @test_affine_min_rewrite(%lb : index, %ub: index,
223224
}
224225
return
225226
}
227+
228+
// -----
229+
230+
// CHECK: func @nested_loops
231+
// CHECK: scf.for {{.*}} {
232+
// CHECK: scf.for {{.*}} {
233+
// CHECK: }
234+
// CHECK: scf.if {{.*}} {
235+
// CHECK: }
236+
// CHECK: }
237+
// CHECK: scf.if {{.*}} {
238+
// CHECK: scf.for {{.*}} {
239+
// CHECK: }
240+
// CHECK-NOT: scf.if
241+
// CHECK: }
242+
243+
// CHECK-NO-SKIP: func @nested_loops
244+
// CHECK-NO-SKIP: scf.for {{.*}} {
245+
// CHECK-NO-SKIP: scf.for {{.*}} {
246+
// CHECK-NO-SKIP: }
247+
// CHECK-NO-SKIP: scf.if {{.*}} {
248+
// CHECK-NO-SKIP: }
249+
// CHECK-NO-SKIP: }
250+
// CHECK-NO-SKIP: scf.if {{.*}} {
251+
// CHECK-NO-SKIP: scf.for {{.*}} {
252+
// CHECK-NO-SKIP: }
253+
// CHECK-NO-SKIP: scf.if {{.*}} {
254+
// CHECK-NO-SKIP: }
255+
// CHECK-NO-SKIP: }
256+
#map = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)>
257+
func @nested_loops(%lb0: index, %lb1 : index, %ub0: index, %ub1: index,
258+
%step: index) -> i32 {
259+
%c0 = constant 0 : i32
260+
%r0 = scf.for %iv0 = %lb0 to %ub0 step %step iter_args(%arg0 = %c0) -> i32 {
261+
%r1 = scf.for %iv1 = %lb1 to %ub1 step %step iter_args(%arg1 = %arg0) -> i32 {
262+
%s = affine.min #map(%ub1, %iv1)[%step]
263+
%casted = index_cast %s : index to i32
264+
%0 = addi %arg1, %casted : i32
265+
scf.yield %0 : i32
266+
}
267+
%1 = addi %arg0, %r1 : i32
268+
scf.yield %1 : i32
269+
}
270+
return %r0 : i32
271+
}

0 commit comments

Comments
 (0)