Skip to content

Commit b14088f

Browse files
authored
Support declaring a custom allocation / free function. (rust-lang#937)
* starting custom allocator * Fix custom alloc
1 parent ccfb103 commit b14088f

File tree

8 files changed

+501
-207
lines changed

8 files changed

+501
-207
lines changed

enzyme/Enzyme/AdjointGenerator.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -10738,7 +10738,7 @@ class AdjointGenerator
1073810738
bb,
1073910739
[&](Value *anti) {
1074010740
zeroKnownAllocation(bb, anti, args, funcName,
10741-
gutils->TLI);
10741+
gutils->TLI, orig);
1074210742
},
1074310743
anti);
1074410744
}
@@ -10760,7 +10760,7 @@ class AdjointGenerator
1076010760
assert(tofree->getType());
1076110761
auto rule = [&](Value *tofree) {
1076210762
auto CI = freeKnownAllocation(Builder2, tofree, funcName, dbgLoc,
10763-
gutils->TLI);
10763+
gutils->TLI, orig, gutils);
1076410764
if (CI)
1076510765
#if LLVM_VERSION_MAJOR >= 14
1076610766
CI->addAttributeAtIndex(AttributeList::FirstArgIndex,
@@ -10978,7 +10978,7 @@ class AdjointGenerator
1097810978
getReverseBuilder(Builder2);
1097910979
auto dbgLoc = gutils->getNewFromOriginal(orig->getDebugLoc());
1098010980
freeKnownAllocation(Builder2, lookup(newCall, Builder2), funcName,
10981-
dbgLoc, gutils->TLI);
10981+
dbgLoc, gutils->TLI, orig, gutils);
1098210982
if (Mode == DerivativeMode::ReverseModeGradient &&
1098310983
found->second.LI && found->second.LI->contains(orig))
1098410984
gutils->rematerializedPrimalOrShadowAllocations.push_back(
@@ -11063,7 +11063,7 @@ class AdjointGenerator
1106311063
getReverseBuilder(Builder2);
1106411064
auto dbgLoc = gutils->getNewFromOriginal(orig->getDebugLoc());
1106511065
freeKnownAllocation(Builder2, lookup(nop, Builder2), funcName, dbgLoc,
11066-
gutils->TLI);
11066+
gutils->TLI, orig, gutils);
1106711067
}
1106811068
} else if (Mode == DerivativeMode::ReverseModeGradient ||
1106911069
Mode == DerivativeMode::ReverseModeCombined ||

enzyme/Enzyme/Enzyme.cpp

+106-2
Original file line numberDiff line numberDiff line change
@@ -391,15 +391,15 @@ handleFunctionLike(llvm::Module &M, llvm::GlobalVariable &g,
391391
Attribute::get(g.getContext(), "enzyme_math", nameVal));
392392
} else {
393393
llvm::errs() << M << "\n";
394-
llvm::errs() << "Param of __enzyme_inactivefn must be a "
394+
llvm::errs() << "Param of __enzyme_function_like must be a "
395395
"function"
396396
<< g << "\n"
397397
<< *V << "\n";
398398
llvm_unreachable("__enzyme_inactivefn");
399399
}
400400
} else {
401401
llvm::errs() << M << "\n";
402-
llvm::errs() << "Use of __enzyme_inactivefn must be a "
402+
llvm::errs() << "Use of __enzyme_function_like must be a "
403403
"constant function "
404404
<< g << "\n";
405405
llvm_unreachable("__enzyme_register_gradient");
@@ -408,6 +408,108 @@ handleFunctionLike(llvm::Module &M, llvm::GlobalVariable &g,
408408
}
409409
}
410410

