Skip to content

Commit bb4121e

Browse files
witstorm95ChuanqiXu9
authored andcommitted
[Coroutines] Add an O(n) algorithm for computing the cross suspend point
information. Fixed #62348 Propagate cross suspend point information by visiting CFG. Just only go through two times at most, you can get all the cross suspend point information. Before the patch: ``` n: 20000 4.31user 0.11system 0:04.44elapsed 99%CPU (0avgtext+0avgdata 552352maxresident)k 0inputs+8848outputs (0major+126254minor)pagefaults 0swaps n: 40000 11.24user 0.40system 0:11.66elapsed 99%CPU (0avgtext+0avgdata 1788404maxresident)k 0inputs+17600outputs (0major+431105minor)pagefaults 0swaps n: 60000 21.65user 0.96system 0:22.62elapsed 99%CPU (0avgtext+0avgdata 3809836maxresident)k 0inputs+26352outputs (0major+934749minor)pagefaults 0swaps n: 80000 37.05user 1.53system 0:38.58elapsed 99%CPU (0avgtext+0avgdata 6602396maxresident)k 0inputs+35096outputs (0major+1622584minor)pagefaults 0swaps n: 100000 51.87user 2.67system 0:54.54elapsed 99%CPU (0avgtext+0avgdata 10210736maxresident)k 0inputs+43848outputs (0major+2518945minor)pagefaults 0swaps ``` After the patch: ``` n: 20000 3.17user 0.16system 0:03.33elapsed 100%CPU (0avgtext+0avgdata 551736maxresident)k 0inputs+8848outputs (0major+126192minor)pagefaults 0swaps n: 40000 6.10user 0.42system 0:06.54elapsed 99%CPU (0avgtext+0avgdata 1787848maxresident)k 0inputs+17600outputs (0major+432212minor)pagefaults 0swaps n: 60000 9.13user 0.89system 0:10.03elapsed 99%CPU (0avgtext+0avgdata 3809108maxresident)k 0inputs+26352outputs (0major+931280minor)pagefaults 0swaps n: 80000 12.44user 1.57system 0:14.02elapsed 99%CPU (0avgtext+0avgdata 6603432maxresident)k 0inputs+35096outputs (0major+1624635minor)pagefaults 0swaps n: 100000 16.29user 2.28system 0:18.59elapsed 99%CPU (0avgtext+0avgdata 10212808maxresident)k 0inputs+43848outputs (0major+2522200minor)pagefaults 0swaps ```
1 parent 7c760b2 commit bb4121e

File tree

1 file changed

+150
-68
lines changed

1 file changed

+150
-68
lines changed

llvm/lib/Transforms/Coroutines/CoroFrame.cpp

Lines changed: 150 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -98,24 +98,57 @@ class SuspendCrossingInfo {
9898
bool Suspend = false;
9999
bool End = false;
100100
bool KillLoop = false;
101-
bool Changed = false;
102101
};
103102
SmallVector<BlockData, SmallVectorThreshold> Block;
104103

105104
iterator_range<pred_iterator> predecessors(BlockData const &BD) const {
106105
BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]);
107106
return llvm::predecessors(BB);
108107
}
108+
size_t pred_size(BlockData const &BD) const {
109+
BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]);
110+
return llvm::pred_size(BB);
111+
}
112+
iterator_range<succ_iterator> successors(BlockData const &BD) const {
113+
BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]);
114+
return llvm::successors(BB);
115+
}
109116

110117
BlockData &getBlockData(BasicBlock *BB) {
111118
return Block[Mapping.blockToIndex(BB)];
112119
}
113120

114-
/// Compute the BlockData for the current function in one iteration.
115-
/// Returns whether the BlockData changes in this iteration.
116-
/// Initialize - Whether this is the first iteration, we can optimize
117-
/// the initial case a little bit by manual loop switch.
118-
template <bool Initialize = false> bool computeBlockData();
121+
/// This algorithm is based on topological sorting. As we know, topological
122+
/// sorting is typically used on Directed Acyclic Graph (DAG). However, a
123+
/// Control Flow Graph (CFG) may not always be a DAG, as it can contain back
124+
/// edges or loops. To handle this, we need to break the back edge when we
125+
/// encounter it in order to ensure a valid topological sorting.
126+
/// Why do we need an extra traversal when a CFG contains a back edge?
127+
/// Firstly, we need to figure out how the Consumes information propagates
128+
/// along the back edge. For example,
129+
///
130+
/// A -> B -> C -> D -> H
131+
/// ^ |
132+
/// | v
133+
/// G <- F <- E
134+
///
135+
/// Following the direction of the arrow, we can obtain the traversal
136+
/// sequences: A, B, C, D, H, E, F, G or A, B, C, D, E, H, F, G. We know that
137+
/// there is a path from C to G after the first traversal. However, we are
138+
/// uncertain about the existence of a path from G to C, as the Consumes info
139+
/// of G has not yet propagated to C (via B). Therefore, we need a second
140+
/// traversal to propagate G's Consumes info to C (via B) and its successors.
141+
/// The second traversal allows us to obtain the complete Consumes info. Since
142+
/// the computation of the Kills info depends on the Consumes info.
143+
144+
/// The parameter "EntryNo" represents the index associated with the entry
145+
/// block.
146+
/// The parameter "BlockPredecessorsNum" represents the number of predecessors
147+
/// for each block.
148+
/// Returns true if there exists back edges in CFG.
149+
template <bool HasBackEdge = false>
150+
bool collectConsumeKillInfo(size_t EntryNo,
151+
const SmallVector<size_t> &BlockPredecessorsNum);
119152

