Skip to content

Commit 56a0c7d

Browse files
committed
feat: propagate generics to generated function
1 parent 16c1c54 commit 56a0c7d

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ mod llvm_enzyme {
7373
}
7474

7575
// Get information about the function the macro is applied to
76-
fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident)> {
76+
fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
7777
match &iitem.kind {
78-
ItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
79-
Some((iitem.vis.clone(), sig.clone(), ident.clone()))
78+
ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
79+
Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
8080
}
8181
_ => None,
8282
}
@@ -210,16 +210,18 @@ mod llvm_enzyme {
210210
}
211211
let dcx = ecx.sess.dcx();
212212

213-
// first get information about the annotable item:
214-
let Some((vis, sig, primal)) = (match &item {
213+
// first get information about the annotable item: visibility, signature, name and generic
214+
// parameters.
215+
// these will be used to generate the differentiated version of the function
216+
let Some((vis, sig, primal, generics)) = (match &item {
215217
Annotatable::Item(iitem) => extract_item_info(iitem),
216218
Annotatable::Stmt(stmt) => match &stmt.kind {
217219
ast::StmtKind::Item(iitem) => extract_item_info(iitem),
218220
_ => None,
219221
},
220222
Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind {
221-
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
222-
Some((assoc_item.vis.clone(), sig.clone(), ident.clone()))
223+
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
224+
Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
223225
}
224226
_ => None,
225227
},
@@ -310,7 +312,7 @@ mod llvm_enzyme {
310312
defaultness: ast::Defaultness::Final,
311313
sig: d_sig,
312314
ident: first_ident(&meta_item_vec[0]),
313-
generics: Generics::default(),
315+
generics,
314316
contract: None,
315317
body: Some(d_body),
316318
define_opaque: None,

0 commit comments

Comments
 (0)