411+
static void
412+
handleAllocationLike(llvm::Module &M, llvm::GlobalVariable &g,
413+
SmallVectorImpl<GlobalVariable *> &globalsToErase) {
414+
if (g.hasInitializer()) {
415+
if (auto CA = dyn_cast<ConstantAggregate>(g.getInitializer())) {
416+
if (CA->getNumOperands() != 4) {
417+
llvm::errs() << M << "\n";
418+
llvm::errs() << "Use of "
419+
<< "enzyme_allocation_like"
420+
<< " must be a "
421+
"constant of size at least "
422+
<< 4 << " " << g << "\n";
423+
llvm_unreachable("enzyme_allocation_like");
424+
}
425+
Value *V = CA->getOperand(0);
426+
Value *name = CA->getOperand(1);
427+
while (auto CE = dyn_cast<ConstantExpr>(V)) {
428+
V = CE->getOperand(0);
429+
}
430+
while (auto CE = dyn_cast<ConstantExpr>(name)) {
431+
name = CE->getOperand(0);
432+
}
433+
Value *deallocind = CA->getOperand(2);
434+
while (auto CE = dyn_cast<ConstantExpr>(deallocind)) {
435+
deallocind = CE->getOperand(0);
436+
}
437+
Value *deallocfn = CA->getOperand(3);
438+
while (auto CE = dyn_cast<ConstantExpr>(deallocfn)) {
439+
deallocfn = CE->getOperand(0);
440+
}
441+
size_t index = 0;
442+
if (auto CI = dyn_cast<ConstantInt>(name)) {
443+
index = CI->getZExtValue();
444+
} else {
445+
llvm::errs() << *name << "\n";
446+
llvm::errs() << "Use of "
447+
<< "enzyme_allocation_like"
448+
<< "requires an integer index"
449+
<< "\n";
450+
llvm_unreachable("enzyme_allocation_like");
451+
}
452+
453+
StringRef deallocIndStr;
454+
bool foundInd = false;
455+
if (auto GV = dyn_cast<GlobalVariable>(deallocind))
456+
if (GV->isConstant())
457+
if (auto C = GV->getInitializer())
458+
if (auto CA = dyn_cast<ConstantDataArray>(C))
459+
if (CA->getType()->getElementType()->isIntegerTy(8) &&
460+
CA->isCString()) {
461+
deallocIndStr = CA->getAsCString();
462+
foundInd = true;
463+
}
464+
465+
if (!foundInd) {
466+
llvm::errs() << *deallocind << "\n";
467+
llvm::errs() << "Use of "
468+
<< "enzyme_allocation_like"
469+
<< "requires a deallocation index string"
470+
<< "\n";
471+
llvm_unreachable("enzyme_allocation_like");
472+
}
473+
if (auto F = dyn_cast<Function>(V)) {
474+
F->addAttribute(AttributeList::FunctionIndex,
475+
Attribute::get(g.getContext(), "enzyme_allocator",
476+
std::to_string(index)));
477+
} else {
478+
llvm::errs() << M << "\n";
479+
llvm::errs() << "Param of __enzyme_allocation_like must be a "
480+
"function"
481+
<< g << "\n"
482+
<< *V << "\n";
483+
llvm_unreachable("__enzyme_allocation_like");
484+
}
485+
cast<Function>(V)->addAttribute(
486+
AttributeList::FunctionIndex,
487+
Attribute::get(g.getContext(), "enzyme_deallocator", deallocIndStr));
488+
489+
if (auto F = dyn_cast<Function>(deallocfn)) {
490+
cast<Function>(V)->setMetadata(
491+
"enzyme_deallocator_fn",
492+
llvm::MDTuple::get(F->getContext(),
493+
{llvm::ValueAsMetadata::get(F)}));
494+
} else {
495+
llvm::errs() << M << "\n";
496+
llvm::errs() << "Free fn of __enzyme_allocation_like must be a "
497+
"function"
498+
<< g << "\n"
499+
<< *deallocfn << "\n";
500+
llvm_unreachable("__enzyme_allocation_like");
501+
}
502+
} else {
503+
llvm::errs() << M << "\n";
504+
llvm::errs() << "Use of __enzyme_allocation_like must be a "
505+
"constant function "
506+
<< g << "\n";
507+
llvm_unreachable("__enzyme_allocation_like");
508+
}
509+
globalsToErase.push_back(&g);
510+
}
511+
}
512+
411513
static void handleKnownFunctions(llvm::Function &F) {
412514
if (F.getName() == "memcmp") {
413515
F.addFnAttr(Attribute::ReadOnly);
@@ -2452,6 +2554,8 @@ class Enzyme final : public ModulePass {
24522554
handleInactiveFunction(M, g, globalsToErase);
24532555
} else if (g.getName().contains("__enzyme_function_like")) {
24542556
handleFunctionLike(M, g, globalsToErase);
2557+
} else if (g.getName().contains("__enzyme_allocation_like")) {
2558+
handleAllocationLike(M, g, globalsToErase);
24552559
}
24562560
}
24572561
for (auto g : globalsToErase) {

0 commit comments

Comments
 (0)