Skip to content

Commit 58d27d5

Browse files
authored
Fix antialloca (rust-lang#651)
* Fix antialloca
1 parent 1e083a3 commit 58d27d5

File tree

1 file changed

+32
-22
lines changed

1 file changed

+32
-22
lines changed

enzyme/Enzyme/GradientUtils.cpp

+32-22
Original file line numberDiff line numberDiff line change
@@ -3841,32 +3841,42 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
38413841

38423842
if (isa<Argument>(oval) && cast<Argument>(oval)->hasByValAttr()) {
38433843
IRBuilder<> bb(inversionAllocs);
3844-
AllocaInst *antialloca = bb.CreateAlloca(
3845-
oval->getType()->getPointerElementType(),
3846-
cast<PointerType>(oval->getType())->getPointerAddressSpace(), nullptr,
3847-
oval->getName() + "'ipa");
3848-
invertedPointers.insert(std::make_pair(
3849-
(const Value *)oval, InvertedPointerVH(this, antialloca)));
3850-
auto dst_arg =
3851-
bb.CreateBitCast(antialloca, Type::getInt8PtrTy(oval->getContext()));
3852-
auto val_arg = ConstantInt::get(Type::getInt8Ty(oval->getContext()), 0);
3853-
auto len_arg =
3854-
ConstantInt::get(Type::getInt64Ty(oval->getContext()),
3855-
M->getDataLayout().getTypeAllocSizeInBits(
3856-
oval->getType()->getPointerElementType()) /
3857-
8);
3858-
auto volatile_arg = ConstantInt::getFalse(oval->getContext());
3844+
3845+
auto rule1 = [&]() {
3846+
AllocaInst *antialloca = bb.CreateAlloca(
3847+
oval->getType()->getPointerElementType(),
3848+
cast<PointerType>(oval->getType())->getPointerAddressSpace(), nullptr,
3849+
oval->getName() + "'ipa");
3850+
3851+
auto dst_arg =
3852+
bb.CreateBitCast(antialloca, Type::getInt8PtrTy(oval->getContext()));
3853+
auto val_arg = ConstantInt::get(Type::getInt8Ty(oval->getContext()), 0);
3854+
auto len_arg =
3855+
ConstantInt::get(Type::getInt64Ty(oval->getContext()),
3856+
M->getDataLayout().getTypeAllocSizeInBits(
3857+
oval->getType()->getPointerElementType()) /
3858+
8);
3859+
auto volatile_arg = ConstantInt::getFalse(oval->getContext());
38593860

38603861
#if LLVM_VERSION_MAJOR == 6
3861-
auto align_arg = ConstantInt::get(Type::getInt32Ty(oval->getContext()),
3862-
antialloca->getAlignment());
3863-
Value *args[] = {dst_arg, val_arg, len_arg, align_arg, volatile_arg};
3862+
auto align_arg = ConstantInt::get(Type::getInt32Ty(oval->getContext()),
3863+
antialloca->getAlignment());
3864+
Value *args[] = {dst_arg, val_arg, len_arg, align_arg, volatile_arg};
38643865
#else
3865-
Value *args[] = {dst_arg, val_arg, len_arg, volatile_arg};
3866+
Value *args[] = {dst_arg, val_arg, len_arg, volatile_arg};
38663867
#endif
3867-
Type *tys[] = {dst_arg->getType(), len_arg->getType()};
3868-
cast<CallInst>(bb.CreateCall(
3869-
Intrinsic::getDeclaration(M, Intrinsic::memset, tys), args));
3868+
Type *tys[] = {dst_arg->getType(), len_arg->getType()};
3869+
cast<CallInst>(bb.CreateCall(
3870+
Intrinsic::getDeclaration(M, Intrinsic::memset, tys), args));
3871+
3872+
return antialloca;
3873+
};
3874+
3875+
Value *antialloca = applyChainRule(oval->getType(), bb, rule1);
3876+
3877+
invertedPointers.insert(std::make_pair(
3878+
(const Value *)oval, InvertedPointerVH(this, antialloca)));
3879+
38703880
return antialloca;
38713881
} else if (auto arg = dyn_cast<GlobalVariable>(oval)) {
38723882
if (!hasMetadata(arg, "enzyme_shadow")) {

0 commit comments

Comments
 (0)