Skip to content

Commit c1a1f22

Browse files
authored
Custom alloc zeroing (rust-lang#908)
* Custom zero * Enable on differentials * Cleanup
1 parent c81bf98 commit c1a1f22

File tree

6 files changed

+20
-12
lines changed

6 files changed

+20
-12
lines changed

enzyme/Enzyme/CApi.cpp

-5
Original file line numberDiff line numberDiff line change
@@ -685,11 +685,6 @@ void EnzymeMoveBefore(LLVMValueRef inst1, LLVMValueRef inst2,
685685
}
686686
}
687687

688-
void EnzymeSetForMemSet(LLVMValueRef inst1) {
689-
Instruction *I1 = cast<Instruction>(unwrap(inst1));
690-
I1->setMetadata("enzyme_formemset", MDNode::get(I1->getContext(), {}));
691-
}
692-
693688
void EnzymeSetMustCache(LLVMValueRef inst1) {
694689
Instruction *I1 = cast<Instruction>(unwrap(inst1));
695690
I1->setMetadata("enzyme_mustcache", MDNode::get(I1->getContext(), {}));

enzyme/Enzyme/EnzymeLogic.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -2513,9 +2513,6 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
25132513
IRBuilder<> ib(NewF->getEntryBlock().getFirstNonPHI());
25142514

25152515
AllocaInst *ret = noReturn ? nullptr : ib.CreateAlloca(RetType);
2516-
if (!noReturn && EnzymeZeroCache) {
2517-
ib.CreateStore(Constant::getNullValue(RetType), ret);
2518-
}
25192516

25202517
if (!noTape) {
25212518
Value *tapeMemory;
@@ -2573,6 +2570,10 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
25732570
#endif
25742571
cast<GetElementPtrInst>(tapeMemory)->setIsInBounds(true);
25752572
}
2573+
if (EnzymeZeroCache) {
2574+
ZeroMemory(ib, tapeType, tapeMemory,
2575+
/*isTape*/ true);
2576+
}
25762577
}
25772578

25782579
unsigned i = 0;

enzyme/Enzyme/GradientUtils.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -2218,9 +2218,8 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc,
22182218
toadd = scopeAllocs[found2->second.first][0];
22192219
for (auto u : toadd->users()) {
22202220
if (auto ci = dyn_cast<CastInst>(u)) {
2221-
if (hasMetadata(ci, "enzyme_formemset"))
2222-
continue;
22232221
toadd = ci;
2222+
break;
22242223
}
22252224
}
22262225

enzyme/Enzyme/GradientUtils.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -1650,8 +1650,8 @@ class DiffeGradientUtils final : public GradientUtils {
16501650
#else
16511651
differentials[val]->setAlignment(Alignment);
16521652
#endif
1653-
entryBuilder.CreateStore(Constant::getNullValue(type),
1654-
differentials[val]);
1653+
ZeroMemory(entryBuilder, type, differentials[val],
1654+
/*isTape*/ false);
16551655
}
16561656
#if LLVM_VERSION_MAJOR >= 15
16571657
if (val->getContext().supportsTypedPointers()) {

enzyme/Enzyme/Utils.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ LLVMValueRef (*CustomAllocator)(LLVMBuilderRef, LLVMTypeRef,
5050
/*Count*/ LLVMValueRef,
5151
/*Align*/ LLVMValueRef, uint8_t,
5252
LLVMValueRef *) = nullptr;
53+
LLVMValueRef (*CustomZero)(LLVMBuilderRef, LLVMTypeRef,
54+
/*Ptr*/ LLVMValueRef, uint8_t) = nullptr;
5355
LLVMValueRef (*CustomDeallocator)(LLVMBuilderRef, LLVMValueRef) = nullptr;
5456
void (*CustomRuntimeInactiveError)(LLVMBuilderRef, LLVMValueRef,
5557
LLVMValueRef) = nullptr;
@@ -58,6 +60,15 @@ LLVMValueRef *(*EnzymePostCacheStore)(LLVMValueRef, LLVMBuilderRef,
5860
LLVMTypeRef (*EnzymeDefaultTapeType)(LLVMContextRef) = nullptr;
5961
}
6062

63+
void ZeroMemory(llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *obj,
64+
bool isTape) {
65+
if (CustomZero) {
66+
CustomZero(wrap(&Builder), wrap(T), wrap(obj), isTape);
67+
} else {
68+
Builder.CreateStore(Constant::getNullValue(T), obj);
69+
}
70+
}
71+
6172
llvm::SmallVector<llvm::Instruction *, 2> PostCacheStore(llvm::StoreInst *SI,
6273
llvm::IRBuilder<> &B) {
6374
SmallVector<llvm::Instruction *, 2> res;

enzyme/Enzyme/Utils.h

+2
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ llvm::Value *CreateAllocation(llvm::IRBuilder<> &B, llvm::Type *T,
9292
llvm::Instruction **ZeroMem = nullptr,
9393
bool isDefault = false);
9494
llvm::CallInst *CreateDealloc(llvm::IRBuilder<> &B, llvm::Value *ToFree);
95+
void ZeroMemory(llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *obj,
96+
bool isTape);
9597

9698
llvm::Value *CreateReAllocation(llvm::IRBuilder<> &B, llvm::Value *prev,
9799
llvm::Type *T, llvm::Value *OuterCount,

0 commit comments

Comments
 (0)