Skip to content

Commit ca49313

Browse files
authored
Julia custom alloc fixups (rust-lang#780)
* Julia custom alloc fixups * Fix tape usage in rev * fix
1 parent 54263c2 commit ca49313

File tree

5 files changed

+61
-22
lines changed

5 files changed

+61
-22
lines changed

enzyme/Enzyme/CApi.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -588,11 +588,22 @@ const char *EnzymeGradientUtilsInvertedPointersToString(GradientUtils *gutils,
588588

589589
void EnzymeStringFree(const char *cstr) { delete[] cstr; }
590590

591-
void EnzymeMoveBefore(LLVMValueRef inst1, LLVMValueRef inst2) {
591+
void EnzymeMoveBefore(LLVMValueRef inst1, LLVMValueRef inst2,
592+
LLVMBuilderRef B) {
592593
Instruction *I1 = cast<Instruction>(unwrap(inst1));
593594
Instruction *I2 = cast<Instruction>(unwrap(inst2));
594-
if (I1 != I2)
595+
if (I1 != I2) {
596+
if (B != nullptr) {
597+
IRBuilder<> &BR = *unwrap(B);
598+
if (I1->getIterator() == BR.GetInsertPoint()) {
599+
if (I2->getNextNode() == nullptr)
600+
BR.SetInsertPoint(I1->getParent());
601+
else
602+
BR.SetInsertPoint(I1->getNextNode());
603+
}
604+
}
595605
I1->moveBefore(I2);
606+
}
596607
}
597608

598609
void EnzymeSetMustCache(LLVMValueRef inst1) {

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2312,10 +2312,6 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
23122312
report_fatal_error("function failed verification (2)");
23132313
}
23142314

2315-
StructType *sty = cast<StructType>(gutils->newFunc->getReturnType());
2316-
SmallVector<Type *, 4> RetTypes(sty->elements().begin(),
2317-
sty->elements().end());
2318-
23192315
SmallVector<Type *, 4> MallocTypes;
23202316

23212317
for (auto a : gutils->getTapeValues()) {
@@ -2338,6 +2334,30 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
23382334
forceAnonymousTape;
23392335
bool noTape = MallocTypes.size() == 0 && !forceAnonymousTape;
23402336

2337+
StructType *sty = cast<StructType>(gutils->newFunc->getReturnType());
2338+
SmallVector<Type *, 4> RetTypes(sty->elements().begin(),
2339+
sty->elements().end());
2340+
if (!noTape) {
2341+
if (recursive && !omp) {
2342+
auto size =
2343+
gutils->newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(
2344+
tapeType);
2345+
if (size != 0) {
2346+
auto i64 = Type::getInt64Ty(gutils->newFunc->getContext());
2347+
BasicBlock *BB = BasicBlock::Create(gutils->newFunc->getContext(),
2348+
"entry", gutils->newFunc);
2349+
IRBuilder<> B(BB);
2350+
2351+
CallInst *malloccall;
2352+
CreateAllocation(B, tapeType, ConstantInt::get(i64, 1), "tapemem",
2353+
&malloccall, nullptr);
2354+
RetTypes[returnMapping.find(AugmentedStruct::Tape)->second] =
2355+
malloccall->getType();
2356+
BB->eraseFromParent();
2357+
}
2358+
}
2359+
}
2360+
23412361
int oldretIdx = -1;
23422362
if (returnMapping.find(AugmentedStruct::Return) != returnMapping.end()) {
23432363
oldretIdx = returnMapping[AugmentedStruct::Return];
@@ -2367,8 +2387,6 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
23672387
}
23682388
RetTypes.erase(RetTypes.begin() + tidx);
23692389
} else if (recursive) {
2370-
assert(RetTypes[returnMapping.find(AugmentedStruct::Tape)->second] ==
2371-
Type::getInt8PtrTy(nf->getContext()));
23722390
} else {
23732391
RetTypes[returnMapping.find(AugmentedStruct::Tape)->second] = tapeType;
23742392
}
@@ -2440,7 +2458,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
24402458

24412459
IRBuilder<> ib(NewF->getEntryBlock().getFirstNonPHI());
24422460

2443-
Value *ret = noReturn ? nullptr : ib.CreateAlloca(RetType);
2461+
AllocaInst *ret = noReturn ? nullptr : ib.CreateAlloca(RetType);
24442462
if (!noReturn && EnzymeZeroCache) {
24452463
ib.CreateStore(Constant::getNullValue(RetType), ret);
24462464
}
@@ -2455,10 +2473,9 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
24552473
if (size != 0) {
24562474
CallInst *malloccall = nullptr;
24572475
Instruction *zero = nullptr;
2458-
IRBuilder<> Builder(NewF->getEntryBlock().getFirstNonPHI());
2459-
tapeMemory = CreateAllocation(
2460-
Builder, tapeType, ConstantInt::get(i64, 1), "tapemem", &malloccall,
2461-
EnzymeZeroCache ? &zero : nullptr);
2476+
tapeMemory =
2477+
CreateAllocation(ib, tapeType, ConstantInt::get(i64, 1), "tapemem",
2478+
&malloccall, EnzymeZeroCache ? &zero : nullptr);
24622479
memory = malloccall;
24632480
} else {
24642481
memory =
@@ -3690,7 +3707,10 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
36903707
IRBuilder<> BuilderZ(gutils->inversionAllocs);
36913708
if (!augmenteddata->tapeType->isEmptyTy()) {
36923709
auto tapep = BuilderZ.CreatePointerCast(
3693-
additionalValue, PointerType::getUnqual(augmenteddata->tapeType));
3710+
additionalValue,
3711+
PointerType::get(augmenteddata->tapeType,
3712+
cast<PointerType>(additionalValue->getType())
3713+
->getAddressSpace()));
36943714
#if LLVM_VERSION_MAJOR > 7
36953715
LoadInst *truetape =
36963716
BuilderZ.CreateLoad(augmenteddata->tapeType, tapep, "truetape");
@@ -3701,7 +3721,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
37013721
MDNode::get(truetape->getContext(), {}));
37023722

37033723
if (!omp && gutils->FreeMemory) {
3704-
CreateDealloc(BuilderZ, tapep);
3724+
CreateDealloc(BuilderZ, additionalValue);
37053725
}
37063726
additionalValue = truetape;
37073727
} else {

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,11 +379,14 @@ void RecursivelyReplaceAddressSpace(Value *AI, Value *rep, bool legal) {
379379
/// pass. Specifically if we're not topLevel all allocations must be upgraded
380380
/// Even if topLevel any allocations that aren't in the entry block (and
381381
/// therefore may not be reachable in the reverse pass) must be upgraded.
382-
static inline void UpgradeAllocasToMallocs(Function *NewF,
383-
DerivativeMode mode) {
382+
static inline void
383+
UpgradeAllocasToMallocs(Function *NewF, DerivativeMode mode,
384+
SmallPtrSetImpl<llvm::BasicBlock *> &Unreachable) {
384385
SmallVector<AllocaInst *, 4> ToConvert;
385386

386387
for (auto &BB : *NewF) {
388+
if (Unreachable.count(&BB))
389+
continue;
387390
for (auto &I : BB) {
388391
if (auto AI = dyn_cast<AllocaInst>(&I)) {
389392
bool UsableEverywhere = AI->getParent() == &NewF->getEntryBlock();
@@ -1653,7 +1656,8 @@ Function *PreProcessCache::preprocessForClone(Function *F,
16531656
mode == DerivativeMode::ReverseModeCombined) {
16541657
// For subfunction calls upgrade stack allocations to mallocs
16551658
// to ensure availability in the reverse pass
1656-
UpgradeAllocasToMallocs(NewF, mode);
1659+
auto unreachable = getGuaranteedUnreachable(NewF);
1660+
UpgradeAllocasToMallocs(NewF, mode, unreachable);
16571661
}
16581662

16591663
CanonicalizeLoops(NewF, FAM);

enzyme/Enzyme/Utils.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,14 @@ Value *CreateAllocation(IRBuilder<> &Builder, llvm::Type *T, Value *Count,
128128
*caller = malloccall;
129129
}
130130
if (ZeroMem) {
131+
auto PT = cast<PointerType>(malloccall->getType());
132+
Value *tozero = malloccall;
133+
if (!PT->getElementType()->isIntegerTy(8))
134+
tozero = Builder.CreatePointerCast(
135+
tozero, PointerType::get(Type::getInt8Ty(PT->getContext()),
136+
PT->getAddressSpace()));
131137
Value *args[] = {
132-
malloccall,
133-
ConstantInt::get(Type::getInt8Ty(malloccall->getContext()), 0),
138+
tozero, ConstantInt::get(Type::getInt8Ty(malloccall->getContext()), 0),
134139
Builder.CreateMul(Align, Count, "", true, true),
135140
ConstantInt::getFalse(malloccall->getContext())};
136141
Type *tys[] = {args[0]->getType(), args[2]->getType()};

enzyme/test/Enzyme/ReverseMode/loadcall.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ attributes #4 = { nounwind }
7272
; CHECK-NEXT: entry:
7373
; CHECK-NEXT: %0 = bitcast i8* %tapeArg to double**
7474
; CHECK-NEXT: %[[callp:.+]] = load double*, double** %0{{(, align 8)?}}, !enzyme_mustcache !
75-
; CHECK-NEXT: %[[t0:.+]] = bitcast double** %0 to i8*
76-
; CHECK-NEXT: tail call void @free(i8* nonnull %[[t0]])
75+
; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg)
7776
; CHECK-NEXT: %"loaded'de" = alloca double
7877
; CHECK-NEXT: store double 0.000000e+00, double* %"loaded'de"
7978
; CHECK-NEXT: br label %invertentry

0 commit comments

Comments
 (0)