Skip to content

Commit 6344202

Browse files
authored
Fast realloc (rust-lang#233)
1 parent 3aef65b commit 6344202

16 files changed

+447
-172
lines changed

enzyme/Enzyme/CacheUtility.cpp

+21-24
Original file line numberDiff line numberDiff line change
@@ -695,8 +695,6 @@ AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T,
695695
}
696696

697697
Type *BPTy = Type::getInt8PtrTy(T->getContext());
698-
auto realloc = newFunc->getParent()->getOrInsertFunction(
699-
"realloc", BPTy, BPTy, Type::getInt64Ty(T->getContext()));
700698

701699
Value *storeInto = alloc;
702700

@@ -799,32 +797,31 @@ AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T,
799797

800798
IRBuilder<> build(containedloops.back().first.incvar->getNextNode());
801799
Value *allocation = build.CreateLoad(storeInto);
802-
Value *realloc_size = nullptr;
803-
if (isa<ConstantInt>(sublimits[i].first) &&
804-
cast<ConstantInt>(sublimits[i].first)->isOne()) {
805-
realloc_size = containedloops.back().first.incvar;
806-
} else {
807-
realloc_size = build.CreateMul(containedloops.back().first.incvar,
808-
sublimits[i].first, "", /*NUW*/ true,
809-
/*NSW*/ true);
810-
}
811800

812-
Value *idxs[2] = {
813-
build.CreatePointerCast(allocation, BPTy),
814-
build.CreateMul(
815-
ConstantInt::get(size->getType(),
816-
newFunc->getParent()
817-
->getDataLayout()
818-
.getTypeAllocSizeInBits(myType) /
819-
8),
820-
realloc_size, "", /*NUW*/ true, /*NSW*/ true)};
801+
Value *tsize = ConstantInt::get(
802+
size->getType(),
803+
newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(
804+
myType) /
805+
8);
821806

807+
Value *idxs[] = {
808+
/*ptr*/
809+
build.CreatePointerCast(allocation, BPTy),
810+
/*incrementing value to increase when it goes past a power of two*/
811+
containedloops.back().first.incvar,
812+
/*buffer size (element x subloops)*/
813+
build.CreateMul(tsize, sublimits[i].first, "", /*NUW*/ true,
814+
/*NSW*/ true)};
815+
816+
assert(cast<PointerType>(allocation->getType())->getElementType() ==
817+
myType);
822818
Value *realloccall = nullptr;
823-
allocation = build.CreatePointerCast(
824-
realloccall =
825-
build.CreateCall(realloc, idxs, name + "_realloccache"),
826-
allocation->getType(), name + "_realloccast");
819+
820+
realloccall = build.CreateCall(
821+
getOrInsertExponentialAllocator(*newFunc->getParent()), idxs,
822+
name + "_realloccache");
827823
scopeAllocs[alloc].push_back(cast<CallInst>(realloccall));
824+
allocation = build.CreateBitCast(realloccall, allocation->getType());
828825
storealloc = build.CreateStore(allocation, storeInto);
829826
// Unlike the static case we can not mark the memory as invariant
830827
// since we are reloading/storing based off the number of loop

enzyme/Enzyme/EnzymeLogic.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -2068,6 +2068,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
20682068
user->setCalledFunction(NewF);
20692069
}
20702070
}
2071+
PPC.AlwaysInline(gutils->newFunc);
20712072
auto Arch = llvm::Triple(NewF->getParent()->getTargetTriple()).getArch();
20722073
if (Arch == Triple::nvptx || Arch == Triple::nvptx64)
20732074
PPC.ReplaceReallocs(NewF, /*mem2reg*/ true);
@@ -3210,6 +3211,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
32103211
PreservedAnalyses PA;
32113212
PPC.FAM.invalidate(*gutils->newFunc, PA);
32123213
}
3214+
PPC.AlwaysInline(gutils->newFunc);
32133215
if (Arch == Triple::nvptx || Arch == Triple::nvptx64)
32143216
PPC.ReplaceReallocs(gutils->newFunc, /*mem2reg*/ true);
32153217
if (PostOpt)

enzyme/Enzyme/FunctionUtils.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,35 @@ OldAllocationSize(Value *Ptr, CallInst *Loc, Function *NewF, IntegerType *T,
489489
return AI;
490490
}
491491

