Skip to content

Commit ef11283

Browse files
authored
[MLIR][SCF] Add support for pipelining dynamic loops (llvm#74350)
Support loops without static boundaries. Since the number of iteration is not known we need to predicate prologue and epilogue in case the number of iterations is smaller than the number of stages. This patch includes work from @chengjunlu
1 parent 9a46518 commit ef11283

File tree

5 files changed

+173
-45
lines changed

5 files changed

+173
-45
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,13 @@ struct PipeliningOption {
128128
/// lambda to generate the predicated version of operations.
129129
bool peelEpilogue = true;
130130

131+
/// Control whether the transformation checks that the number of iterations is
132+
/// greater or equal to the number of stages and skip the transformation if
133+
/// this is not the case. If the loop is dynamic and this is set to true and
134+
/// the loop bounds are not static the pipeliner will have to predicate
135+
/// operations in the the prologue/epilogue.
136+
bool supportDynamicLoops = false;
137+
131138
// Callback to predicate operations when the prologue or epilogue are not
132139
// peeled. This takes the original operation, an i1 predicate value and the
133140
// pattern rewriter. It is expected to replace the given operation with

mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp

Lines changed: 107 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@ struct LoopPipelinerInternal {
4444
unsigned maxStage = 0;
4545
DenseMap<Operation *, unsigned> stages;
4646
std::vector<Operation *> opOrder;
47-
int64_t ub;
48-
int64_t lb;
49-
int64_t step;
47+
Value ub;
48+
Value lb;
49+
Value step;
50+
bool dynamicLoop;
5051
PipeliningOption::AnnotationlFnType annotateFn = nullptr;
5152
bool peelEpilogue;
5253
PipeliningOption::PredicateOpFn predicateFn = nullptr;
@@ -96,25 +97,41 @@ bool LoopPipelinerInternal::initializeLoopInfo(
9697
ForOp op, const PipeliningOption &options) {
9798
LDBG("Start initializeLoopInfo");
9899
forOp = op;
99-
auto upperBoundCst =
100-
forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
101-
auto lowerBoundCst =
102-
forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
103-
auto stepCst = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
100+
ub = forOp.getUpperBound();
101+
lb = forOp.getLowerBound();
102+
step = forOp.getStep();
103+
104+
dynamicLoop = true;
105+
auto upperBoundCst = getConstantIntValue(ub);
106+
auto lowerBoundCst = getConstantIntValue(lb);
107+
auto stepCst = getConstantIntValue(step);
104108
if (!upperBoundCst || !lowerBoundCst || !stepCst) {
105-
LDBG("--no constant bounds or step -> BAIL");
106-
return false;
109+
if (!options.supportDynamicLoops) {
110+
LDBG("--dynamic loop not supported -> BAIL");
111+
return false;
112+
}
113+
} else {
114+
int64_t ubImm = upperBoundCst.value();
115+
int64_t lbImm = lowerBoundCst.value();
116+
int64_t stepImm = stepCst.value();
117+
int64_t numIteration = ceilDiv(ubImm - lbImm, stepImm);
118+
if (numIteration > maxStage) {
119+
dynamicLoop = false;
120+
} else if (!options.supportDynamicLoops) {
121+
LDBG("--fewer loop iterations than pipeline stages -> BAIL");
122+
return false;
123+
}
107124
}
108-
ub = upperBoundCst.value();
109-
lb = lowerBoundCst.value();
110-
step = stepCst.value();
111125
peelEpilogue = options.peelEpilogue;
112126
predicateFn = options.predicateFn;
113-
if (!peelEpilogue && predicateFn == nullptr) {
127+
if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) {
114128
LDBG("--no epilogue or predicate set -> BAIL");
115129
return false;
116130
}
117-
int64_t numIteration = ceilDiv(ub - lb, step);
131+
if (dynamicLoop && peelEpilogue) {
132+
LDBG("--dynamic loop doesn't support epilogue yet -> BAIL");
133+
return false;
134+
}
118135
std::vector<std::pair<Operation *, unsigned>> schedule;
119136
options.getScheduleFn(forOp, schedule);
120137
if (schedule.empty()) {
@@ -128,10 +145,6 @@ bool LoopPipelinerInternal::initializeLoopInfo(
128145
stages[opSchedule.first] = opSchedule.second;
129146
opOrder.push_back(opSchedule.first);
130147
}
131-
if (numIteration <= maxStage) {
132-
LDBG("--fewer loop iterations than pipeline stages -> BAIL");
133-
return false;
134-
}
135148

136149
// All operations need to have a stage.
137150
for (Operation &op : forOp.getBody()->without_terminator()) {
@@ -204,10 +217,31 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
204217
setValueMapping(arg, operand.get(), 0);
205218
}
206219
auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
220+
Location loc = forOp.getLoc();
221+
SmallVector<Value> predicates(maxStage);
207222
for (int64_t i = 0; i < maxStage; i++) {
223+
if (dynamicLoop) {
224+
Type t = ub.getType();
225+
// pred = ub > lb + (i * step)
226+
Value iv = rewriter.create<arith::AddIOp>(
227+
loc, lb,
228+
rewriter.create<arith::MulIOp>(
229+
loc, step,
230+
rewriter.create<arith::ConstantOp>(
231+
loc, rewriter.getIntegerAttr(t, i))));
232+
predicates[i] = rewriter.create<arith::CmpIOp>(
233+
loc, arith::CmpIPredicate::slt, iv, ub);
234+
}
235+
208236
// special handling for induction variable as the increment is implicit.
209-
Value iv =
210-
rewriter.create<arith::ConstantIndexOp>(forOp.getLoc(), lb + i * step);
237+
// iv = lb + i * step
238+
Type t = lb.getType();
239+
Value iv = rewriter.create<arith::AddIOp>(
240+
loc, lb,
241+
rewriter.create<arith::MulIOp>(
242+
loc, step,
243+
rewriter.create<arith::ConstantOp>(loc,
244+
rewriter.getIntegerAttr(t, i))));
211245
setValueMapping(forOp.getInductionVar(), iv, i);
212246
for (Operation *op : opOrder) {
213247
if (stages[op] > i)
@@ -220,6 +254,12 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
220254
newOperand->set(replacement);
221255
}
222256
});
257+
int predicateIdx = i - stages[op];
258+
if (predicates[predicateIdx]) {
259+
newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
260+
assert(newOp && "failed to predicate op.");
261+
}
262+
rewriter.setInsertionPointAfter(newOp);
223263
if (annotateFn)
224264
annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
225265
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
@@ -326,9 +366,16 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
326366
// `numStages - 1` iterations. Then we adjust the upper bound to remove those
327367
// iterations.
328368
Value newUb = forOp.getUpperBound();
329-
if (peelEpilogue)
330-
newUb = rewriter.create<arith::ConstantIndexOp>(forOp.getLoc(),
331-
ub - maxStage * step);
369+
if (peelEpilogue) {
370+
Type t = ub.getType();
371+
Location loc = forOp.getLoc();
372+
// newUb = ub - maxStage * step
373+
Value maxStageValue = rewriter.create<arith::ConstantOp>(
374+
loc, rewriter.getIntegerAttr(t, maxStage));
375+
Value maxStageByStep =
376+
rewriter.create<arith::MulIOp>(loc, step, maxStageValue);
377+
newUb = rewriter.create<arith::SubIOp>(loc, ub, maxStageByStep);
378+
}
332379
auto newForOp =
333380
rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
334381
forOp.getStep(), newLoopArg);
@@ -358,9 +405,17 @@ LogicalResult LoopPipelinerInternal::createKernel(
358405
SmallVector<Value> predicates(maxStage + 1, nullptr);
359406
if (!peelEpilogue) {
360407
// Create a predicate for each stage except the last stage.
408+
Location loc = newForOp.getLoc();
409+
Type t = ub.getType();
361410
for (unsigned i = 0; i < maxStage; i++) {
362-
Value c = rewriter.create<arith::ConstantIndexOp>(
363-
newForOp.getLoc(), ub - (maxStage - i) * step);
411+
// c = ub - (maxStage - i) * step
412+
Value c = rewriter.create<arith::SubIOp>(
413+
loc, ub,
414+
rewriter.create<arith::MulIOp>(
415+
loc, step,
416+
rewriter.create<arith::ConstantOp>(
417+
loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i)))));
418+
364419
Value pred = rewriter.create<arith::CmpIOp>(
365420
newForOp.getLoc(), arith::CmpIPredicate::slt,
366421
newForOp.getInductionVar(), c);
@@ -383,8 +438,14 @@ LogicalResult LoopPipelinerInternal::createKernel(
383438
// version incremented based on the stage where it is used.
384439
if (operand->get() == forOp.getInductionVar()) {
385440
rewriter.setInsertionPoint(newOp);
386-
Value offset = rewriter.create<arith::ConstantIndexOp>(
387-
forOp.getLoc(), (maxStage - stages[op]) * step);
441+
442+
// offset = (maxStage - stages[op]) * step
443+
Type t = step.getType();
444+
Value offset = rewriter.create<arith::MulIOp>(
445+
forOp.getLoc(), step,
446+
rewriter.create<arith::ConstantOp>(
447+
forOp.getLoc(),
448+
rewriter.getIntegerAttr(t, maxStage - stages[op])));
388449
Value iv = rewriter.create<arith::AddIOp>(
389450
forOp.getLoc(), newForOp.getInductionVar(), offset);
390451
nestedNewOp->setOperand(operand->getOperandNumber(), iv);
@@ -508,8 +569,24 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) {
508569
// Emit different versions of the induction variable. They will be
509570
// removed by dead code if not used.
510571
for (int64_t i = 0; i < maxStage; i++) {
511-
Value newlastIter = rewriter.create<arith::ConstantIndexOp>(
512-
forOp.getLoc(), lb + step * ((((ub - 1) - lb) / step) - i));
572+
Location loc = forOp.getLoc();
573+
Type t = lb.getType();
574+
Value minusOne =
575+
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
576+
// number of iterations = ((ub - 1) - lb) / step
577+
Value totalNumIteration = rewriter.create<arith::DivUIOp>(
578+
loc,
579+
rewriter.create<arith::SubIOp>(
580+
loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
581+
step);
582+
// newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
583+
Value minusI =
584+
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
585+
Value newlastIter = rewriter.create<arith::AddIOp>(
586+
loc, lb,
587+
rewriter.create<arith::MulIOp>(
588+
loc, step,
589+
rewriter.create<arith::AddIOp>(loc, totalNumIteration, minusI)));
513590
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
514591
}
515592
// Emit `maxStage - 1` epilogue part that includes operations from stages

mlir/test/Dialect/NVGPU/transform-pipeline-shared.mlir

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s
1+
// RUN: mlir-opt %s --transform-interpreter -canonicalize --split-input-file --verify-diagnostics | FileCheck %s
22

33
func.func @simple_depth_2_unpeeled(%global: memref<?xf32>, %result: memref<?xf32> ) {
44
%c0 = arith.constant 0 : index
@@ -78,15 +78,19 @@ module attributes {transform.with_named_sequence} {
7878

7979
// CHECK-LABEL: @async_depth_2_predicated
8080
// CHECK-SAME: %[[GLOBAL:.+]]: memref
81-
func.func @async_depth_2_predicated(%global: memref<?xf32>) {
81+
func.func @async_depth_2_predicated(%global: memref<?xf32>, %alloc_size: index) {
8282
%c0 = arith.constant 0 : index
8383
%c98 = arith.constant 98 : index
8484
%c100 = arith.constant 100 : index
85-
%c200 = arith.constant 200 : index
86-
// CHECK: %[[C4:.+]] = arith.constant 4
85+
// CHECK-DAG: %[[C4:.+]] = arith.constant 4
86+
// CHECK-DAG: %[[C90:.+]] = arith.constant 90
87+
// CHECK-DAG: %[[C96:.+]] = arith.constant 96
88+
// CHECK-DAG: %[[C8:.+]] = arith.constant 8
89+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2
90+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0
8791
%c4 = arith.constant 4 : index
8892
// CHECK: %[[SHARED:.+]] = memref.alloc{{.*}} #gpu.address_space<workgroup>
89-
%shared = memref.alloc(%c200) : memref<?xf32, #gpu.address_space<workgroup>>
93+
%shared = memref.alloc(%alloc_size) : memref<?xf32, #gpu.address_space<workgroup>>
9094
%c0f = arith.constant 0.0 : f32
9195
// CHECK: %[[TOKEN0:.+]] = nvgpu.device_async_copy
9296
// CHECK: %[[TOKEN1:.+]] = nvgpu.device_async_copy
@@ -95,16 +99,11 @@ func.func @async_depth_2_predicated(%global: memref<?xf32>) {
9599
// CHECK-SAME: %[[ITER_ARG1:.+]] = %[[TOKEN1]]
96100
scf.for %i = %c0 to %c98 step %c4 {
97101
// Condition for the predication "select" below.
98-
// CHECK: %[[C90:.+]] = arith.constant 90
99102
// CHECK: %[[CMP0:.+]] = arith.cmpi slt, %[[I]], %[[C90]]
100103
// CHECK: nvgpu.device_async_wait %[[ITER_ARG0]] {numGroups = 1
101-
102104
// Original "select" with updated induction variable.
103-
// CHECK: %[[C96:.+]] = arith.constant 96
104-
// CHECK: %[[C8:.+]] = arith.constant 8
105105
// CHECK: %[[I_PLUS_8:.+]] = arith.addi %[[I]], %[[C8]]
106106
// CHECK: %[[CMP1:.+]] = arith.cmpi slt, %[[I_PLUS_8]], %[[C96]]
107-
// CHECK: %[[C2:.+]] = arith.constant 2
108107
// CHECK: %[[SELECTED0:.+]] = arith.select %[[CMP1]], %[[C4]], %[[C2]]
109108
%c96 = arith.constant 96 : index
110109
%cond = arith.cmpi slt, %i, %c96 : index
@@ -113,14 +112,11 @@ func.func @async_depth_2_predicated(%global: memref<?xf32>) {
113112

114113
// Updated induction variables (two more) for the device_async_copy below.
115114
// These are generated repeatedly by the pipeliner.
116-
// CHECK: %[[C8_2:.+]] = arith.constant 8
117-
// CHECK: %[[I_PLUS_8_2:.+]] = arith.addi %[[I]], %[[C8_2]]
118-
// CHECK: %[[C8_3:.+]] = arith.constant 8
119-
// CHECK: %[[I_PLUS_8_3:.+]] = arith.addi %[[I]], %[[C8_3]]
115+
// CHECK: %[[I_PLUS_8_2:.+]] = arith.addi %[[I]], %[[C8]]
116+
// CHECK: %[[I_PLUS_8_3:.+]] = arith.addi %[[I]], %[[C8]]
120117

121118
// The second "select" is generated by predication and selects 0 for
122119
// the two last iterations.
123-
// CHECK: %[[C0:.+]] = arith.constant 0
124120
// CHECK: %[[SELECTED1:.+]] = arith.select %[[CMP0]], %[[SELECTED0]], %[[C0]]
125121
// CHECK: %[[ASYNC_TOKEN:.+]] = nvgpu.device_async_copy %[[GLOBAL]][%[[I_PLUS_8_3]]], %[[SHARED]][%[[I_PLUS_8_2]]], 4, %[[SELECTED1]]
126122
%token = nvgpu.device_async_copy %global[%i], %shared[%i], 4, %read_size

mlir/test/Dialect/SCF/loop-pipelining.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,3 +723,50 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>) {
723723
memref.store %r, %result[%c1] : memref<?xf32>
724724
return
725725
}
726+
727+
// -----
728+
729+
// NOEPILOGUE-LABEL: dynamic_loop(
730+
// NOEPILOGUE-SAME: %[[A:.*]]: memref<?xf32>, %[[R:.*]]: memref<?xf32>, %[[LB:.+]]: index, %[[UB:.+]]: index, %[[STEP:.+]]: index) {
731+
// NOEPILOGUE-DAG: %[[C2:.+]] = arith.constant 2 : index
732+
// NOEPILOGUE-DAG: %[[CSTF:.+]] = arith.constant 1.000000e+00 : f32
733+
// Prologue:
734+
// NOEPILOGUE: %[[P_I0:.+]] = arith.cmpi slt, %[[LB]], %[[UB]] : index
735+
// NOEPILOGUE: %[[L0:.+]] = scf.if %[[P_I0]] -> (f32) {
736+
// NOEPILOGUE-NEXT: memref.load %[[A]][%[[LB]]] : memref<?xf32>
737+
// NOEPILOGUE: %[[IV1:.+]] = arith.addi %[[LB]], %[[STEP]] : index
738+
// NOEPILOGUE: %[[P_I1:.+]] = arith.cmpi slt, %[[IV1]], %[[UB]] : index
739+
// NOEPILOGUE: %[[IV1_2:.+]] = arith.addi %[[LB]], %[[STEP]] : index
740+
// NOEPILOGUE: %[[V0:.+]] = scf.if %[[P_I0]] -> (f32) {
741+
// NOEPILOGUE-NEXT: arith.addf %[[L0]], %[[CSTF]] : f32
742+
// NOEPILOGUE: %[[L1:.+]] = scf.if %[[P_I1]] -> (f32) {
743+
// NOEPILOGUE-NEXT: memref.load %[[A]][%[[IV1_2]]] : memref<?xf32>
744+
// NOEPILOGUE: scf.for %[[IV2:.+]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[V1:.+]] = %[[V0]], %[[L2:.+]] = %[[L1]]) -> (f32, f32) {
745+
// NOEPILOGUE-DAG: %[[S2:.+]] = arith.muli %[[STEP]], %[[C2]] : index
746+
// NOEPILOGUE-DAG: %[[IT2:.+]] = arith.subi %[[UB]], %[[S2]] : index
747+
// NOEPILOGUE-DAG: %[[P_I2:.+]] = arith.cmpi slt, %[[IV2]], %[[IT2]] : index
748+
// NOEPILOGUE-DAG: %[[IT3:.+]] = arith.subi %[[UB]], %[[STEP]] : index
749+
// NOEPILOGUE-DAG: %[[P_I3:.+]] = arith.cmpi slt, %[[IV2]], %[[IT3]] : index
750+
// NOEPILOGUE: memref.store %[[V1]], %[[R]][%[[IV2]]] : memref<?xf32>
751+
// NOEPILOGUE: %[[V2:.+]] = scf.if %[[P_I3]] -> (f32) {
752+
// NOEPILOGUE: arith.addf %[[L2]], %[[CSTF]] : f32
753+
// NOEPILOGUE: %[[IT4:.+]] = arith.muli %[[STEP]], %[[C2]] : index
754+
// NOEPILOGUE: %[[IV3:.+]] = arith.addi %[[IV2]], %[[IT4]] : index
755+
// NOEPILOGUE: %[[L3:.+]] = scf.if %[[P_I2]] -> (f32) {
756+
// NOEPILOGUE: memref.load %[[A]][%[[IV3]]] : memref<?xf32>
757+
// NOEPILOGUE: scf.yield %[[V2]], %[[L3]] : f32, f32
758+
759+
// In case dynamic loop pipelining is off check that the transformation didn't
760+
// apply.
761+
// CHECK-LABEL: dynamic_loop(
762+
// CHECK-NOT: memref.load
763+
// CHECK: scf.for
764+
func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
765+
%cf = arith.constant 1.0 : f32
766+
scf.for %i0 = %lb to %ub step %step {
767+
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
768+
%A1_elem = arith.addf %A_elem, %cf { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
769+
memref.store %A1_elem, %result[%i0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 0 } : memref<?xf32>
770+
} { __test_pipelining_loop__ }
771+
return
772+
}

mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ struct TestSCFPipeliningPass
217217
if (annotatePipeline)
218218
options.annotateFn = annotate;
219219
if (noEpiloguePeeling) {
220+
options.supportDynamicLoops = true;
220221
options.peelEpilogue = false;
221222
options.predicateFn = predicateOp;
222223
}

0 commit comments

Comments
 (0)