Skip to content

Commit f7c54c4

Browse files
committed
[LoopUnroll] Fold all exits based on known trip count/multiple
Fold all exits based on known trip count/multiple information from SCEV. Previously only the latch exit or the single exit were folded. This doesn't yet eliminate ULO.TripCount and ULO.TripMultiple entirely: They're still used to a) decide whether runtime unrolling should be performed and b) for ORE remarks. However, the core unrolling logic is independent of them now. Differential Revision: https://reviews.llvm.org/D104203
1 parent dc11d4e commit f7c54c4

File tree

3 files changed

+102
-84
lines changed

3 files changed

+102
-84
lines changed

llvm/lib/Transforms/Utils/LoopUnroll.cpp

+88-70
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,37 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
328328
if (MaxTripCount && ULO.Count > MaxTripCount)
329329
ULO.Count = MaxTripCount;
330330

331+
struct ExitInfo {
332+
unsigned TripCount;
333+
unsigned TripMultiple;
334+
unsigned BreakoutTrip;
335+
bool ExitOnTrue;
336+
SmallVector<BasicBlock *> ExitingBlocks;
337+
};
338+
DenseMap<BasicBlock *, ExitInfo> ExitInfos;
339+
SmallVector<BasicBlock *, 4> ExitingBlocks;
340+
L->getExitingBlocks(ExitingBlocks);
341+
for (auto *ExitingBlock : ExitingBlocks) {
342+
// The folding code is not prepared to deal with non-branch instructions
343+
// right now.
344+
auto *BI = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
345+
if (!BI)
346+
continue;
347+
348+
ExitInfo &Info = ExitInfos.try_emplace(ExitingBlock).first->second;
349+
Info.TripCount = SE->getSmallConstantTripCount(L, ExitingBlock);
350+
Info.TripMultiple = SE->getSmallConstantTripMultiple(L, ExitingBlock);
351+
if (Info.TripCount != 0) {
352+
Info.BreakoutTrip = Info.TripCount % ULO.Count;
353+
Info.TripMultiple = 0;
354+
} else {
355+
Info.BreakoutTrip = Info.TripMultiple =
356+
(unsigned)GreatestCommonDivisor64(ULO.Count, Info.TripMultiple);
357+
}
358+
Info.ExitOnTrue = !L->contains(BI->getSuccessor(0));
359+
Info.ExitingBlocks.push_back(ExitingBlock);
360+
}
361+
331362
// Are we eliminating the loop control altogether? Note that we can know
332363
// we're eliminating the backedge without knowing exactly which iteration
333364
// of the unrolled body exits.
@@ -362,31 +393,12 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
362393

363394
// A conditional branch which exits the loop, which can be optimized to an
364395
// unconditional branch in the unrolled loop in some cases.
365-
BranchInst *ExitingBI = nullptr;
366396
bool LatchIsExiting = L->isLoopExiting(LatchBlock);
367-
if (LatchIsExiting)
368-
ExitingBI = LatchBI;
369-
else if (BasicBlock *ExitingBlock = L->getExitingBlock())
370-
ExitingBI = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
371397
if (!LatchBI || (LatchBI->isConditional() && !LatchIsExiting)) {
372398
LLVM_DEBUG(
373399
dbgs() << "Can't unroll; a conditional latch must exit the loop");
374400
return LoopUnrollResult::Unmodified;
375401
}
376-
LLVM_DEBUG({
377-
if (ExitingBI)
378-
dbgs() << " Exiting Block = " << ExitingBI->getParent()->getName()
379-
<< "\n";
380-
else
381-
dbgs() << " No single exiting block\n";
382-
});
383-
384-
// Warning: ExactTripCount is the exact trip count for the block ending in
385-
// ExitingBI, not neccessarily an exact exit count *for the loop*. The
386-
// distinction comes when we have an exiting latch, but the loop exits
387-
// through another exit first.
388-
const unsigned ExactTripCount = ExitingBI ?
389-
SE->getSmallConstantTripCount(L,ExitingBI->getParent()) : 0;
390402

391403
// Loops containing convergent instructions must have a count that divides
392404
// their TripMultiple.
@@ -421,6 +433,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
421433
}
422434

423435
// If we know the trip count, we know the multiple...
436+
// TODO: This is only used for the ORE code, remove it.
424437
unsigned BreakoutTrip = 0;
425438
if (ULO.TripCount != 0) {
426439
BreakoutTrip = ULO.TripCount % ULO.Count;
@@ -504,12 +517,9 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
504517
}
505518

