@@ -3841,32 +3841,42 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
3841
3841
3842
3842
if (isa<Argument>(oval) && cast<Argument>(oval)->hasByValAttr ()) {
3843
3843
IRBuilder<> bb (inversionAllocs);
3844
- AllocaInst *antialloca = bb.CreateAlloca (
3845
- oval->getType ()->getPointerElementType (),
3846
- cast<PointerType>(oval->getType ())->getPointerAddressSpace (), nullptr ,
3847
- oval->getName () + " 'ipa" );
3848
- invertedPointers.insert (std::make_pair (
3849
- (const Value *)oval, InvertedPointerVH (this , antialloca)));
3850
- auto dst_arg =
3851
- bb.CreateBitCast (antialloca, Type::getInt8PtrTy (oval->getContext ()));
3852
- auto val_arg = ConstantInt::get (Type::getInt8Ty (oval->getContext ()), 0 );
3853
- auto len_arg =
3854
- ConstantInt::get (Type::getInt64Ty (oval->getContext ()),
3855
- M->getDataLayout ().getTypeAllocSizeInBits (
3856
- oval->getType ()->getPointerElementType ()) /
3857
- 8 );
3858
- auto volatile_arg = ConstantInt::getFalse (oval->getContext ());
3844
+
3845
+ auto rule1 = [&]() {
3846
+ AllocaInst *antialloca = bb.CreateAlloca (
3847
+ oval->getType ()->getPointerElementType (),
3848
+ cast<PointerType>(oval->getType ())->getPointerAddressSpace (), nullptr ,
3849
+ oval->getName () + " 'ipa" );
3850
+
3851
+ auto dst_arg =
3852
+ bb.CreateBitCast (antialloca, Type::getInt8PtrTy (oval->getContext ()));
3853
+ auto val_arg = ConstantInt::get (Type::getInt8Ty (oval->getContext ()), 0 );
3854
+ auto len_arg =
3855
+ ConstantInt::get (Type::getInt64Ty (oval->getContext ()),
3856
+ M->getDataLayout ().getTypeAllocSizeInBits (
3857
+ oval->getType ()->getPointerElementType ()) /
3858
+ 8 );
3859
+ auto volatile_arg = ConstantInt::getFalse (oval->getContext ());
3859
3860
3860
3861
#if LLVM_VERSION_MAJOR == 6
3861
- auto align_arg = ConstantInt::get (Type::getInt32Ty (oval->getContext ()),
3862
- antialloca->getAlignment ());
3863
- Value *args[] = {dst_arg, val_arg, len_arg, align_arg, volatile_arg};
3862
+ auto align_arg = ConstantInt::get (Type::getInt32Ty (oval->getContext ()),
3863
+ antialloca->getAlignment ());
3864
+ Value *args[] = {dst_arg, val_arg, len_arg, align_arg, volatile_arg};
3864
3865
#else
3865
- Value *args[] = {dst_arg, val_arg, len_arg, volatile_arg};
3866
+ Value *args[] = {dst_arg, val_arg, len_arg, volatile_arg};
3866
3867
#endif
3867
- Type *tys[] = {dst_arg->getType (), len_arg->getType ()};
3868
- cast<CallInst>(bb.CreateCall (
3869
- Intrinsic::getDeclaration (M, Intrinsic::memset , tys), args));
3868
+ Type *tys[] = {dst_arg->getType (), len_arg->getType ()};
3869
+ cast<CallInst>(bb.CreateCall (
3870
+ Intrinsic::getDeclaration (M, Intrinsic::memset , tys), args));
3871
+
3872
+ return antialloca;
3873
+ };
3874
+
3875
+ Value *antialloca = applyChainRule (oval->getType (), bb, rule1);
3876
+
3877
+ invertedPointers.insert (std::make_pair (
3878
+ (const Value *)oval, InvertedPointerVH (this , antialloca)));
3879
+
3870
3880
return antialloca;
3871
3881
} else if (auto arg = dyn_cast<GlobalVariable>(oval)) {
3872
3882
if (!hasMetadata (arg, " enzyme_shadow" )) {
0 commit comments