Skip to content

Commit 8b32282

Browse files
committed
feat: add test for generics in generated function
1 parent 56a0c7d commit 8b32282

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

tests/pretty/autodiff/autodiff_forward.pp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131

3232
// We want to make sure that we can use the macro for functions defined inside of functions
3333

34+
// Make sure we can handle generics
35+
3436
::core::panicking::panic("not implemented")
3537
}
3638
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
@@ -181,4 +183,16 @@
181183
::core::hint::black_box(<f32>::default())
182184
}
183185
}
186+
#[rustc_autodiff]
187+
#[inline(never)]
188+
pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x }
189+
#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
190+
#[inline(never)]
191+
pub fn d_square<T: std::ops::Mul<Output = T> +
192+
Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
193+
unsafe { asm!("NOP", options(pure, nomem)); };
194+
::core::hint::black_box(f10(x));
195+
::core::hint::black_box((dx_0, dret));
196+
::core::hint::black_box(f10(x))
197+
}
184198
fn main() {}

tests/pretty/autodiff/autodiff_forward.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,10 @@ pub fn f9() {
6363
}
6464
}
6565

66+
// Make sure we can handle generics
67+
#[autodiff(d_square, Reverse, Duplicated, Active)]
68+
pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
69+
*x * *x
70+
}
71+
6672
fn main() {}

0 commit comments

Comments
 (0)