Skip to content

Commit 7018392

Browse files
committed
remove noinline attribute and add alwaysinline after AD pass
1 parent 9bc0401 commit 7018392

File tree

8 files changed

+104
-10
lines changed

8 files changed

+104
-10
lines changed

compiler/rustc_codegen_llvm/src/attributes.rs

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
//! Set and unset common attributes on LLVM values.
2-
32
use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
43
use rustc_codegen_ssa::traits::*;
54
use rustc_hir::def_id::DefId;
@@ -32,15 +31,15 @@ pub(crate) fn has_attr(llfn: &Value, idx: AttributePlace, attr: AttributeKind) -
3231
llvm::HasAttributeAtIndex(llfn, idx, attr)
3332
}
3433

35-
pub(crate) fn has_string_attr(llfn: &Value, name: *const i8) -> bool {
34+
pub(crate) fn has_string_attr(llfn: &Value, name: &str) -> bool {
3635
llvm::HasStringAttribute(llfn, name)
3736
}
3837

3938
pub(crate) fn remove_from_llfn(llfn: &Value, place: AttributePlace, kind: AttributeKind) {
4039
llvm::RemoveRustEnumAttributeAtIndex(llfn, place, kind);
4140
}
4241

43-
pub(crate) fn remove_string_attr_from_llfn(llfn: &Value, name: *const i8) {
42+
pub(crate) fn remove_string_attr_from_llfn(llfn: &Value, name: &str) {
4443
llvm::RemoveStringAttrFromFn(llfn, name);
4544
}
4645

compiler/rustc_codegen_llvm/src/back/lto.rs

+27-1
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ use crate::back::write::{
2828
use crate::errors::{
2929
DynamicLinkingWithLTO, LlvmError, LtoBitcodeFromRlib, LtoDisallowed, LtoDylib, LtoProcMacro,
3030
};
31+
use crate::llvm::AttributePlace::Function;
3132
use crate::llvm::{self, build_string};
32-
use crate::{LlvmCodegenBackend, ModuleLlvm};
33+
use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, attributes};
3334

3435
/// We keep track of the computed LTO cache keys from the previous
3536
/// session to determine which CGUs we can reuse.
@@ -666,6 +667,31 @@ pub(crate) fn run_pass_manager(
666667
}
667668

668669
if cfg!(llvm_enzyme) && enable_ad && !thin {
670+
let cx =
671+
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
672+
673+
for function in cx.get_functions() {
674+
let enzyme_marker = "enzyme_marker";
675+
if attributes::has_string_attr(function, enzyme_marker) {
676+
// Sanity check: Ensure 'noinline' is present before replacing it.
677+
assert!(
678+
!attributes::has_attr(function, Function, llvm::AttributeKind::NoInline),
679+
"Expected __enzyme function to have 'noinline' before adding 'alwaysinline'"
680+
);
681+
682+
attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline);
683+
attributes::remove_string_attr_from_llfn(function, enzyme_marker);
684+
685+
assert!(
686+
!attributes::has_string_attr(function, enzyme_marker),
687+
"Expected function to not have 'enzyme_marker'"
688+
);
689+
690+
let always_inline = llvm::AttributeKind::AlwaysInline.create_attr(cx.llcx);
691+
attributes::apply_to_llfn(function, Function, &[always_inline]);
692+
}
693+
}
694+
669695
let opt_stage = llvm::OptStage::FatLTO;
670696
let stage = write::AutodiffStage::PostAD;
671697
if !config.autodiff.contains(&config::AutoDiff::NoPostopt) {

compiler/rustc_codegen_llvm/src/context.rs

+10
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,16 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
698698
llvm::LLVMMDStringInContext2(self.llcx(), name.as_ptr() as *const c_char, name.len())
699699
})
700700
}
701+
702+
pub(crate) fn get_functions(&self) -> Vec<&'ll Value> {
703+
let mut functions = vec![];
704+
let mut func = unsafe { llvm::LLVMGetFirstFunction(self.llmod()) };
705+
while let Some(f) = func {
706+
functions.push(f);
707+
func = unsafe { llvm::LLVMGetNextFunction(f) }
708+
}
709+
functions
710+
}
701711
}
702712

