@@ -2125,7 +2125,6 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
2125
2125
insert_or_assign (AugmentedCachedFunctions, tup,
2126
2126
AugmentedReturn (gutils->newFunc , nullptr , {}, returnMapping,
2127
2127
uncacheable_args_map, can_modref_map));
2128
- AugmentedCachedFinished[tup] = false ;
2129
2128
2130
2129
auto getIndex = [&](Instruction *I, CacheType u) -> unsigned {
2131
2130
return gutils->getIndex (
@@ -2708,7 +2707,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
2708
2707
AugmentedCachedFunctions.find (tup)->second .fn = NewF;
2709
2708
if (recursive || (omp && !noTape))
2710
2709
AugmentedCachedFunctions.find (tup)->second .tapeType = tapeType;
2711
- insert_or_assign (AugmentedCachedFinished, tup, true ) ;
2710
+ AugmentedCachedFunctions. find ( tup)-> second . isComplete = true ;
2712
2711
2713
2712
for (auto pair : gfnusers) {
2714
2713
auto GV = pair.first ;
@@ -3226,8 +3225,14 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
3226
3225
3227
3226
if (key.retType != DIFFE_TYPE::CONSTANT)
3228
3227
assert (!key.todiff ->getReturnType ()->isVoidTy ());
3228
+
3229
+ Function *prevFunction = nullptr ;
3229
3230
if (ReverseCachedFunctions.find (key) != ReverseCachedFunctions.end ()) {
3230
- return ReverseCachedFunctions.find (key)->second ;
3231
+ prevFunction = ReverseCachedFunctions.find (key)->second ;
3232
+ if (!hasMetadata (prevFunction, " enzyme_placeholder" ))
3233
+ return prevFunction;
3234
+ if (augmenteddata && !augmenteddata->isComplete )
3235
+ return prevFunction;
3231
3236
}
3232
3237
3233
3238
if (key.returnUsed )
@@ -3641,6 +3646,14 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
3641
3646
insert_or_assign2<ReverseCacheKey, Function *>(ReverseCachedFunctions, key,
3642
3647
gutils->newFunc );
3643
3648
3649
+ if (augmenteddata && !augmenteddata->isComplete ) {
3650
+ auto nf = gutils->newFunc ;
3651
+ delete gutils;
3652
+ assert (!prevFunction);
3653
+ nf->setMetadata (" enzyme_placeholder" , MDTuple::get (nf->getContext (), {}));
3654
+ return nf;
3655
+ }
3656
+
3644
3657
const SmallPtrSet<BasicBlock *, 4 > guaranteedUnreachable =
3645
3658
getGuaranteedUnreachable (gutils->oldFunc );
3646
3659
@@ -4020,6 +4033,11 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
4020
4033
if (Arch == Triple::nvptx || Arch == Triple::nvptx64)
4021
4034
PPC.ReplaceReallocs (nf, /* mem2reg*/ true );
4022
4035
4036
+ if (prevFunction) {
4037
+ prevFunction->replaceAllUsesWith (nf);
4038
+ prevFunction->eraseFromParent ();
4039
+ }
4040
+
4023
4041
// Do not run post processing optimizations if the body of an openmp
4024
4042
// parallel so the adjointgenerator can successfully extract the allocation
4025
4043
// and frees and hoist them into the parent. Optimizing before then may
@@ -4714,6 +4732,5 @@ llvm::Function *EnzymeLogic::CreateBatch(Function *tobatch, unsigned width,
4714
4732
void EnzymeLogic::clear () {
4715
4733
PPC.clear ();
4716
4734
AugmentedCachedFunctions.clear ();
4717
- AugmentedCachedFinished.clear ();
4718
4735
ReverseCachedFunctions.clear ();
4719
4736
}
0 commit comments