@@ -44,9 +44,10 @@ struct LoopPipelinerInternal {
44
44
unsigned maxStage = 0 ;
45
45
DenseMap<Operation *, unsigned > stages;
46
46
std::vector<Operation *> opOrder;
47
- int64_t ub;
48
- int64_t lb;
49
- int64_t step;
47
+ Value ub;
48
+ Value lb;
49
+ Value step;
50
+ bool dynamicLoop;
50
51
PipeliningOption::AnnotationlFnType annotateFn = nullptr ;
51
52
bool peelEpilogue;
52
53
PipeliningOption::PredicateOpFn predicateFn = nullptr ;
@@ -96,25 +97,41 @@ bool LoopPipelinerInternal::initializeLoopInfo(
96
97
ForOp op, const PipeliningOption &options) {
97
98
LDBG (" Start initializeLoopInfo" );
98
99
forOp = op;
99
- auto upperBoundCst =
100
- forOp.getUpperBound ().getDefiningOp <arith::ConstantIndexOp>();
101
- auto lowerBoundCst =
102
- forOp.getLowerBound ().getDefiningOp <arith::ConstantIndexOp>();
103
- auto stepCst = forOp.getStep ().getDefiningOp <arith::ConstantIndexOp>();
100
+ ub = forOp.getUpperBound ();
101
+ lb = forOp.getLowerBound ();
102
+ step = forOp.getStep ();
103
+
104
+ dynamicLoop = true ;
105
+ auto upperBoundCst = getConstantIntValue (ub);
106
+ auto lowerBoundCst = getConstantIntValue (lb);
107
+ auto stepCst = getConstantIntValue (step);
104
108
if (!upperBoundCst || !lowerBoundCst || !stepCst) {
105
- LDBG (" --no constant bounds or step -> BAIL" );
106
- return false ;
109
+ if (!options.supportDynamicLoops ) {
110
+ LDBG (" --dynamic loop not supported -> BAIL" );
111
+ return false ;
112
+ }
113
+ } else {
114
+ int64_t ubImm = upperBoundCst.value ();
115
+ int64_t lbImm = lowerBoundCst.value ();
116
+ int64_t stepImm = stepCst.value ();
117
+ int64_t numIteration = ceilDiv (ubImm - lbImm, stepImm);
118
+ if (numIteration > maxStage) {
119
+ dynamicLoop = false ;
120
+ } else if (!options.supportDynamicLoops ) {
121
+ LDBG (" --fewer loop iterations than pipeline stages -> BAIL" );
122
+ return false ;
123
+ }
107
124
}
108
- ub = upperBoundCst.value ();
109
- lb = lowerBoundCst.value ();
110
- step = stepCst.value ();
111
125
peelEpilogue = options.peelEpilogue ;
112
126
predicateFn = options.predicateFn ;
113
- if (!peelEpilogue && predicateFn == nullptr ) {
127
+ if (( !peelEpilogue || dynamicLoop) && predicateFn == nullptr ) {
114
128
LDBG (" --no epilogue or predicate set -> BAIL" );
115
129
return false ;
116
130
}
117
- int64_t numIteration = ceilDiv (ub - lb, step);
131
+ if (dynamicLoop && peelEpilogue) {
132
+ LDBG (" --dynamic loop doesn't support epilogue yet -> BAIL" );
133
+ return false ;
134
+ }
118
135
std::vector<std::pair<Operation *, unsigned >> schedule;
119
136
options.getScheduleFn (forOp, schedule);
120
137
if (schedule.empty ()) {
@@ -128,10 +145,6 @@ bool LoopPipelinerInternal::initializeLoopInfo(
128
145
stages[opSchedule.first ] = opSchedule.second ;
129
146
opOrder.push_back (opSchedule.first );
130
147
}
131
- if (numIteration <= maxStage) {
132
- LDBG (" --fewer loop iterations than pipeline stages -> BAIL" );
133
- return false ;
134
- }
135
148
136
149
// All operations need to have a stage.
137
150
for (Operation &op : forOp.getBody ()->without_terminator ()) {
@@ -204,10 +217,31 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
204
217
setValueMapping (arg, operand.get (), 0 );
205
218
}
206
219
auto yield = cast<scf::YieldOp>(forOp.getBody ()->getTerminator ());
220
+ Location loc = forOp.getLoc ();
221
+ SmallVector<Value> predicates (maxStage);
207
222
for (int64_t i = 0 ; i < maxStage; i++) {
223
+ if (dynamicLoop) {
224
+ Type t = ub.getType ();
225
+ // pred = ub > lb + (i * step)
226
+ Value iv = rewriter.create <arith::AddIOp>(
227
+ loc, lb,
228
+ rewriter.create <arith::MulIOp>(
229
+ loc, step,
230
+ rewriter.create <arith::ConstantOp>(
231
+ loc, rewriter.getIntegerAttr (t, i))));
232
+ predicates[i] = rewriter.create <arith::CmpIOp>(
233
+ loc, arith::CmpIPredicate::slt, iv, ub);
234
+ }
235
+
208
236
// special handling for induction variable as the increment is implicit.
209
- Value iv =
210
- rewriter.create <arith::ConstantIndexOp>(forOp.getLoc (), lb + i * step);
237
+ // iv = lb + i * step
238
+ Type t = lb.getType ();
239
+ Value iv = rewriter.create <arith::AddIOp>(
240
+ loc, lb,
241
+ rewriter.create <arith::MulIOp>(
242
+ loc, step,
243
+ rewriter.create <arith::ConstantOp>(loc,
244
+ rewriter.getIntegerAttr (t, i))));
211
245
setValueMapping (forOp.getInductionVar (), iv, i);
212
246
for (Operation *op : opOrder) {
213
247
if (stages[op] > i)
@@ -220,6 +254,12 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
220
254
newOperand->set (replacement);
221
255
}
222
256
});
257
+ int predicateIdx = i - stages[op];
258
+ if (predicates[predicateIdx]) {
259
+ newOp = predicateFn (rewriter, newOp, predicates[predicateIdx]);
260
+ assert (newOp && " failed to predicate op." );
261
+ }
262
+ rewriter.setInsertionPointAfter (newOp);
223
263
if (annotateFn)
224
264
annotateFn (newOp, PipeliningOption::PipelinerPart::Prologue, i);
225
265
for (unsigned destId : llvm::seq (unsigned (0 ), op->getNumResults ())) {
@@ -326,9 +366,16 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
326
366
// `numStages - 1` iterations. Then we adjust the upper bound to remove those
327
367
// iterations.
328
368
Value newUb = forOp.getUpperBound ();
329
- if (peelEpilogue)
330
- newUb = rewriter.create <arith::ConstantIndexOp>(forOp.getLoc (),
331
- ub - maxStage * step);
369
+ if (peelEpilogue) {
370
+ Type t = ub.getType ();
371
+ Location loc = forOp.getLoc ();
372
+ // newUb = ub - maxStage * step
373
+ Value maxStageValue = rewriter.create <arith::ConstantOp>(
374
+ loc, rewriter.getIntegerAttr (t, maxStage));
375
+ Value maxStageByStep =
376
+ rewriter.create <arith::MulIOp>(loc, step, maxStageValue);
377
+ newUb = rewriter.create <arith::SubIOp>(loc, ub, maxStageByStep);
378
+ }
332
379
auto newForOp =
333
380
rewriter.create <scf::ForOp>(forOp.getLoc (), forOp.getLowerBound (), newUb,
334
381
forOp.getStep (), newLoopArg);
@@ -358,9 +405,17 @@ LogicalResult LoopPipelinerInternal::createKernel(
358
405
SmallVector<Value> predicates (maxStage + 1 , nullptr );
359
406
if (!peelEpilogue) {
360
407
// Create a predicate for each stage except the last stage.
408
+ Location loc = newForOp.getLoc ();
409
+ Type t = ub.getType ();
361
410
for (unsigned i = 0 ; i < maxStage; i++) {
362
- Value c = rewriter.create <arith::ConstantIndexOp>(
363
- newForOp.getLoc (), ub - (maxStage - i) * step);
411
+ // c = ub - (maxStage - i) * step
412
+ Value c = rewriter.create <arith::SubIOp>(
413
+ loc, ub,
414
+ rewriter.create <arith::MulIOp>(
415
+ loc, step,
416
+ rewriter.create <arith::ConstantOp>(
417
+ loc, rewriter.getIntegerAttr (t, int64_t (maxStage - i)))));
418
+
364
419
Value pred = rewriter.create <arith::CmpIOp>(
365
420
newForOp.getLoc (), arith::CmpIPredicate::slt,
366
421
newForOp.getInductionVar (), c);
@@ -383,8 +438,14 @@ LogicalResult LoopPipelinerInternal::createKernel(
383
438
// version incremented based on the stage where it is used.
384
439
if (operand->get () == forOp.getInductionVar ()) {
385
440
rewriter.setInsertionPoint (newOp);
386
- Value offset = rewriter.create <arith::ConstantIndexOp>(
387
- forOp.getLoc (), (maxStage - stages[op]) * step);
441
+
442
+ // offset = (maxStage - stages[op]) * step
443
+ Type t = step.getType ();
444
+ Value offset = rewriter.create <arith::MulIOp>(
445
+ forOp.getLoc (), step,
446
+ rewriter.create <arith::ConstantOp>(
447
+ forOp.getLoc (),
448
+ rewriter.getIntegerAttr (t, maxStage - stages[op])));
388
449
Value iv = rewriter.create <arith::AddIOp>(
389
450
forOp.getLoc (), newForOp.getInductionVar (), offset);
390
451
nestedNewOp->setOperand (operand->getOperandNumber (), iv);
@@ -508,8 +569,24 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) {
508
569
// Emit different versions of the induction variable. They will be
509
570
// removed by dead code if not used.
510
571
for (int64_t i = 0 ; i < maxStage; i++) {
511
- Value newlastIter = rewriter.create <arith::ConstantIndexOp>(
512
- forOp.getLoc (), lb + step * ((((ub - 1 ) - lb) / step) - i));
572
+ Location loc = forOp.getLoc ();
573
+ Type t = lb.getType ();
574
+ Value minusOne =
575
+ rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, -1 ));
576
+ // number of iterations = ((ub - 1) - lb) / step
577
+ Value totalNumIteration = rewriter.create <arith::DivUIOp>(
578
+ loc,
579
+ rewriter.create <arith::SubIOp>(
580
+ loc, rewriter.create <arith::AddIOp>(loc, ub, minusOne), lb),
581
+ step);
582
+ // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
583
+ Value minusI =
584
+ rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, -i));
585
+ Value newlastIter = rewriter.create <arith::AddIOp>(
586
+ loc, lb,
587
+ rewriter.create <arith::MulIOp>(
588
+ loc, step,
589
+ rewriter.create <arith::AddIOp>(loc, totalNumIteration, minusI)));
513
590
setValueMapping (forOp.getInductionVar (), newlastIter, maxStage - i);
514
591
}
515
592
// Emit `maxStage - 1` epilogue part that includes operations from stages
0 commit comments