703713
impl<'ll, 'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> {

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@ unsafe extern "C" {
1919
pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
2020
pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool;
2121
pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64;
22-
pub(crate) fn LLVMRustHasFnAttribute(F: &Value, Name: *const c_char) -> bool;
23-
pub(crate) fn LLVMRustRemoveFnAttribute(F: &Value, Name: *const c_char);
22+
pub(crate) fn LLVMRustHasFnAttribute(
23+
F: &Value,
24+
Name: *const c_char,
25+
NameLen: libc::size_t,
26+
) -> bool;
27+
pub(crate) fn LLVMRustRemoveFnAttribute(F: &Value, Name: *const c_char, NameLen: libc::size_t);
2428
pub(crate) fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>;
2529
pub(crate) fn LLVMGetNextFunction(Fn: &Value) -> Option<&Value>;
2630
pub(crate) fn LLVMRustRemoveEnumAttributeAtIndex(

compiler/rustc_codegen_llvm/src/llvm/mod.rs

+26
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,32 @@ pub(crate) fn AddFunctionAttributes<'ll>(
4141
}
4242
}
4343

44+
pub(crate) fn HasAttributeAtIndex<'ll>(
45+
llfn: &'ll Value,
46+
idx: AttributePlace,
47+
kind: AttributeKind,
48+
) -> bool {
49+
unsafe { LLVMRustHasAttributeAtIndex(llfn, idx.as_uint(), kind) }
50+
}
51+
52+
pub(crate) fn HasStringAttribute<'ll>(llfn: &'ll Value, name: &str) -> bool {
53+
unsafe { LLVMRustHasFnAttribute(llfn, name.as_c_char_ptr(), name.len()) }
54+
}
55+
56+
pub(crate) fn RemoveStringAttrFromFn<'ll>(llfn: &'ll Value, name: &str) {
57+
unsafe { LLVMRustRemoveFnAttribute(llfn, name.as_c_char_ptr(), name.len()) }
58+
}
59+
60+
pub(crate) fn RemoveRustEnumAttributeAtIndex(
61+
llfn: &Value,
62+
place: AttributePlace,
63+
kind: AttributeKind,
64+
) {
65+
unsafe {
66+
LLVMRustRemoveEnumAttributeAtIndex(llfn, place.as_uint(), kind);
67+
}
68+
}
69+
4470
pub(crate) fn AddCallSiteAttributes<'ll>(
4571
callsite: &'ll Value,
4672
idx: AttributePlace,

compiler/rustc_codegen_llvm/src/type_.rs

+4
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
128128
(**self).borrow().llcx
129129
}
130130

131+
pub(crate) fn llmod(&self) -> &'ll llvm::Module {
132+
(**self).borrow().llmod
133+
}
134+
131135
pub(crate) fn isize_ty(&self) -> &'ll Type {
132136
(**self).borrow().isize_ty
133137
}

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -979,16 +979,18 @@ LLVMRustRemoveEnumAttributeAtIndex(LLVMValueRef F, size_t index,
979979
LLVMRemoveEnumAttributeAtIndex(F, index, fromRust(RustAttr));
980980
}
981981

982-
extern "C" bool LLVMRustHasFnAttribute(LLVMValueRef F, const char *Name) {
982+
extern "C" bool LLVMRustHasFnAttribute(LLVMValueRef F, const char *Name,
983+
size_t NameLen) {
983984
if (auto *Fn = dyn_cast<Function>(unwrap<Value>(F))) {
984-
return Fn->hasFnAttribute(Name);
985+
return Fn->hasFnAttribute(StringRef(Name, NameLen));
985986
}
986987
return false;
987988
}
988989

989-
extern "C" void LLVMRustRemoveFnAttribute(LLVMValueRef Fn, const char *Name) {
990+
extern "C" void LLVMRustRemoveFnAttribute(LLVMValueRef Fn, const char *Name,
991+
size_t NameLen) {
990992
if (auto *F = dyn_cast<Function>(unwrap<Value>(Fn))) {
991-
F->removeFnAttr(Name);
993+
F->removeFnAttr(StringRef(Name, NameLen));
992994
}
993995
}
994996

tests/codegen/autodiff/inline.rs

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat -Zautodiff=NoPostopt
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
5+
#![feature(autodiff)]
6+
7+
use std::autodiff::autodiff;
8+
9+
#[autodiff(d_square, Reverse, Duplicated, Active)]
10+
fn square(x: &f64) -> f64 {
11+
x * x
12+
}
13+
14+
// CHECK: ; inline::d_square
15+
// CHECK-NEXT: ; Function Attrs: alwaysinline
16+
// CHECK-NOT: noinline
17+
// CHECK-NEXT: define internal fastcc void @_ZN6inline8d_square17h021c74e92c259cdeE
18+
fn main() {
19+
let x = std::hint::black_box(3.0);
20+
let mut dx1 = std::hint::black_box(1.0);
21+
let _ = d_square(&x, &mut dx1, 1.0);
22+
assert_eq!(dx1, 6.0);
23+
}

0 commit comments

Comments
 (0)