Skip to content

Commit 63fa872

Browse files
authored
Create a opt dbg helper (#134)
* Implement a opt dbg helper to create compiler-explorer MWE bug reproducer.
1 parent 269d384 commit 63fa872

File tree

2 files changed

+115
-5
lines changed

2 files changed

+115
-5
lines changed

compiler/rustc_codegen_llvm/src/back/write.rs

+11-5
Original file line numberDiff line numberDiff line change
@@ -1093,24 +1093,30 @@ pub(crate) unsafe fn differentiate(
10931093
}
10941094

10951095
// Before dumping the module, we want all the tt to become part of the module.
1096-
for item in &diff_items {
1096+
for (i, item) in diff_items.iter().enumerate() {
10971097
let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
10981098
let llvm_data_layout =
10991099
std::str::from_utf8(unsafe { CStr::from_ptr(llvm_data_layout) }.to_bytes())
11001100
.expect("got a non-UTF8 data-layout from LLVM");
1101-
//let input_tts: Vec<TypeTree> =
1102-
// item.inputs.iter().map(|x| to_enzyme_typetree(x.clone(), llvm_data_layout, llcx)).collect();
1103-
//let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx);
11041101
let tt: FncTree = FncTree {
11051102
args: item.inputs.clone(),
11061103
ret: item.output.clone(),
11071104
};
11081105
let name = CString::new(item.source.clone()).unwrap();
11091106
let fn_def: &llvm::Value = llvm::LLVMGetNamedFunction(llmod, name.as_ptr()).unwrap();
11101107
crate::builder::add_tt2(llmod, llcx, fn_def, tt);
1108+
1109+
// Before dumping the module, we also might want to add dummy functions, which will
1110+
// trigger the LLVMEnzyme pass to run on them, if we invoke the opt binary.
1111+
// This is super helpfull if we want to create a MWE bug reproducer, e.g. to run in
1112+
// Enzyme's compiler explorer. TODO: Can we run llvm-extract on the module to remove all other functions?
1113+
if std::env::var("ENZYME_OPT").is_ok() {
1114+
dbg!("Enable extra debug helper to debug Enzyme through the opt plugin");
1115+
crate::builder::add_opt_dbg_helper(llmod, llcx, fn_def, item.attrs.clone(), i);
1116+
}
11111117
}
11121118

1113-
if std::env::var("ENZYME_PRINT_MOD_BEFORE").is_ok() {
1119+
if std::env::var("ENZYME_PRINT_MOD_BEFORE").is_ok() || std::env::var("ENZYME_OPT").is_ok(){
11141120
unsafe {
11151121
LLVMDumpModule(llmod);
11161122
}

compiler/rustc_codegen_llvm/src/builder.rs

+104
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::type_::Type;
88
use crate::type_of::LayoutLlvmExt;
99
use crate::value::Value;
1010
use libc::{c_char, c_uint};
11+
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity};
1112
use rustc_codegen_ssa::common::{IntPredicate, RealPredicate, SynchronizationScope, TypeKind};
1213
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
1314
use rustc_codegen_ssa::mir::place::PlaceRef;
@@ -30,6 +31,7 @@ use std::iter;
3031
use std::ops::Deref;
3132
use std::ptr;
3233

34+
use rustc_ast::expand::autodiff_attrs::DiffMode;
3335
use crate::typetree::to_enzyme_typetree;
3436
use rustc_ast::expand::typetree::{TypeTree, FncTree};
3537

@@ -136,6 +138,7 @@ macro_rules! builder_methods_for_value_instructions {
136138
})+
137139
}
138140
}
141+
139142
pub fn add_tt2<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def: &'ll Value, tt: FncTree) {
140143
let inputs = tt.args;
141144
let ret_tt: TypeTree = tt.ret;
@@ -180,6 +183,107 @@ pub fn add_tt2<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def:
180183
}
181184
}
182185

