@@ -90,6 +90,8 @@ cl::opt<bool> nonmarkedglobals_inactiveloads(
90
90
cl::opt<bool > EnzymeJuliaAddrLoad (
91
91
" enzyme-julia-addr-load" , cl::init(false ), cl::Hidden,
92
92
cl::desc(" Mark all loads resulting in an addr(13)* to be legal to redo" ));
93
+
94
+ LLVMValueRef (*EnzymeFixupReturn)(LLVMBuilderRef, LLVMValueRef) = nullptr;
93
95
}
94
96
95
97
struct CacheAnalysis {
@@ -2630,13 +2632,17 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
2630
2632
if (auto ggep = dyn_cast<GetElementPtrInst>(gep)) {
2631
2633
ggep->setIsInBounds (true );
2632
2634
}
2635
+ if (EnzymeFixupReturn)
2636
+ actualrv = unwrap (EnzymeFixupReturn (wrap (&ib), wrap (actualrv)));
2633
2637
auto storeinst = ib.CreateStore (actualrv, gep);
2634
2638
PostCacheStore (storeinst, ib);
2635
2639
}
2636
2640
2637
2641
if (shadowReturnUsed) {
2638
2642
assert (invertedRetPs[ri]);
2639
- if (!isa<UndefValue>(invertedRetPs[ri])) {
2643
+ Value *shadowRV = invertedRetPs[ri];
2644
+
2645
+ if (!isa<UndefValue>(shadowRV)) {
2640
2646
Value *gep =
2641
2647
removeStruct
2642
2648
? ret
@@ -2648,15 +2654,13 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
2648
2654
if (auto ggep = dyn_cast<GetElementPtrInst>(gep)) {
2649
2655
ggep->setIsInBounds (true );
2650
2656
}
2651
- if (isa<ConstantExpr>(invertedRetPs[ri]) ||
2652
- isa<ConstantData>(invertedRetPs[ri])) {
2653
- auto storeinst = ib.CreateStore (invertedRetPs[ri], gep);
2654
- PostCacheStore (storeinst, ib);
2655
- } else {
2656
- assert (VMap[invertedRetPs[ri]]);
2657
- auto storeinst = ib.CreateStore (VMap[invertedRetPs[ri]], gep);
2658
- PostCacheStore (storeinst, ib);
2657
+ if (!(isa<ConstantExpr>(shadowRV) || isa<ConstantData>(shadowRV))) {
2658
+ shadowRV = VMap[shadowRV];
2659
2659
}
2660
+ if (EnzymeFixupReturn)
2661
+ shadowRV = unwrap (EnzymeFixupReturn (wrap (&ib), wrap (shadowRV)));
2662
+ auto storeinst = ib.CreateStore (shadowRV, gep);
2663
+ PostCacheStore (storeinst, ib);
2660
2664
}
2661
2665
}
2662
2666
if (noReturn)
0 commit comments