120153
public:
121154
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -223,84 +256,132 @@ LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const {
223256
}
224257
#endif
225258

226-
template <bool Initialize> bool SuspendCrossingInfo::computeBlockData() {
227-
const size_t N = Mapping.size();
228-
bool Changed = false;
229-
230-
for (size_t I = 0; I < N; ++I) {
231-
auto &B = Block[I];
259+
template <bool HasBackEdge>
260+
bool SuspendCrossingInfo::collectConsumeKillInfo(
261+
size_t EntryNo, const SmallVector<size_t> &BlockPredecessorsNum) {
262+
bool FoundBackEdge = false;
263+
SmallVector<size_t> UnvisitedBlockPredNum = BlockPredecessorsNum;
264+
// BlockNo Queue with BlockPredNum[BlockNo] equal to zero.
265+
std::queue<size_t> CandidateQueue;
266+
// For blocks that maybe has a back edge.
267+
DenseSet<size_t> MaybeBackEdgeSet;
268+
// Visit BlockNo
269+
auto visit = [&](size_t BlockNo) {
270+
switch (UnvisitedBlockPredNum[BlockNo]) {
271+
// Already visited, not visit again.
272+
case 0:
273+
break;
274+
// If predecessors number of BlockNo is 1, it means all predecessors of
275+
// BlockNo have propagated its info to BlockNo. So add BlockNo to
276+
// CandidateQueue.
277+
case 1: {
278+
CandidateQueue.push(BlockNo);
279+
MaybeBackEdgeSet.erase(BlockNo);
280+
UnvisitedBlockPredNum[BlockNo] = 0;
281+
break;
282+
}
283+
// If predecessors number of BlockNo bigger than 1, it means BlockNo not
284+
// collect full Consumes/Kills info yet. So decrease
285+
// UnvisitedBlockPredNum[BlockNo] and insert BlockNo into MaybeBackEdgeSet.
286+
default: {
287+
UnvisitedBlockPredNum[BlockNo]--;
288+
MaybeBackEdgeSet.insert(BlockNo);
289+
break;
290+
}
291+
}
292+
};
232293

233-
// We don't need to count the predecessors when initialization.
234-
if constexpr (!Initialize)
235-
// If all the predecessors of the current Block don't change,
236-
// the BlockData for the current block must not change too.
237-
if (all_of(predecessors(B), [this](BasicBlock *BB) {
238-
return !Block[Mapping.blockToIndex(BB)].Changed;
239-
})) {
240-
B.Changed = false;
241-
continue;
294+
CandidateQueue.push(EntryNo);
295+
296+
// Topological sorting.
297+
while (!CandidateQueue.empty()) {
298+
auto &B = Block[CandidateQueue.front()];
299+
CandidateQueue.pop();
300+
for (BasicBlock *SI : successors(B)) {
301+
auto SuccNo = Mapping.blockToIndex(SI);
302+
auto &S = Block[SuccNo];
303+
304+
// Propagate Kills and Consumes from predecessors into S.
305+
S.Consumes |= B.Consumes;
306+
S.Kills |= B.Kills;
307+
308+
if (B.Suspend)
309+
S.Kills |= B.Consumes;
310+
311+
if (S.Suspend) {
312+
// If block S is a suspend block, it should kill all of the blocks
313+
// it consumes.
314+
S.Kills |= S.Consumes;
315+
} else if (S.End) {
316+
// If block S is an end block, it should not propagate kills as the
317+
// blocks following coro.end() are reached during initial invocation
318+
// of the coroutine while all the data are still available on the
319+
// stack or in the registers.
320+
S.Kills.reset();
321+
} else {
322+
// This is reached when S block it not Suspend nor coro.end and it
323+
// need to make sure that it is not in the kill set.
324+
S.KillLoop |= S.Kills[SuccNo];
325+
S.Kills.reset(SuccNo);
242326
}
243-
244-
// Saved Consumes and Kills bitsets so that it is easy to see
245-
// if anything changed after propagation.
246-
auto SavedConsumes = B.Consumes;
247-
auto SavedKills = B.Kills;
248-
249-
for (BasicBlock *PI : predecessors(B)) {
250-
auto PrevNo = Mapping.blockToIndex(PI);
251-
auto &P = Block[PrevNo];
252-
253-
// Propagate Kills and Consumes from predecessors into B.
254-
B.Consumes |= P.Consumes;
255-
B.Kills |= P.Kills;
256-
257-
// If block P is a suspend block, it should propagate kills into block
258-
// B for every block P consumes.
259-
if (P.Suspend)
260-
B.Kills |= P.Consumes;
327+
// visit SuccNo.
328+
visit(SuccNo);
261329
}
262330

263-
if (B.Suspend) {
264-
// If block S is a suspend block, it should kill all of the blocks it
265-
// consumes.
266-
B.Kills |= B.Consumes;
267-
} else if (B.End) {
268-
// If block B is an end block, it should not propagate kills as the
269-
// blocks following coro.end() are reached during initial invocation
270-
// of the coroutine while all the data are still available on the
271-
// stack or in the registers.
272-
B.Kills.reset();
273-
} else {
274-
// This is reached when B block it not Suspend nor coro.end and it
275-
// need to make sure that it is not in the kill set.
276-
B.KillLoop |= B.Kills[I];
277-
B.Kills.reset(I);
278-
}
331+
// If the CandidateQueue is empty but the MaybeBackEdgeSet is not, it
332+
// indicates the presence of a back edge that needs to be addressed. In such
333+
// cases, it is necessary to break the back edge.
334+
if (CandidateQueue.empty() && !MaybeBackEdgeSet.empty()) {
335+
FoundBackEdge = true;
336+
size_t CandidateNo = -1;
337+
if constexpr (HasBackEdge) {
338+
auto IsCandidate = [this](size_t I) {
339+
for (BasicBlock *PI : llvm::predecessors(Mapping.indexToBlock(I))) {
340+
auto PredNo = Mapping.blockToIndex(PI);
341+
auto &P = Block[PredNo];
342+
// The node I can reach its predecessor. So we found a loop.
343+
if (P.Consumes[I])
344+
return true;
345+
}
346+
347+
return false;
348+
};
279349

280-
if constexpr (!Initialize) {
281-
B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes);
282-
Changed |= B.Changed;
350+
for (auto I : MaybeBackEdgeSet) {
351+
if (IsCandidate(I)) {
352+
CandidateNo = I;
353+
break;
354+
}
355+
}
356+
assert(CandidateNo != size_t(-1) && "We collected the wrong backegdes");
357+
} else
358+
// When the value of HasBackEdge is false and we don't have any
359+
// information about back edges, we can simply select one block from the
360+
// MaybeBackEdgeSet.
361+
CandidateNo = *(MaybeBackEdgeSet.begin());
362+
CandidateQueue.push(CandidateNo);
363+
MaybeBackEdgeSet.erase(CandidateNo);
364+
UnvisitedBlockPredNum[CandidateNo] = 0;
283365
}
284366
}
285-
286-
if constexpr (Initialize)
287-
return true;
288-
289-
return Changed;
367+
return FoundBackEdge;
290368
}
291369

292370
SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
293371
: Mapping(F) {
294372
const size_t N = Mapping.size();
295373
Block.resize(N);
296374

375+
size_t EntryNo = Mapping.blockToIndex(&(F.getEntryBlock()));
376+
SmallVector<size_t> BlockPredecessorsNum(N, 0);
377+
297378
// Initialize every block so that it consumes itself
298379
for (size_t I = 0; I < N; ++I) {
299380
auto &B = Block[I];
300381
B.Consumes.resize(N);
301382
B.Kills.resize(N);
302383
B.Consumes.set(I);
303-
B.Changed = true;
384+
BlockPredecessorsNum[I] = pred_size(B);
304385
}
305386

306387
// Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as
@@ -325,10 +406,11 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
325406
markSuspendBlock(Save);
326407
}
327408

328-
computeBlockData</*Initialize=*/true>();
329-
330-
while (computeBlockData())
331-
;
409+
// We should collect the Consumes and Kills information initially. If there is
410+
// a back edge present, it is necessary to perform the collection process
411+
// again.
412+
if (collectConsumeKillInfo(EntryNo, BlockPredecessorsNum))
413+
collectConsumeKillInfo</*HasBackEdge*/ true>(EntryNo, BlockPredecessorsNum);
332414

333415
LLVM_DEBUG(dump());
334416
}

0 commit comments

Comments
 (0)