186+
#[allow(unused)]
187+
pub fn add_opt_dbg_helper<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, val: &'ll Value, attrs: AutoDiffAttrs, i: usize) {
188+
//pub mode: DiffMode,
189+
//pub ret_activity: DiffActivity,
190+
//pub input_activity: Vec<DiffActivity>,
191+
let inputs = attrs.input_activity;
192+
let outputs = attrs.ret_activity;
193+
let ad_name = match attrs.mode {
194+
DiffMode::Forward => "__enzyme_fwddiff",
195+
DiffMode::Reverse => "__enzyme_autodiff",
196+
DiffMode::ForwardFirst => "__enzyme_fwddiff",
197+
DiffMode::ReverseFirst => "__enzyme_autodiff",
198+
_ => panic!("Why are we here?"),
199+
};
200+
201+
// Assuming that our val is the fnc square, want to generate the following llvm-ir:
202+
// declare double @__enzyme_autodiff(...)
203+
//
204+
// define double @dsquare(double %x) {
205+
// entry:
206+
// %0 = tail call double (...) @__enzyme_autodiff(double (double)* nonnull @square, double %x)
207+
// ret double %0
208+
// }
209+
210+
let mut final_num_args;
211+
unsafe {
212+
let fn_ty = llvm::LLVMRustGetFunctionType(val);
213+
let ret_ty = llvm::LLVMGetReturnType(fn_ty);
214+
215+
// First we add the declaration of the __enzyme function
216+
let enzyme_ty = llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True);
217+
let ad_fn = llvm::LLVMRustGetOrInsertFunction(
218+
llmod,
219+
ad_name.as_ptr() as *const c_char,
220+
ad_name.len().try_into().unwrap(),
221+
enzyme_ty,
222+
);
223+
224+
let wrapper_name = String::from("enzyme_opt_helper_") + i.to_string().as_str();
225+
let wrapper_fn = llvm::LLVMRustGetOrInsertFunction(
226+
llmod,
227+
wrapper_name.as_ptr() as *const c_char,
228+
wrapper_name.len().try_into().unwrap(),
229+
fn_ty,
230+
);
231+
let entry = llvm::LLVMAppendBasicBlockInContext(llcx, wrapper_fn, "entry".as_ptr() as *const c_char);
232+
let builder = llvm::LLVMCreateBuilderInContext(llcx);
233+
llvm::LLVMPositionBuilderAtEnd(builder, entry);
234+
let num_args = llvm::LLVMCountParams(wrapper_fn);
235+
let mut args = Vec::with_capacity(num_args as usize + 1);
236+
args.push(val);
237+
// metadata !"enzyme_const"
238+
let enzyme_const = llvm::LLVMMDStringInContext(llcx, "enzyme_const".as_ptr() as *const c_char, 12);
239+
let enzyme_out = llvm::LLVMMDStringInContext(llcx, "enzyme_out".as_ptr() as *const c_char, 10);
240+
let enzyme_dup = llvm::LLVMMDStringInContext(llcx, "enzyme_dup".as_ptr() as *const c_char, 10);
241+
let enzyme_dupnoneed = llvm::LLVMMDStringInContext(llcx, "enzyme_dupnoneed".as_ptr() as *const c_char, 16);
242+
final_num_args = num_args * 2 + 1;
243+
for i in 0..num_args {
244+
let arg = llvm::LLVMGetParam(wrapper_fn, i);
245+
let activity = inputs[i as usize];
246+
let (activity, duplicated): (&Value, bool) = match activity {
247+
DiffActivity::None => panic!(),
248+
DiffActivity::Const => (enzyme_const, false),
249+
DiffActivity::Active => (enzyme_out, false),
250+
DiffActivity::ActiveOnly => (enzyme_out, false),
251+
DiffActivity::Dual => (enzyme_dup, true),
252+
DiffActivity::DualOnly => (enzyme_dupnoneed, true),
253+
DiffActivity::Duplicated => (enzyme_dup, true),
254+
DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true),
255+
DiffActivity::FakeActivitySize => (enzyme_const, false),
256+
};
257+
args.push(activity);
258+
args.push(arg);
259+
if duplicated {
260+
final_num_args += 1;
261+
args.push(arg);
262+
}
263+
}
264+
265+
// declare void @__enzyme_autodiff(...)
266+
267+
// define void @enzyme_opt_helper_0(ptr %0, ptr %1) {
268+
// call void (...) @__enzyme_autodiff(ptr @ffff, ptr %0, ptr %1)
269+
// ret void
270+
// }
271+
272+
let call = llvm::LLVMBuildCall2(builder, enzyme_ty, ad_fn, args.as_mut_ptr(), final_num_args as usize, ad_name.as_ptr() as *const c_char);
273+
let void_ty = llvm::LLVMVoidTypeInContext(llcx);
274+
if llvm::LLVMTypeOf(call) != void_ty {
275+
llvm::LLVMBuildRet(builder, call);
276+
} else {
277+
llvm::LLVMBuildRetVoid(builder);
278+
}
279+
llvm::LLVMDisposeBuilder(builder);
280+
281+
let _fnc_ok =
282+
llvm::LLVMVerifyFunction(wrapper_fn, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction);
283+
}
284+
285+
}
286+
183287
fn add_tt<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context,val: &'ll Value, tt: FncTree) {
184288
let inputs = tt.args;
185289
let _ret: TypeTree = tt.ret;

0 commit comments

Comments
 (0)