Skip to content

Commit 5980f31

Browse files
authored
Custom return fixup (rust-lang#912)
1 parent 952a4f5 commit 5980f31

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

+13-9
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ cl::opt<bool> nonmarkedglobals_inactiveloads(
9090
cl::opt<bool> EnzymeJuliaAddrLoad(
9191
"enzyme-julia-addr-load", cl::init(false), cl::Hidden,
9292
cl::desc("Mark all loads resulting in an addr(13)* to be legal to redo"));
93+
94+
LLVMValueRef (*EnzymeFixupReturn)(LLVMBuilderRef, LLVMValueRef) = nullptr;
9395
}
9496

9597
struct CacheAnalysis {
@@ -2630,13 +2632,17 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
26302632
if (auto ggep = dyn_cast<GetElementPtrInst>(gep)) {
26312633
ggep->setIsInBounds(true);
26322634
}
2635+
if (EnzymeFixupReturn)
2636+
actualrv = unwrap(EnzymeFixupReturn(wrap(&ib), wrap(actualrv)));
26332637
auto storeinst = ib.CreateStore(actualrv, gep);
26342638
PostCacheStore(storeinst, ib);
26352639
}
26362640

26372641
if (shadowReturnUsed) {
26382642
assert(invertedRetPs[ri]);
2639-
if (!isa<UndefValue>(invertedRetPs[ri])) {
2643+
Value *shadowRV = invertedRetPs[ri];
2644+
2645+
if (!isa<UndefValue>(shadowRV)) {
26402646
Value *gep =
26412647
removeStruct
26422648
? ret
@@ -2648,15 +2654,13 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
26482654
if (auto ggep = dyn_cast<GetElementPtrInst>(gep)) {
26492655
ggep->setIsInBounds(true);
26502656
}
2651-
if (isa<ConstantExpr>(invertedRetPs[ri]) ||
2652-
isa<ConstantData>(invertedRetPs[ri])) {
2653-
auto storeinst = ib.CreateStore(invertedRetPs[ri], gep);
2654-
PostCacheStore(storeinst, ib);
2655-
} else {
2656-
assert(VMap[invertedRetPs[ri]]);
2657-
auto storeinst = ib.CreateStore(VMap[invertedRetPs[ri]], gep);
2658-
PostCacheStore(storeinst, ib);
2657+
if (!(isa<ConstantExpr>(shadowRV) || isa<ConstantData>(shadowRV))) {
2658+
shadowRV = VMap[shadowRV];
26592659
}
2660+
if (EnzymeFixupReturn)
2661+
shadowRV = unwrap(EnzymeFixupReturn(wrap(&ib), wrap(shadowRV)));
2662+
auto storeinst = ib.CreateStore(shadowRV, gep);
2663+
PostCacheStore(storeinst, ib);
26602664
}
26612665
}
26622666
if (noReturn)

0 commit comments

Comments
 (0)