492+
void PreProcessCache::AlwaysInline(Function *NewF) {
493+
PreservedAnalyses PA;
494+
PA.preserve<AssumptionAnalysis>();
495+
PA.preserve<TargetLibraryAnalysis>();
496+
FAM.invalidate(*NewF, PA);
497+
SmallVector<CallInst *, 2> ToInline;
498+
// TODO this logic should be combined with the dynamic loop emission
499+
// to minimize the number of branches if the realloc is used for multiple
500+
// values with the same bound.
501+
for (auto &BB : *NewF) {
502+
for (auto &I : BB) {
503+
if (auto CI = dyn_cast<CallInst>(&I)) {
504+
if (!CI->getCalledFunction())
505+
continue;
506+
if (CI->getCalledFunction()->hasFnAttribute(Attribute::AlwaysInline))
507+
ToInline.push_back(CI);
508+
}
509+
}
510+
}
511+
for (auto CI : ToInline) {
512+
InlineFunctionInfo IFI;
513+
#if LLVM_VERSION_MAJOR >= 11
514+
InlineFunction(*CI, IFI);
515+
#else
516+
InlineFunction(CI, IFI);
517+
#endif
518+
}
519+
}
520+
492521
/// Calls to realloc with an appropriate implementation
493522
void PreProcessCache::ReplaceReallocs(Function *NewF, bool mem2reg) {
494523
if (mem2reg) {

enzyme/Enzyme/FunctionUtils.h

+1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class PreProcessCache {
7171
llvm::Type *additionalArg = nullptr);
7272

7373
void ReplaceReallocs(llvm::Function *NewF, bool mem2reg = false);
74+
void AlwaysInline(llvm::Function *NewF);
7475
void optimizeIntermediate(llvm::Function *F);
7576

7677
void clear();

enzyme/Enzyme/Utils.cpp

+72
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,75 @@ llvm::Value *getOrInsertOpFloatSum(llvm::Module &M, llvm::Type *OpPtr,
416416
B2.CreateCall(M.getFunction(name + "initializer"));
417417
return GV;
418418
}
419+
420+
Function *getOrInsertExponentialAllocator(Module &M) {
421+
Type *BPTy = Type::getInt8PtrTy(M.getContext());
422+
Type *types[] = {BPTy, Type::getInt64Ty(M.getContext()),
423+
Type::getInt64Ty(M.getContext())};
424+
std::string name = "__enzyme_exponentialallocation";
425+
FunctionType *FT =
426+
FunctionType::get(Type::getInt8PtrTy(M.getContext()), types, false);
427+
428+
#if LLVM_VERSION_MAJOR >= 9
429+
Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee());
430+
#else
431+
Function *F = cast<Function>(M.getOrInsertFunction(name, FT));
432+
#endif
433+
434+
if (!F->empty())
435+
return F;
436+
437+
F->setLinkage(Function::LinkageTypes::InternalLinkage);
438+
F->addFnAttr(Attribute::AlwaysInline);
439+
F->addFnAttr(Attribute::NoUnwind);
440+
BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F);
441+
BasicBlock *grow = BasicBlock::Create(M.getContext(), "grow", F);
442+
BasicBlock *ok = BasicBlock::Create(M.getContext(), "ok", F);
443+
444+
IRBuilder<> B(entry);
445+
446+
Argument *ptr = F->arg_begin();
447+
ptr->setName("ptr");
448+
Argument *size = ptr + 1;
449+
size->setName("size");
450+
Argument *tsize = size + 1;
451+
tsize->setName("tsize");
452+
453+
Value *hasOne = B.CreateICmpNE(
454+
B.CreateAnd(size, ConstantInt::get(size->getType(), 1, false)),
455+
ConstantInt::get(size->getType(), 0, false));
456+
auto popCnt = Intrinsic::getDeclaration(&M, Intrinsic::ctpop,
457+
std::vector<Type *>({types[1]}));
458+
459+
B.CreateCondBr(
460+
B.CreateAnd(
461+
B.CreateICmpULT(B.CreateCall(popCnt, std::vector<Value *>({size})),
462+
ConstantInt::get(types[1], 3, false)),
463+
hasOne),
464+
grow, ok);
465+
466+
B.SetInsertPoint(grow);
467+
468+
auto lz = B.CreateCall(
469+
Intrinsic::getDeclaration(&M, Intrinsic::ctlz,
470+
std::vector<Type *>({types[1]})),
471+
std::vector<Value *>({size, ConstantInt::getTrue(M.getContext())}));
472+
Value *next =
473+
B.CreateShl(tsize, B.CreateSub(ConstantInt::get(types[1], 64, false), lz,
474+
"", true, true));
475+
476+
auto reallocF = M.getOrInsertFunction("realloc", BPTy, BPTy,
477+
Type::getInt64Ty(M.getContext()));
478+
479+
Value *args[] = {B.CreatePointerCast(ptr, BPTy), next};
480+
Value *gVal =
481+
B.CreatePointerCast(B.CreateCall(reallocF, args), ptr->getType());
482+
483+
B.CreateBr(ok);
484+
B.SetInsertPoint(ok);
485+
auto phi = B.CreatePHI(ptr->getType(), 2);
486+
phi->addIncoming(gVal, grow);
487+
phi->addIncoming(ptr, entry);
488+
B.CreateRet(phi);
489+
return F;
490+
}

