@@ -8,6 +8,7 @@ use crate::type_::Type;
8
8
use crate :: type_of:: LayoutLlvmExt ;
9
9
use crate :: value:: Value ;
10
10
use libc:: { c_char, c_uint} ;
11
+ use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity } ;
11
12
use rustc_codegen_ssa:: common:: { IntPredicate , RealPredicate , SynchronizationScope , TypeKind } ;
12
13
use rustc_codegen_ssa:: mir:: operand:: { OperandRef , OperandValue } ;
13
14
use rustc_codegen_ssa:: mir:: place:: PlaceRef ;
@@ -30,6 +31,7 @@ use std::iter;
30
31
use std:: ops:: Deref ;
31
32
use std:: ptr;
32
33
34
+ use rustc_ast:: expand:: autodiff_attrs:: DiffMode ;
33
35
use crate :: typetree:: to_enzyme_typetree;
34
36
use rustc_ast:: expand:: typetree:: { TypeTree , FncTree } ;
35
37
@@ -136,6 +138,7 @@ macro_rules! builder_methods_for_value_instructions {
136
138
} ) +
137
139
}
138
140
}
141
+
139
142
pub fn add_tt2 < ' ll > ( llmod : & ' ll llvm:: Module , llcx : & ' ll llvm:: Context , fn_def : & ' ll Value , tt : FncTree ) {
140
143
let inputs = tt. args ;
141
144
let ret_tt: TypeTree = tt. ret ;
@@ -180,6 +183,107 @@ pub fn add_tt2<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def:
180
183
}
181
184
}
182
185
186
+ #[ allow( unused) ]
187
+ pub fn add_opt_dbg_helper < ' ll > ( llmod : & ' ll llvm:: Module , llcx : & ' ll llvm:: Context , val : & ' ll Value , attrs : AutoDiffAttrs , i : usize ) {
188
+ //pub mode: DiffMode,
189
+ //pub ret_activity: DiffActivity,
190
+ //pub input_activity: Vec<DiffActivity>,
191
+ let inputs = attrs. input_activity ;
192
+ let outputs = attrs. ret_activity ;
193
+ let ad_name = match attrs. mode {
194
+ DiffMode :: Forward => "__enzyme_fwddiff" ,
195
+ DiffMode :: Reverse => "__enzyme_autodiff" ,
196
+ DiffMode :: ForwardFirst => "__enzyme_fwddiff" ,
197
+ DiffMode :: ReverseFirst => "__enzyme_autodiff" ,
198
+ _ => panic ! ( "Why are we here?" ) ,
199
+ } ;
200
+
201
+ // Assuming that our val is the fnc square, want to generate the following llvm-ir:
202
+ // declare double @__enzyme_autodiff(...)
203
+ //
204
+ // define double @dsquare(double %x) {
205
+ // entry:
206
+ // %0 = tail call double (...) @__enzyme_autodiff(double (double)* nonnull @square, double %x)
207
+ // ret double %0
208
+ // }
209
+
210
+ let mut final_num_args;
211
+ unsafe {
212
+ let fn_ty = llvm:: LLVMRustGetFunctionType ( val) ;
213
+ let ret_ty = llvm:: LLVMGetReturnType ( fn_ty) ;
214
+
215
+ // First we add the declaration of the __enzyme function
216
+ let enzyme_ty = llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) ;
217
+ let ad_fn = llvm:: LLVMRustGetOrInsertFunction (
218
+ llmod,
219
+ ad_name. as_ptr ( ) as * const c_char ,
220
+ ad_name. len ( ) . try_into ( ) . unwrap ( ) ,
221
+ enzyme_ty,
222
+ ) ;
223
+
224
+ let wrapper_name = String :: from ( "enzyme_opt_helper_" ) + i. to_string ( ) . as_str ( ) ;
225
+ let wrapper_fn = llvm:: LLVMRustGetOrInsertFunction (
226
+ llmod,
227
+ wrapper_name. as_ptr ( ) as * const c_char ,
228
+ wrapper_name. len ( ) . try_into ( ) . unwrap ( ) ,
229
+ fn_ty,
230
+ ) ;
231
+ let entry = llvm:: LLVMAppendBasicBlockInContext ( llcx, wrapper_fn, "entry" . as_ptr ( ) as * const c_char ) ;
232
+ let builder = llvm:: LLVMCreateBuilderInContext ( llcx) ;
233
+ llvm:: LLVMPositionBuilderAtEnd ( builder, entry) ;
234
+ let num_args = llvm:: LLVMCountParams ( wrapper_fn) ;
235
+ let mut args = Vec :: with_capacity ( num_args as usize + 1 ) ;
236
+ args. push ( val) ;
237
+ // metadata !"enzyme_const"
238
+ let enzyme_const = llvm:: LLVMMDStringInContext ( llcx, "enzyme_const" . as_ptr ( ) as * const c_char , 12 ) ;
239
+ let enzyme_out = llvm:: LLVMMDStringInContext ( llcx, "enzyme_out" . as_ptr ( ) as * const c_char , 10 ) ;
240
+ let enzyme_dup = llvm:: LLVMMDStringInContext ( llcx, "enzyme_dup" . as_ptr ( ) as * const c_char , 10 ) ;
241
+ let enzyme_dupnoneed = llvm:: LLVMMDStringInContext ( llcx, "enzyme_dupnoneed" . as_ptr ( ) as * const c_char , 16 ) ;
242
+ final_num_args = num_args * 2 + 1 ;
243
+ for i in 0 ..num_args {
244
+ let arg = llvm:: LLVMGetParam ( wrapper_fn, i) ;
245
+ let activity = inputs[ i as usize ] ;
246
+ let ( activity, duplicated) : ( & Value , bool ) = match activity {
247
+ DiffActivity :: None => panic ! ( ) ,
248
+ DiffActivity :: Const => ( enzyme_const, false ) ,
249
+ DiffActivity :: Active => ( enzyme_out, false ) ,
250
+ DiffActivity :: ActiveOnly => ( enzyme_out, false ) ,
251
+ DiffActivity :: Dual => ( enzyme_dup, true ) ,
252
+ DiffActivity :: DualOnly => ( enzyme_dupnoneed, true ) ,
253
+ DiffActivity :: Duplicated => ( enzyme_dup, true ) ,
254
+ DiffActivity :: DuplicatedOnly => ( enzyme_dupnoneed, true ) ,
255
+ DiffActivity :: FakeActivitySize => ( enzyme_const, false ) ,
256
+ } ;
257
+ args. push ( activity) ;
258
+ args. push ( arg) ;
259
+ if duplicated {
260
+ final_num_args += 1 ;
261
+ args. push ( arg) ;
262
+ }
263
+ }
264
+
265
+ // declare void @__enzyme_autodiff(...)
266
+
267
+ // define void @enzyme_opt_helper_0(ptr %0, ptr %1) {
268
+ // call void (...) @__enzyme_autodiff(ptr @ffff, ptr %0, ptr %1)
269
+ // ret void
270
+ // }
271
+
272
+ let call = llvm:: LLVMBuildCall2 ( builder, enzyme_ty, ad_fn, args. as_mut_ptr ( ) , final_num_args as usize , ad_name. as_ptr ( ) as * const c_char ) ;
273
+ let void_ty = llvm:: LLVMVoidTypeInContext ( llcx) ;
274
+ if llvm:: LLVMTypeOf ( call) != void_ty {
275
+ llvm:: LLVMBuildRet ( builder, call) ;
276
+ } else {
277
+ llvm:: LLVMBuildRetVoid ( builder) ;
278
+ }
279
+ llvm:: LLVMDisposeBuilder ( builder) ;
280
+
281
+ let _fnc_ok =
282
+ llvm:: LLVMVerifyFunction ( wrapper_fn, llvm:: LLVMVerifierFailureAction :: LLVMAbortProcessAction ) ;
283
+ }
284
+
285
+ }
286
+
183
287
fn add_tt < ' ll > ( llmod : & ' ll llvm:: Module , llcx : & ' ll llvm:: Context , val : & ' ll Value , tt : FncTree ) {
184
288
let inputs = tt. args ;
185
289
let _ret: TypeTree = tt. ret ;
0 commit comments