@@ -391,15 +391,15 @@ handleFunctionLike(llvm::Module &M, llvm::GlobalVariable &g,
391
391
Attribute::get (g.getContext (), " enzyme_math" , nameVal));
392
392
} else {
393
393
llvm::errs () << M << " \n " ;
394
- llvm::errs () << " Param of __enzyme_inactivefn must be a "
394
+ llvm::errs () << " Param of __enzyme_function_like must be a "
395
395
" function"
396
396
<< g << " \n "
397
397
<< *V << " \n " ;
398
398
llvm_unreachable (" __enzyme_inactivefn" );
399
399
}
400
400
} else {
401
401
llvm::errs () << M << " \n " ;
402
- llvm::errs () << " Use of __enzyme_inactivefn must be a "
402
+ llvm::errs () << " Use of __enzyme_function_like must be a "
403
403
" constant function "
404
404
<< g << " \n " ;
405
405
llvm_unreachable (" __enzyme_register_gradient" );
@@ -408,6 +408,108 @@ handleFunctionLike(llvm::Module &M, llvm::GlobalVariable &g,
408
408
}
409
409
}
410
410
411
+ static void
412
+ handleAllocationLike (llvm::Module &M, llvm::GlobalVariable &g,
413
+ SmallVectorImpl<GlobalVariable *> &globalsToErase) {
414
+ if (g.hasInitializer ()) {
415
+ if (auto CA = dyn_cast<ConstantAggregate>(g.getInitializer ())) {
416
+ if (CA->getNumOperands () != 4 ) {
417
+ llvm::errs () << M << " \n " ;
418
+ llvm::errs () << " Use of "
419
+ << " enzyme_allocation_like"
420
+ << " must be a "
421
+ " constant of size at least "
422
+ << 4 << " " << g << " \n " ;
423
+ llvm_unreachable (" enzyme_allocation_like" );
424
+ }
425
+ Value *V = CA->getOperand (0 );
426
+ Value *name = CA->getOperand (1 );
427
+ while (auto CE = dyn_cast<ConstantExpr>(V)) {
428
+ V = CE->getOperand (0 );
429
+ }
430
+ while (auto CE = dyn_cast<ConstantExpr>(name)) {
431
+ name = CE->getOperand (0 );
432
+ }
433
+ Value *deallocind = CA->getOperand (2 );
434
+ while (auto CE = dyn_cast<ConstantExpr>(deallocind)) {
435
+ deallocind = CE->getOperand (0 );
436
+ }
437
+ Value *deallocfn = CA->getOperand (3 );
438
+ while (auto CE = dyn_cast<ConstantExpr>(deallocfn)) {
439
+ deallocfn = CE->getOperand (0 );
440
+ }
441
+ size_t index = 0 ;
442
+ if (auto CI = dyn_cast<ConstantInt>(name)) {
443
+ index = CI->getZExtValue ();
444
+ } else {
445
+ llvm::errs () << *name << " \n " ;
446
+ llvm::errs () << " Use of "
447
+ << " enzyme_allocation_like"
448
+ << " requires an integer index"
449
+ << " \n " ;
450
+ llvm_unreachable (" enzyme_allocation_like" );
451
+ }
452
+
453
+ StringRef deallocIndStr;
454
+ bool foundInd = false ;
455
+ if (auto GV = dyn_cast<GlobalVariable>(deallocind))
456
+ if (GV->isConstant ())
457
+ if (auto C = GV->getInitializer ())
458
+ if (auto CA = dyn_cast<ConstantDataArray>(C))
459
+ if (CA->getType ()->getElementType ()->isIntegerTy (8 ) &&
460
+ CA->isCString ()) {
461
+ deallocIndStr = CA->getAsCString ();
462
+ foundInd = true ;
463
+ }
464
+
465
+ if (!foundInd) {
466
+ llvm::errs () << *deallocind << " \n " ;
467
+ llvm::errs () << " Use of "
468
+ << " enzyme_allocation_like"
469
+ << " requires a deallocation index string"
470
+ << " \n " ;
471
+ llvm_unreachable (" enzyme_allocation_like" );
472
+ }
473
+ if (auto F = dyn_cast<Function>(V)) {
474
+ F->addAttribute (AttributeList::FunctionIndex,
475
+ Attribute::get (g.getContext (), " enzyme_allocator" ,
476
+ std::to_string (index )));
477
+ } else {
478
+ llvm::errs () << M << " \n " ;
479
+ llvm::errs () << " Param of __enzyme_allocation_like must be a "
480
+ " function"
481
+ << g << " \n "
482
+ << *V << " \n " ;
483
+ llvm_unreachable (" __enzyme_allocation_like" );
484
+ }
485
+ cast<Function>(V)->addAttribute (
486
+ AttributeList::FunctionIndex,
487
+ Attribute::get (g.getContext (), " enzyme_deallocator" , deallocIndStr));
488
+
489
+ if (auto F = dyn_cast<Function>(deallocfn)) {
490
+ cast<Function>(V)->setMetadata (
491
+ " enzyme_deallocator_fn" ,
492
+ llvm::MDTuple::get (F->getContext (),
493
+ {llvm::ValueAsMetadata::get (F)}));
494
+ } else {
495
+ llvm::errs () << M << " \n " ;
496
+ llvm::errs () << " Free fn of __enzyme_allocation_like must be a "
497
+ " function"
498
+ << g << " \n "
499
+ << *deallocfn << " \n " ;
500
+ llvm_unreachable (" __enzyme_allocation_like" );
501
+ }
502
+ } else {
503
+ llvm::errs () << M << " \n " ;
504
+ llvm::errs () << " Use of __enzyme_allocation_like must be a "
505
+ " constant function "
506
+ << g << " \n " ;
507
+ llvm_unreachable (" __enzyme_allocation_like" );
508
+ }
509
+ globalsToErase.push_back (&g);
510
+ }
511
+ }
512
+
411
513
static void handleKnownFunctions (llvm::Function &F) {
412
514
if (F.getName () == " memcmp" ) {
413
515
F.addFnAttr (Attribute::ReadOnly);
@@ -2452,6 +2554,8 @@ class Enzyme final : public ModulePass {
2452
2554
handleInactiveFunction (M, g, globalsToErase);
2453
2555
} else if (g.getName ().contains (" __enzyme_function_like" )) {
2454
2556
handleFunctionLike (M, g, globalsToErase);
2557
+ } else if (g.getName ().contains (" __enzyme_allocation_like" )) {
2558
+ handleAllocationLike (M, g, globalsToErase);
2455
2559
}
2456
2560
}
2457
2561
for (auto g : globalsToErase) {
0 commit comments