Skip to content

Commit 5a6e056

Browse files
authored
Fix recursive cache reverse (rust-lang#826)
1 parent 78145b6 commit 5a6e056

File tree

3 files changed

+52
-6
lines changed

3 files changed

+52
-6
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

+21-4
Original file line numberDiff line numberDiff line change
@@ -2125,7 +2125,6 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
21252125
insert_or_assign(AugmentedCachedFunctions, tup,
21262126
AugmentedReturn(gutils->newFunc, nullptr, {}, returnMapping,
21272127
uncacheable_args_map, can_modref_map));
2128-
AugmentedCachedFinished[tup] = false;
21292128

21302129
auto getIndex = [&](Instruction *I, CacheType u) -> unsigned {
21312130
return gutils->getIndex(
@@ -2708,7 +2707,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
27082707
AugmentedCachedFunctions.find(tup)->second.fn = NewF;
27092708
if (recursive || (omp && !noTape))
27102709
AugmentedCachedFunctions.find(tup)->second.tapeType = tapeType;
2711-
insert_or_assign(AugmentedCachedFinished, tup, true);
2710+
AugmentedCachedFunctions.find(tup)->second.isComplete = true;
27122711

27132712
for (auto pair : gfnusers) {
27142713
auto GV = pair.first;
@@ -3226,8 +3225,14 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
32263225

32273226
if (key.retType != DIFFE_TYPE::CONSTANT)
32283227
assert(!key.todiff->getReturnType()->isVoidTy());
3228+
3229+
Function *prevFunction = nullptr;
32293230
if (ReverseCachedFunctions.find(key) != ReverseCachedFunctions.end()) {
3230-
return ReverseCachedFunctions.find(key)->second;
3231+
prevFunction = ReverseCachedFunctions.find(key)->second;
3232+
if (!hasMetadata(prevFunction, "enzyme_placeholder"))
3233+
return prevFunction;
3234+
if (augmenteddata && !augmenteddata->isComplete)
3235+
return prevFunction;
32313236
}
32323237

32333238
if (key.returnUsed)
@@ -3641,6 +3646,14 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
36413646
insert_or_assign2<ReverseCacheKey, Function *>(ReverseCachedFunctions, key,
36423647
gutils->newFunc);
36433648

3649+
if (augmenteddata && !augmenteddata->isComplete) {
3650+
auto nf = gutils->newFunc;
3651+
delete gutils;
3652+
assert(!prevFunction);
3653+
nf->setMetadata("enzyme_placeholder", MDTuple::get(nf->getContext(), {}));
3654+
return nf;
3655+
}
3656+
36443657
const SmallPtrSet<BasicBlock *, 4> guaranteedUnreachable =
36453658
getGuaranteedUnreachable(gutils->oldFunc);
36463659

@@ -4020,6 +4033,11 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
40204033
if (Arch == Triple::nvptx || Arch == Triple::nvptx64)
40214034
PPC.ReplaceReallocs(nf, /*mem2reg*/ true);
40224035

4036+
if (prevFunction) {
4037+
prevFunction->replaceAllUsesWith(nf);
4038+
prevFunction->eraseFromParent();
4039+
}
4040+
40234041
// Do not run post processing optimizations if the body of an openmp
40244042
// parallel so the adjointgenerator can successfully extract the allocation
40254043
// and frees and hoist them into the parent. Optimizing before then may
@@ -4714,6 +4732,5 @@ llvm::Function *EnzymeLogic::CreateBatch(Function *tobatch, unsigned width,
47144732
void EnzymeLogic::clear() {
47154733
PPC.clear();
47164734
AugmentedCachedFunctions.clear();
4717-
AugmentedCachedFinished.clear();
47184735
ReverseCachedFunctions.clear();
47194736
}

enzyme/Enzyme/EnzymeLogic.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ class AugmentedReturn {
117117

118118
std::set<ssize_t> tapeIndiciesToFree;
119119

120+
bool isComplete;
121+
120122
AugmentedReturn(
121123
llvm::Function *fn, llvm::Type *tapeType,
122124
std::map<std::pair<llvm::Instruction *, CacheType>, int> tapeIndices,
@@ -126,7 +128,7 @@ class AugmentedReturn {
126128
std::map<llvm::Instruction *, bool> can_modref_map)
127129
: fn(fn), tapeType(tapeType), tapeIndices(tapeIndices), returns(returns),
128130
uncacheable_args_map(uncacheable_args_map),
129-
can_modref_map(can_modref_map) {}
131+
can_modref_map(can_modref_map), isComplete(false) {}
130132
};
131133

132134
struct ReverseCacheKey {
@@ -329,7 +331,6 @@ class EnzymeLogic {
329331
};
330332

331333
std::map<AugmentedCacheKey, AugmentedReturn> AugmentedCachedFunctions;
332-
std::map<AugmentedCacheKey, bool> AugmentedCachedFinished;
333334

334335
/// Create an augmented forward pass.
335336
/// \p todiff is the function to differentiate
+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -simplifycfg -S | FileCheck %s
2+
3+
%DynamicsStruct = type { i8**, void (%DynamicsStruct*)** }
4+
5+
@someGlobal = internal constant i8* bitcast (void (%DynamicsStruct*)* @asdf to i8*)
6+
7+
define internal void @asdf(%DynamicsStruct* %arg) {
8+
bb:
9+
%i = getelementptr inbounds %DynamicsStruct, %DynamicsStruct* %arg, i64 0, i32 0
10+
store i8** @someGlobal, i8*** %i, align 8
11+
%i5 = getelementptr inbounds %DynamicsStruct, %DynamicsStruct* %arg, i64 0, i32 1
12+
%i6 = load void (%DynamicsStruct*)**, void (%DynamicsStruct*)*** %i5, align 8
13+
%i8 = load void (%DynamicsStruct*)*, void (%DynamicsStruct*)** %i6, align 8
14+
tail call void %i8(%DynamicsStruct* %arg)
15+
ret void
16+
}
17+
18+
declare i8* @_Z17__enzyme_virtualreversePv(...)
19+
20+
define internal void @_Z19testSensitivitiesADv() {
21+
bb40:
22+
call i8* (...) @_Z17__enzyme_virtualreversePv(void (%DynamicsStruct*)* @asdf)
23+
ret void
24+
}
25+
26+
; CHECK: define internal i8* @augmented_asdf(%DynamicsStruct* %arg, %DynamicsStruct* %"arg'")
27+
28+
; CHECK: define internal void @diffeasdf.1(%DynamicsStruct* %arg, %DynamicsStruct* %"arg'", i8* %tapeArg)

0 commit comments

Comments
 (0)