enzyme/Enzyme/Utils.h

+1
Original file line numberDiff line numberDiff line change
@@ -796,4 +796,5 @@ enum class MPI_CallType {
796796
llvm::Value *getOrInsertOpFloatSum(llvm::Module &M, llvm::Type *OpPtr,
797797
ConcreteType CT, llvm::Type *intType,
798798
llvm::IRBuilder<> &B2);
799+
llvm::Function *getOrInsertExponentialAllocator(llvm::Module &M);
799800
#endif

enzyme/test/Enzyme/ReverseMode/allocacache.ll

+4-4
Original file line numberDiff line numberDiff line change
@@ -412,18 +412,18 @@ entry:
412412
; Function Attrs: argmemonly nounwind
413413
declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) #1
414414

415-
attributes #0 = { alwaysinline norecurse nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "min-legal-vector-width"="128" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
415+
attributes #0 = { norecurse nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "min-legal-vector-width"="128" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
416416
attributes #1 = { argmemonly nounwind }
417-
attributes #2 = { alwaysinline "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
417+
attributes #2 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
418418
attributes #3 = { noinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "min-legal-vector-width"="128" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
419419
attributes #4 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
420420
attributes #5 = { nounwind readnone speculatable }
421421
attributes #6 = { noreturn nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
422422
attributes #7 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "min-legal-vector-width"="128" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
423-
attributes #8 = { alwaysinline inlinehint norecurse nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "min-legal-vector-width"="128" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
423+
attributes #8 = { inlinehint norecurse nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "min-legal-vector-width"="128" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
424424
attributes #9 = { nounwind }
425425
attributes #10 = { noreturn nounwind }
426-
attributes #11 = { alwaysinline cold }
426+
attributes #11 = { cold }
427427

428428
!llvm.module.flags = !{!0}
429429
!llvm.ident = !{!1}

enzyme/test/Enzyme/ReverseMode/cppllist.ll

+27-10
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,9 @@ attributes #8 = { builtin nounwind }
181181
; CHECK-NEXT: %[[nodevar:.+]] = phi %class.node* [ %"'ipc.i", %for.body.i ], [ null, %entry ]
182182
; CHECK-NEXT: %list.09.i = phi %class.node* [ %[[bcnode:.+]], %for.body.i ], [ null, %entry ]
183183
; CHECK-NEXT: %[[ivnext]] = add nuw nsw i64 %[[iv]], 1
184-
; CHECK-NEXT: %call.i = call noalias nonnull dereferenceable(16) dereferenceable_or_null(16) i8* @_Znwm(i64 16) #10
185-
; CHECK-NEXT: %"call'mi.i" = call noalias nonnull dereferenceable(16) dereferenceable_or_null(16) i8* @_Znwm(i64 16) #10
186-
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* nonnull {{(align 1 )?}}dereferenceable(16) dereferenceable_or_null(16) %"call'mi.i", i8 0, i64 16, {{(i32 1, )?}}i1 false) #5
184+
; CHECK-NEXT: %call.i = call noalias nonnull dereferenceable(16) dereferenceable_or_null(16) i8* @_Znwm(i64 16)
185+
; CHECK-NEXT: %"call'mi.i" = call noalias nonnull dereferenceable(16) dereferenceable_or_null(16) i8* @_Znwm(i64 16)
186+
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* nonnull {{(align 1 )?}}dereferenceable(16) dereferenceable_or_null(16) %"call'mi.i", i8 0, i64 16, {{(i32 1, )?}}i1 false)
187187
; CHECK-NEXT: %"'ipc.i" = bitcast i8* %"call'mi.i" to %class.node*
188188
; CHECK-NEXT: %[[bcnode]] = bitcast i8* %call.i to %class.node*
189189
; CHECK-NEXT: %value.i.i = bitcast i8* %call.i to double*
@@ -241,13 +241,30 @@ attributes #8 = { builtin nounwind }
241241
; CHECK-NEXT: br i1 %[[cmp]], label %invertentry, label %for.body
242242