506519
std::vector<BasicBlock *> Headers;
507-
std::vector<BasicBlock *> ExitingBlocks;
508520
std::vector<BasicBlock *> Latches;
509521
Headers.push_back(Header);
510522
Latches.push_back(LatchBlock);
511-
if (ExitingBI)
512-
ExitingBlocks.push_back(ExitingBI->getParent());
513523

514524
// The current on-the-fly SSA update requires blocks to be processed in
515525
// reverse postorder so that LastValueMap contains the correct value at each
@@ -609,9 +619,9 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
609619

610620
// Keep track of the exiting block and its successor block contained in
611621
// the loop for the current iteration.
612-
if (ExitingBI)
613-
if (*BB == ExitingBlocks[0])
614-
ExitingBlocks.push_back(New);
622+
auto ExitInfoIt = ExitInfos.find(*BB);
623+
if (ExitInfoIt != ExitInfos.end())
624+
ExitInfoIt->second.ExitingBlocks.push_back(New);
615625

616626
NewBlocks.push_back(New);
617627
UnrolledLoopBlocks.push_back(New);
@@ -701,71 +711,79 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
701711

702712
DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
703713

704-
if (ExitingBI) {
705-
auto SetDest = [&](BasicBlock *Src, bool WillExit, bool ExitOnTrue) {
706-
auto *Term = cast<BranchInst>(Src->getTerminator());
707-
const unsigned Idx = ExitOnTrue ^ WillExit;
708-
BasicBlock *Dest = Term->getSuccessor(Idx);
709-
BasicBlock *DeadSucc = Term->getSuccessor(1-Idx);
714+
auto SetDest = [&](BasicBlock *Src, bool WillExit, bool ExitOnTrue) {
715+
auto *Term = cast<BranchInst>(Src->getTerminator());
716+
const unsigned Idx = ExitOnTrue ^ WillExit;
717+
BasicBlock *Dest = Term->getSuccessor(Idx);
718+
BasicBlock *DeadSucc = Term->getSuccessor(1-Idx);
710719

711-
// Remove predecessors from all non-Dest successors.
712-
DeadSucc->removePredecessor(Src, /* KeepOneInputPHIs */ true);
720+
// Remove predecessors from all non-Dest successors.
721+
DeadSucc->removePredecessor(Src, /* KeepOneInputPHIs */ true);
713722

714-
// Replace the conditional branch with an unconditional one.
715-
BranchInst::Create(Dest, Term);
716-
Term->eraseFromParent();
723+
// Replace the conditional branch with an unconditional one.
724+
BranchInst::Create(Dest, Term);
725+
Term->eraseFromParent();
717726

718-
DTU.applyUpdates({{DominatorTree::Delete, Src, DeadSucc}});
719-
};
727+
DTU.applyUpdates({{DominatorTree::Delete, Src, DeadSucc}});
728+
};
720729

721-
auto WillExit = [&](unsigned i, unsigned j) -> Optional<bool> {
722-
if (CompletelyUnroll) {
723-
if (PreserveOnlyFirst) {
724-
if (i == 0)
725-
return None;
726-
return j == 0;
727-
}
728-
// Complete (but possibly inexact) unrolling
729-
if (j == 0)
730-
return true;
731-
// Warning: ExactTripCount is the trip count of the exiting
732-
// block which ends in ExitingBI, not neccessarily the loop.
733-
if (ExactTripCount && j != ExactTripCount)
734-
return false;
735-
return None;
730+
auto WillExit = [&](const ExitInfo &Info, unsigned i, unsigned j,
731+
bool IsLatch) -> Optional<bool> {
732+
if (CompletelyUnroll) {
733+
if (PreserveOnlyFirst) {
734+
if (i == 0)
735+
return None;
736+
return j == 0;
736737
}
737-
738-
if (RuntimeTripCount && j != 0)
738+
// Complete (but possibly inexact) unrolling
739+
if (j == 0)
740+
return true;
741+
if (Info.TripCount && j != Info.TripCount)
739742
return false;
743+
return None;
744+
}
740745

741-
if (j != BreakoutTrip &&
742-
(ULO.TripMultiple == 0 || j % ULO.TripMultiple != 0)) {
743-
// If we know the trip count or a multiple of it, we can safely use an
744-
// unconditional branch for some iterations.
746+
if (RuntimeTripCount) {
747+
// If runtime unrolling inserts a prologue, information about non-latch
748+
// exits may be stale.
749+
if (IsLatch && j != 0)
745750
return false;
746-
}
747751
return None;
748-
};
752+
}
753+
754+
if (j != Info.BreakoutTrip &&
755+
(Info.TripMultiple == 0 || j % Info.TripMultiple != 0)) {
756+
// If we know the trip count or a multiple of it, we can safely use an
757+
// unconditional branch for some iterations.
758+
return false;
759+
}
760+
return None;
761+
};
749762

750-
// Fold branches for iterations where we know that they will exit or not
751-
// exit.
752-
bool ExitOnTrue = !L->contains(ExitingBI->getSuccessor(0));
753-
for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
763+
// Fold branches for iterations where we know that they will exit or not
764+
// exit.
765+
for (const auto &Pair : ExitInfos) {
766+
const ExitInfo &Info = Pair.second;
767+
for (unsigned i = 0, e = Info.ExitingBlocks.size(); i != e; ++i) {
754768
// The branch destination.
755769
unsigned j = (i + 1) % e;
756-
Optional<bool> KnownWillExit = WillExit(i, j);
770+
bool IsLatch = Pair.first == LatchBlock;
771+
Optional<bool> KnownWillExit = WillExit(Info, i, j, IsLatch);
757772
if (!KnownWillExit)
758773
continue;
759774

760-
// TODO: Also fold known-exiting branches for non-latch exits.
761-
if (*KnownWillExit && !LatchIsExiting)
775+
// We don't fold known-exiting branches for non-latch exits here,
776+
// because this ensures that both all loop blocks and all exit blocks
777+
// remain reachable in the CFG.
778+
// TODO: We could fold these branches, but it would require much more
779+
// sophisticated updates to LoopInfo.
780+
if (*KnownWillExit && !IsLatch)
762781
continue;
763782

764-
SetDest(ExitingBlocks[i], *KnownWillExit, ExitOnTrue);
783+
SetDest(Info.ExitingBlocks[i], *KnownWillExit, Info.ExitOnTrue);
765784
}
766785
}
767786

768-
769787
// When completely unrolling, the last latch becomes unreachable.
770788
if (!LatchIsExiting && CompletelyUnroll)
771789
changeToUnreachable(Latches.back()->getTerminator(), /* UseTrap */ false,

llvm/test/Transforms/LoopUnroll/multiple-exits.ll

+11-11
Original file line numberDiff line numberDiff line change
@@ -9,49 +9,49 @@ define void @test1() {
99
; CHECK-NEXT: br label [[LOOP:%.*]]
1010
; CHECK: loop:
1111
; CHECK-NEXT: call void @bar()
12-
; CHECK-NEXT: br i1 true, label [[LATCH:%.*]], label [[EXIT:%.*]]
12+
; CHECK-NEXT: br label [[LATCH:%.*]]
1313
; CHECK: latch:
1414
; CHECK-NEXT: call void @bar()
1515
; CHECK-NEXT: call void @bar()
16-
; CHECK-NEXT: br i1 true, label [[LATCH_1:%.*]], label [[EXIT]]
16+
; CHECK-NEXT: br label [[LATCH_1:%.*]]
1717
; CHECK: exit:
1818
; CHECK-NEXT: ret void
1919
; CHECK: latch.1:
2020
; CHECK-NEXT: call void @bar()
2121
; CHECK-NEXT: call void @bar()
22-
; CHECK-NEXT: br i1 true, label [[LATCH_2:%.*]], label [[EXIT]]
22+
; CHECK-NEXT: br label [[LATCH_2:%.*]]
2323
; CHECK: latch.2:
2424
; CHECK-NEXT: call void @bar()
2525
; CHECK-NEXT: call void @bar()
26-
; CHECK-NEXT: br i1 true, label [[LATCH_3:%.*]], label [[EXIT]]
26+
; CHECK-NEXT: br label [[LATCH_3:%.*]]
2727
; CHECK: latch.3:
2828
; CHECK-NEXT: call void @bar()
2929
; CHECK-NEXT: call void @bar()
30-
; CHECK-NEXT: br i1 true, label [[LATCH_4:%.*]], label [[EXIT]]
30+
; CHECK-NEXT: br label [[LATCH_4:%.*]]
3131
; CHECK: latch.4:
3232
; CHECK-NEXT: call void @bar()
3333
; CHECK-NEXT: call void @bar()
34-
; CHECK-NEXT: br i1 true, label [[LATCH_5:%.*]], label [[EXIT]]
34+
; CHECK-NEXT: br label [[LATCH_5:%.*]]
3535
; CHECK: latch.5:
3636
; CHECK-NEXT: call void @bar()
3737
; CHECK-NEXT: call void @bar()
38-
; CHECK-NEXT: br i1 true, label [[LATCH_6:%.*]], label [[EXIT]]
38+
; CHECK-NEXT: br label [[LATCH_6:%.*]]
3939
; CHECK: latch.6:
4040
; CHECK-NEXT: call void @bar()
4141
; CHECK-NEXT: call void @bar()
42-
; CHECK-NEXT: br i1 true, label [[LATCH_7:%.*]], label [[EXIT]]
42+
; CHECK-NEXT: br label [[LATCH_7:%.*]]
4343
; CHECK: latch.7:
4444
; CHECK-NEXT: call void @bar()
4545
; CHECK-NEXT: call void @bar()
46-
; CHECK-NEXT: br i1 true, label [[LATCH_8:%.*]], label [[EXIT]]
46+
; CHECK-NEXT: br label [[LATCH_8:%.*]]
4747
; CHECK: latch.8:
4848
; CHECK-NEXT: call void @bar()
4949
; CHECK-NEXT: call void @bar()
50-
; CHECK-NEXT: br i1 true, label [[LATCH_9:%.*]], label [[EXIT]]
50+
; CHECK-NEXT: br label [[LATCH_9:%.*]]
5151
; CHECK: latch.9:
5252
; CHECK-NEXT: call void @bar()
5353
; CHECK-NEXT: call void @bar()
54-
; CHECK-NEXT: br i1 false, label [[LATCH_10:%.*]], label [[EXIT]]
54+
; CHECK-NEXT: br i1 false, label [[LATCH_10:%.*]], label [[EXIT:%.*]]
5555
; CHECK: latch.10:
5656
; CHECK-NEXT: call void @bar()
5757
; CHECK-NEXT: br label [[EXIT]]

llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll

+3-3
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ define void @test3(i32* noalias %A, i1 %cond) {
168168
; CHECK-NEXT: call void @bar(i32 [[TMP0]])
169169
; CHECK-NEXT: br i1 [[COND:%.*]], label [[FOR_BODY:%.*]], label [[FOR_END:%.*]]
170170
; CHECK: for.body:
171-
; CHECK-NEXT: br i1 true, label [[FOR_BODY_FOR_BODY_CRIT_EDGE:%.*]], label [[FOR_END]]
171+
; CHECK-NEXT: br label [[FOR_BODY_FOR_BODY_CRIT_EDGE:%.*]]
172172
; CHECK: for.body.for.body_crit_edge:
173173
; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 1
174174
; CHECK-NEXT: [[DOTPRE:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT]], align 4
@@ -177,14 +177,14 @@ define void @test3(i32* noalias %A, i1 %cond) {
177177
; CHECK: for.end:
178178
; CHECK-NEXT: ret void
179179
; CHECK: for.body.1:
180-
; CHECK-NEXT: br i1 true, label [[FOR_BODY_FOR_BODY_CRIT_EDGE_1:%.*]], label [[FOR_END]]
180+
; CHECK-NEXT: br label [[FOR_BODY_FOR_BODY_CRIT_EDGE_1:%.*]]
181181
; CHECK: for.body.for.body_crit_edge.1:
182182
; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT_1:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 2
183183
; CHECK-NEXT: [[DOTPRE_1:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT_1]], align 4
184184
; CHECK-NEXT: call void @bar(i32 [[DOTPRE_1]])
185185
; CHECK-NEXT: br i1 [[COND]], label [[FOR_BODY_2:%.*]], label [[FOR_END]]
186186
; CHECK: for.body.2:
187-
; CHECK-NEXT: br i1 true, label [[FOR_BODY_FOR_BODY_CRIT_EDGE_2:%.*]], label [[FOR_END]]
187+
; CHECK-NEXT: br label [[FOR_BODY_FOR_BODY_CRIT_EDGE_2:%.*]]
188188
; CHECK: for.body.for.body_crit_edge.2:
189189
; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT_2:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 3
190190
; CHECK-NEXT: [[DOTPRE_2:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT_2]], align 4

0 commit comments

Comments
 (0)