14
14
use std:: autodiff:: autodiff;
15
15
16
16
#[ no_mangle]
17
- #[ autodiff( d_square1, Forward , Dual , Dual ) ]
17
+ // #[autodiff(d_square1, Forward, Dual, Dual)]
18
18
#[ autodiff( d_square2, Forward , 4 , Dualv , Dualv ) ]
19
19
#[ autodiff( d_square3, Forward , 4 , Dual , Dual ) ]
20
20
fn square ( x : & [ f32 ] , y : & mut [ f32 ] ) {
@@ -28,7 +28,7 @@ fn square(x: &[f32], y: &mut [f32]) {
28
28
}
29
29
30
30
fn main ( ) {
31
- let x1 = std:: hint:: black_box ( vec ! [ 0.0 , 1.0 , 2.0 , 3.0 , 4.0 ] ) ;
31
+ let x1 = std:: hint:: black_box ( vec ! [ 0.0 , 1.0 , 2.0 , 3.0 ] ) ;
32
32
33
33
let mut dx1 = std:: hint:: black_box ( vec ! [ 1.0 ; 12 ] ) ;
34
34
@@ -66,37 +66,42 @@ fn main() {
66
66
let result = std:: hint:: black_box ( x1. iter ( ) . map ( |x| 2.0 * x) . collect :: < Vec < _ > > ( ) ) ;
67
67
68
68
// scalar.
69
- d_square1 ( & x1, & z1, & mut y1, & mut dy1_1) ;
70
- d_square1 ( & x1, & z2, & mut y2, & mut dy1_2) ;
71
- d_square1 ( & x1, & z3, & mut y3, & mut dy1_3) ;
72
- d_square1 ( & x1, & z4, & mut y4, & mut dy1_4) ;
69
+ // d_square1(&x1, &z1, &mut y1, &mut dy1_1);
70
+ // d_square1(&x1, &z2, &mut y2, &mut dy1_2);
71
+ // d_square1(&x1, &z3, &mut y3, &mut dy1_3);
72
+ // d_square1(&x1, &z4, &mut y4, &mut dy1_4);
73
73
74
74
// assert y1 == y2 == y3 == y4
75
- for i in 0 ..5 {
76
- assert_eq ! ( y1[ i] , y2[ i] ) ;
77
- assert_eq ! ( y1[ i] , y3[ i] ) ;
78
- assert_eq ! ( y1[ i] , y4[ i] ) ;
79
- }
75
+ // for i in 0..5 {
76
+ // assert_eq!(y1[i], y2[i]);
77
+ // assert_eq!(y1[i], y3[i]);
78
+ // assert_eq!(y1[i], y4[i]);
79
+ // }
80
80
81
81
// batch mode A)
82
82
//dx1 = std::hint::black_box(vec![1.0; 12]);
83
83
d_square2 ( & x1, & z5, & mut y5, & mut dy2) ;
84
84
85
85
// assert y1 == y2 == y3 == y4 == y5
86
- for i in 0 ..5 {
87
- assert_eq ! ( y1[ i] , y5[ i] ) ;
88
- }
86
+ // for i in 0..5 {
87
+ // assert_eq!(y1[i], y5[i]);
88
+ // }
89
89
90
90
// batch mode B)
91
91
d_square3 ( & x1, & z1, & z2, & z3, & z4, & mut y6, & mut dy3_1, & mut dy3_2, & mut dy3_3, & mut dy3_4) ;
92
92
for i in 0 ..5 {
93
- assert_eq ! ( y1 [ i] , y6[ i] ) ;
93
+ assert_eq ! ( y5 [ i] , y6[ i] ) ;
94
94
}
95
95
96
+ dbg ! ( & dy2) ;
97
+ dbg ! ( & dy3_1) ;
98
+ dbg ! ( & dy3_2) ;
99
+ dbg ! ( & dy3_3) ;
100
+ dbg ! ( & dy3_4) ;
96
101
for i in 0 ..5 {
97
- assert_eq ! ( dy1_1 [ i] , dy3_1[ i] ) ;
98
- assert_eq ! ( dy1_2 [ i] , dy3_2[ i] ) ;
99
- assert_eq ! ( dy1_3 [ i] , dy3_3[ i] ) ;
100
- assert_eq ! ( dy1_4 [ i] , dy3_4[ i] ) ;
102
+ assert_eq ! ( dy2 [ 0 .. 5 ] [ i] , dy3_1[ i] ) ;
103
+ assert_eq ! ( dy2 [ 5 .. 10 ] [ i] , dy3_2[ i] ) ;
104
+ assert_eq ! ( dy2 [ 10 .. 15 ] [ i] , dy3_3[ i] ) ;
105
+ assert_eq ! ( dy2 [ 15 .. 20 ] [ i] , dy3_4[ i] ) ;
101
106
}
102
107
}
0 commit comments