243243
; CHECK: for.body:
244-
; CHECK-NEXT: %[[rawcache:.+]] = phi i8* [ %[[realloccache:.+]], %for.body ], [ null, %entry ]
245-
; CHECK-NEXT: %[[preidx:.+]] = phi i64 [ %[[postidx:.+]], %for.body ], [ 0, %entry ]
246-
; CHECK-NEXT: %[[cur:.+]] = phi %class.node* [ %"'ipl", %for.body ], [ %"node'", %entry ]
247-
; CHECK-NEXT: %val.08 = phi %class.node* [ %[[loadst:.+]], %for.body ], [ %node, %entry ]
244+
; CHECK-NEXT: %[[rawcache:.+]] = phi i8* [ %[[realloccache:.+]], %[[mergeblk:.+]] ], [ null, %entry ]
245+
; CHECK-NEXT: %[[preidx:.+]] = phi i64 [ %[[postidx:.+]], %[[mergeblk:.+]] ], [ 0, %entry ]
246+
; CHECK-NEXT: %[[cur:.+]] = phi %class.node* [ %"'ipl", %[[mergeblk:.+]] ], [ %"node'", %entry ]
247+
; CHECK-NEXT: %val.08 = phi %class.node* [ %[[loadst:.+]], %[[mergeblk:.+]] ], [ %node, %entry ]
248248
; CHECK-NEXT: %[[postidx]] = add nuw nsw i64 %[[preidx]], 1
249-
; CHECK-NEXT: %[[nextrealloc:.+]] = shl nuw nsw i64 %[[postidx]], 3
250-
; CHECK-NEXT: %[[realloccache]] = call i8* @realloc(i8* %[[rawcache]], i64 %[[nextrealloc]])
249+
250+
; CHECK-NEXT: %[[nexttrunc0:.+]] = and i64 %[[postidx]], 1
251+
; CHECK-NEXT: %[[nexttrunc:.+]] = icmp ne i64 %[[nexttrunc0]], 0
252+
; CHECK-NEXT: %[[popcnt:.+]] = call i64 @llvm.ctpop.i64(i64 %iv.next)
253+
; CHECK-NEXT: %[[le2:.+]] = icmp ult i64 %[[popcnt:.+]], 3
254+
; CHECK-NEXT: %[[shouldgrow:.+]] = and i1 %[[le2]], %[[nexttrunc]]
255+
; CHECK-NEXT: br i1 %[[shouldgrow]], label %grow.i, label %[[mergeblk]]
256+
257+
; CHECK: grow.i:
258+
; CHECK-NEXT: %[[ctlz:.+]] = call i64 @llvm.ctlz.i64(i64 %[[postidx]], i1 true)
259+
; CHECK-NEXT: %[[maxbit:.+]] = sub nuw nsw i64 64, %[[ctlz]]
260+
; CHECK-NEXT: %[[numbytes:.+]] = shl i64 8, %[[maxbit]]
261+
; CHECK-NEXT: %[[growalloc:.+]] = call i8* @realloc(i8* %[[rawcache]], i64 %[[numbytes]])
262+
; CHECK-NEXT: br label %[[mergeblk]]
263+
264+
; CHECK: [[mergeblk]]:
265+
; CHECK-NEXT: %[[realloccache]] = phi i8* [ %[[growalloc]], %grow.i ], [ %[[rawcache]], %for.body ]
266+
267+
251268
; CHECK-NEXT: %[[reallocbc:.+]] = bitcast i8* %[[realloccache]] to %class.node**
252269
; CHECK-NEXT: %[[geptostore:.+]] = getelementptr inbounds %class.node*, %class.node** %[[reallocbc]], i64 %[[preidx]]
253270
; CHECK-NEXT: store %class.node* %[[cur]], %class.node** %[[geptostore]]
@@ -266,7 +283,7 @@ attributes #8 = { builtin nounwind }
266283
; CHECK-NEXT: br label %invertentry
267284

268285
; CHECK: [[antiloop]]:
269-
; CHECK-NEXT: %[[antivar:.+]] = phi i64 [ %[[subidx:.+]], %incinvertfor.body ], [ %[[preidx]], %for.body ]
286+
; CHECK-NEXT: %[[antivar:.+]] = phi i64 [ %[[subidx:.+]], %incinvertfor.body ], [ %[[preidx]], %[[mergeblk]] ]
270287
; CHECK-NEXT: %[[structptr:.+]] = getelementptr inbounds %class.node*, %class.node** %[[reallocbc]], i64 %[[antivar]]
271288
; CHECK-NEXT: %[[struct:.+]] = load %class.node*, %class.node** %[[structptr]]
272289
; CHECK-NEXT: %[[valueipg:.+]] = getelementptr inbounds %class.node, %class.node* %[[struct]], i64 0, i32 0

0 commit comments

Comments
 (0)