@@ -2312,10 +2312,6 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
2312
2312
report_fatal_error (" function failed verification (2)" );
2313
2313
}
2314
2314
2315
- StructType *sty = cast<StructType>(gutils->newFunc ->getReturnType ());
2316
- SmallVector<Type *, 4 > RetTypes (sty->elements ().begin (),
2317
- sty->elements ().end ());
2318
-
2319
2315
SmallVector<Type *, 4 > MallocTypes;
2320
2316
2321
2317
for (auto a : gutils->getTapeValues ()) {
@@ -2338,6 +2334,30 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
2338
2334
forceAnonymousTape;
2339
2335
bool noTape = MallocTypes.size () == 0 && !forceAnonymousTape;
2340
2336
2337
+ StructType *sty = cast<StructType>(gutils->newFunc ->getReturnType ());
2338
+ SmallVector<Type *, 4 > RetTypes (sty->elements ().begin (),
2339
+ sty->elements ().end ());
2340
+ if (!noTape) {
2341
+ if (recursive && !omp) {
2342
+ auto size =
2343
+ gutils->newFunc ->getParent ()->getDataLayout ().getTypeAllocSizeInBits (
2344
+ tapeType);
2345
+ if (size != 0 ) {
2346
+ auto i64 = Type::getInt64Ty (gutils->newFunc ->getContext ());
2347
+ BasicBlock *BB = BasicBlock::Create (gutils->newFunc ->getContext (),
2348
+ " entry" , gutils->newFunc );
2349
+ IRBuilder<> B (BB);
2350
+
2351
+ CallInst *malloccall;
2352
+ CreateAllocation (B, tapeType, ConstantInt::get (i64 , 1 ), " tapemem" ,
2353
+ &malloccall, nullptr );
2354
+ RetTypes[returnMapping.find (AugmentedStruct::Tape)->second ] =
2355
+ malloccall->getType ();
2356
+ BB->eraseFromParent ();
2357
+ }
2358
+ }
2359
+ }
2360
+
2341
2361
int oldretIdx = -1 ;
2342
2362
if (returnMapping.find (AugmentedStruct::Return) != returnMapping.end ()) {
2343
2363
oldretIdx = returnMapping[AugmentedStruct::Return];
@@ -2367,8 +2387,6 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
2367
2387
}
2368
2388
RetTypes.erase (RetTypes.begin () + tidx);
2369
2389
} else if (recursive) {
2370
- assert (RetTypes[returnMapping.find (AugmentedStruct::Tape)->second ] ==
2371
- Type::getInt8PtrTy (nf->getContext ()));
2372
2390
} else {
2373
2391
RetTypes[returnMapping.find (AugmentedStruct::Tape)->second ] = tapeType;
2374
2392
}
@@ -2440,7 +2458,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
2440
2458
2441
2459
IRBuilder<> ib (NewF->getEntryBlock ().getFirstNonPHI ());
2442
2460
2443
- Value *ret = noReturn ? nullptr : ib.CreateAlloca (RetType);
2461
+ AllocaInst *ret = noReturn ? nullptr : ib.CreateAlloca (RetType);
2444
2462
if (!noReturn && EnzymeZeroCache) {
2445
2463
ib.CreateStore (Constant::getNullValue (RetType), ret);
2446
2464
}
@@ -2455,10 +2473,9 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
2455
2473
if (size != 0 ) {
2456
2474
CallInst *malloccall = nullptr ;
2457
2475
Instruction *zero = nullptr ;
2458
- IRBuilder<> Builder (NewF->getEntryBlock ().getFirstNonPHI ());
2459
- tapeMemory = CreateAllocation (
2460
- Builder, tapeType, ConstantInt::get (i64 , 1 ), " tapemem" , &malloccall,
2461
- EnzymeZeroCache ? &zero : nullptr );
2476
+ tapeMemory =
2477
+ CreateAllocation (ib, tapeType, ConstantInt::get (i64 , 1 ), " tapemem" ,
2478
+ &malloccall, EnzymeZeroCache ? &zero : nullptr );
2462
2479
memory = malloccall;
2463
2480
} else {
2464
2481
memory =
@@ -3690,7 +3707,10 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
3690
3707
IRBuilder<> BuilderZ (gutils->inversionAllocs );
3691
3708
if (!augmenteddata->tapeType ->isEmptyTy ()) {
3692
3709
auto tapep = BuilderZ.CreatePointerCast (
3693
- additionalValue, PointerType::getUnqual (augmenteddata->tapeType ));
3710
+ additionalValue,
3711
+ PointerType::get (augmenteddata->tapeType ,
3712
+ cast<PointerType>(additionalValue->getType ())
3713
+ ->getAddressSpace ()));
3694
3714
#if LLVM_VERSION_MAJOR > 7
3695
3715
LoadInst *truetape =
3696
3716
BuilderZ.CreateLoad (augmenteddata->tapeType , tapep, " truetape" );
@@ -3701,7 +3721,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
3701
3721
MDNode::get (truetape->getContext (), {}));
3702
3722
3703
3723
if (!omp && gutils->FreeMemory ) {
3704
- CreateDealloc (BuilderZ, tapep );
3724
+ CreateDealloc (BuilderZ, additionalValue );
3705
3725
}
3706
3726
additionalValue = truetape;
3707
3727
} else {
0 commit comments