Skip to content

Commit 9eefd84

Browse files
authored
Handle non analyzed blocks in loop rematerialization (rust-lang#957)
1 parent e08703f commit 9eefd84

File tree

1 file changed

+32
-9
lines changed

1 file changed

+32
-9
lines changed

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2891,7 +2891,14 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
28912891
// reverse of the branching block.
28922892
if (rB == origLI->getHeader())
28932893
return reverseBlocks[getNewFromOriginal(B)].front();
2894-
return origToNewForward[rB];
2894+
auto found = origToNewForward.find(rB);
2895+
if (found == origToNewForward.end()) {
2896+
llvm::errs() << *newFunc << "\n";
2897+
llvm::errs() << *origLI << "\n";
2898+
llvm::errs() << *rB << "\n";
2899+
}
2900+
assert(found != origToNewForward.end());
2901+
return found->second;
28952902
};
28962903

28972904
// TODO clone terminator
@@ -2900,20 +2907,36 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
29002907
if (notForAnalysis.count(B)) {
29012908
NB.CreateUnreachable();
29022909
} else if (auto BI = dyn_cast<BranchInst>(TI)) {
2903-
if (BI->isUnconditional())
2904-
NB.CreateBr(remap(BI->getSuccessor(0)));
2905-
else
2906-
NB.CreateCondBr(lookupM(getNewFromOriginal(BI->getCondition()),
2907-
NB, available),
2908-
remap(BI->getSuccessor(0)),
2909-
remap(BI->getSuccessor(1)));
2910+
if (BI->isUnconditional()) {
2911+
if (notForAnalysis.count(BI->getSuccessor(0)))
2912+
NB.CreateUnreachable();
2913+
else
2914+
NB.CreateBr(remap(BI->getSuccessor(0)));
2915+
} else {
2916+
if (notForAnalysis.count(BI->getSuccessor(0))) {
2917+
if (notForAnalysis.count(BI->getSuccessor(1))) {
2918+
NB.CreateUnreachable();
2919+
} else {
2920+
NB.CreateBr(remap(BI->getSuccessor(1)));
2921+
}
2922+
} else if (notForAnalysis.count(BI->getSuccessor(1))) {
2923+
NB.CreateBr(remap(BI->getSuccessor(0)));
2924+
} else {
2925+
NB.CreateCondBr(
2926+
lookupM(getNewFromOriginal(BI->getCondition()), NB,
2927+
available),
2928+
remap(BI->getSuccessor(0)), remap(BI->getSuccessor(1)));
2929+
}
2930+
}
29102931
} else if (auto SI = dyn_cast<SwitchInst>(TI)) {
29112932
auto NSI = NB.CreateSwitch(
29122933
lookupM(getNewFromOriginal(BI->getCondition()), NB,
29132934
available),
29142935
remap(SI->getDefaultDest()));
29152936
for (auto cas : SI->cases()) {
2916-
NSI->addCase(cas.getCaseValue(), remap(cas.getCaseSuccessor()));
2937+
if (!notForAnalysis.count(cas.getCaseSuccessor()))
2938+
NSI->addCase(cas.getCaseValue(),
2939+
remap(cas.getCaseSuccessor()));
29172940
}
29182941
} else {
29192942
assert(isa<UnreachableInst>(TI));

0 commit comments

Comments
 (0)