Skip to content

Commit 264249f

Browse files
authored
Rollup merge of #140104 - Shourya742:2025-04-21-auto-diff-fails-on-impl-block, r=ZuseZ4
Fix auto diff failing on inherent impl blocks closes: #139557 r? ``@ZuseZ4``
2 parents 107f04d + b8ca007 commit 264249f

File tree

7 files changed

+71
-8
lines changed

7 files changed

+71
-8
lines changed

Diff for: compiler/rustc_builtin_macros/src/autodiff.rs

+6-8
Original file line numberDiff line numberDiff line change
@@ -217,14 +217,12 @@ mod llvm_enzyme {
217217
ast::StmtKind::Item(iitem) => extract_item_info(iitem),
218218
_ => None,
219219
},
220-
Annotatable::AssocItem(assoc_item, Impl { of_trait: false }) => {
221-
match &assoc_item.kind {
222-
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
223-
Some((assoc_item.vis.clone(), sig.clone(), ident.clone()))
224-
}
225-
_ => None,
220+
Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind {
221+
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
222+
Some((assoc_item.vis.clone(), sig.clone(), ident.clone()))
226223
}
227-
}
224+
_ => None,
225+
},
228226
_ => None,
229227
}) else {
230228
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
@@ -365,7 +363,7 @@ mod llvm_enzyme {
365363
}
366364
Annotatable::Item(iitem.clone())
367365
}
368-
Annotatable::AssocItem(ref mut assoc_item, i @ Impl { of_trait: false }) => {
366+
Annotatable::AssocItem(ref mut assoc_item, i @ Impl { .. }) => {
369367
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
370368
assoc_item.attrs.push(attr);
371369
}
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

Diff for: tests/pretty/autodiff/inherent_impl.pp

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#![feature(prelude_import)]
2+
#![no_std]
3+
//@ needs-enzyme
4+
5+
#![feature(autodiff)]
6+
#[prelude_import]
7+
use ::std::prelude::rust_2015::*;
8+
#[macro_use]
9+
extern crate std;
10+
//@ pretty-mode:expanded
11+
//@ pretty-compare-only
12+
//@ pp-exact:inherent_impl.pp
13+
14+
use std::autodiff::autodiff;
15+
16+
struct Foo {
17+
a: f64,
18+
}
19+
20+
trait MyTrait {
21+
fn f(&self, x: f64)
22+
-> f64;
23+
fn df(&self, x: f64, seed: f64)
24+
-> (f64, f64);
25+
}
26+
27+
impl MyTrait for Foo {
28+
#[rustc_autodiff]
29+
#[inline(never)]
30+
fn f(&self, x: f64) -> f64 {
31+
self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln())
32+
}
33+
#[rustc_autodiff(Reverse, 1, Const, Active, Active)]
34+
#[inline(never)]
35+
fn df(&self, x: f64, dret: f64) -> (f64, f64) {
36+
unsafe { asm!("NOP", options(pure, nomem)); };
37+
::core::hint::black_box(self.f(x));
38+
::core::hint::black_box((dret,));
39+
::core::hint::black_box((self.f(x), f64::default()))
40+
}
41+
}

Diff for: tests/pretty/autodiff/inherent_impl.rs

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//@ needs-enzyme
2+
3+
#![feature(autodiff)]
4+
//@ pretty-mode:expanded
5+
//@ pretty-compare-only
6+
//@ pp-exact:inherent_impl.pp
7+
8+
use std::autodiff::autodiff;
9+
10+
struct Foo {
11+
a: f64,
12+
}
13+
14+
trait MyTrait {
15+
fn f(&self, x: f64) -> f64;
16+
fn df(&self, x: f64, seed: f64) -> (f64, f64);
17+
}
18+
19+
impl MyTrait for Foo {
20+
#[autodiff(df, Reverse, Const, Active, Active)]
21+
fn f(&self, x: f64) -> f64 {
22+
self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln())
23+
}
24+
}

0 commit comments

Comments
 (0)