|
| 1 | +//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat |
| 2 | +//@ no-prefer-dynamic |
| 3 | +//@ needs-enzyme |
| 4 | +// |
| 5 | +// In Enzyme, we test against a large range of LLVM versions (5+) and don't have overly many |
| 6 | +// breakages. One benefit is that we match the IR generated by Enzyme only after running it |
| 7 | +// through LLVM's O3 pipeline, which will remove most of the noise. |
| 8 | +// However, our integration test could also be affected by changes in how rustc lowers MIR into |
| 9 | +// LLVM-IR, which could cause additional noise and thus breakages. If that's the case, we should |
| 10 | +// reduce this test to only match the first lines and the ret instructions. |
| 11 | +// |
| 12 | +// The function tested here has 4 inputs and 5 outputs, so we could either call forward-mode |
| 13 | +// autodiff 4 times, or reverse mode 5 times. Since a forward-mode call is usually faster than |
| 14 | +// reverse mode, we prefer it here. This file also tests a new optimization (batch mode), which |
| 15 | +// allows us to call forward-mode autodiff only once, and get all 5 outputs in a single call. |
| 16 | +// |
| 17 | +// We support 2 different batch modes. `d_square2` has the same interface as scalar forward-mode, |
| 18 | +// but each shadow argument is `width` times larger (thus 16 and 20 elements here). |
| 19 | +// `d_square3` instead takes `width` (4) shadow arguments, which are all the same size as the |
| 20 | +// original function arguments. |
| 21 | +// |
| 22 | +// FIXME(autodiff): We currently can't test `d_square1` and `d_square3` in the same file, since they |
| 23 | +// generate the same dummy functions which get merged by LLVM, breaking pieces of our pipeline which |
| 24 | +// try to rewrite the dummy functions later. We should consider to change to pure declarations both |
| 25 | +// in our frontend and in the llvm backend to avoid these issues. |
| 26 | + |
| 27 | +#![feature(autodiff)] |
| 28 | + |
| 29 | +use std::autodiff::autodiff; |
| 30 | + |
| 31 | +#[no_mangle] |
| 32 | +//#[autodiff(d_square1, Forward, Dual, Dual)] |
| 33 | +#[autodiff(d_square2, Forward, 4, Dualv, Dualv)] |
| 34 | +#[autodiff(d_square3, Forward, 4, Dual, Dual)] |
| 35 | +fn square(x: &[f32], y: &mut [f32]) { |
| 36 | + assert!(x.len() >= 4); |
| 37 | + assert!(y.len() >= 5); |
| 38 | + y[0] = 4.3 * x[0] + 1.2 * x[1] + 3.4 * x[2] + 2.1 * x[3]; |
| 39 | + y[1] = 2.3 * x[0] + 4.5 * x[1] + 1.7 * x[2] + 6.4 * x[3]; |
| 40 | + y[2] = 1.1 * x[0] + 3.3 * x[1] + 2.5 * x[2] + 4.7 * x[3]; |
| 41 | + y[3] = 5.2 * x[0] + 1.4 * x[1] + 2.6 * x[2] + 3.8 * x[3]; |
| 42 | + y[4] = 1.0 * x[0] + 2.0 * x[1] + 3.0 * x[2] + 4.0 * x[3]; |
| 43 | +} |
| 44 | + |
| 45 | +fn main() { |
| 46 | + let x1 = std::hint::black_box(vec![0.0, 1.0, 2.0, 3.0]); |
| 47 | + |
| 48 | + let dx1 = std::hint::black_box(vec![1.0; 12]); |
| 49 | + |
| 50 | + let z1 = std::hint::black_box(vec![1.0, 0.0, 0.0, 0.0]); |
| 51 | + let z2 = std::hint::black_box(vec![0.0, 1.0, 0.0, 0.0]); |
| 52 | + let z3 = std::hint::black_box(vec![0.0, 0.0, 1.0, 0.0]); |
| 53 | + let z4 = std::hint::black_box(vec![0.0, 0.0, 0.0, 1.0]); |
| 54 | + |
| 55 | + let z5 = std::hint::black_box(vec![ |
| 56 | + 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, |
| 57 | + ]); |
| 58 | + |
| 59 | + let mut y1 = std::hint::black_box(vec![0.0; 5]); |
| 60 | + let mut y2 = std::hint::black_box(vec![0.0; 5]); |
| 61 | + let mut y3 = std::hint::black_box(vec![0.0; 5]); |
| 62 | + let mut y4 = std::hint::black_box(vec![0.0; 5]); |
| 63 | + |
| 64 | + let mut y5 = std::hint::black_box(vec![0.0; 5]); |
| 65 | + |
| 66 | + let mut y6 = std::hint::black_box(vec![0.0; 5]); |
| 67 | + |
| 68 | + let mut dy1_1 = std::hint::black_box(vec![0.0; 5]); |
| 69 | + let mut dy1_2 = std::hint::black_box(vec![0.0; 5]); |
| 70 | + let mut dy1_3 = std::hint::black_box(vec![0.0; 5]); |
| 71 | + let mut dy1_4 = std::hint::black_box(vec![0.0; 5]); |
| 72 | + |
| 73 | + let mut dy2 = std::hint::black_box(vec![0.0; 20]); |
| 74 | + |
| 75 | + let mut dy3_1 = std::hint::black_box(vec![0.0; 5]); |
| 76 | + let mut dy3_2 = std::hint::black_box(vec![0.0; 5]); |
| 77 | + let mut dy3_3 = std::hint::black_box(vec![0.0; 5]); |
| 78 | + let mut dy3_4 = std::hint::black_box(vec![0.0; 5]); |
| 79 | + |
| 80 | + // scalar. |
| 81 | + //d_square1(&x1, &z1, &mut y1, &mut dy1_1); |
| 82 | + //d_square1(&x1, &z2, &mut y2, &mut dy1_2); |
| 83 | + //d_square1(&x1, &z3, &mut y3, &mut dy1_3); |
| 84 | + //d_square1(&x1, &z4, &mut y4, &mut dy1_4); |
| 85 | + |
| 86 | + // assert y1 == y2 == y3 == y4 |
| 87 | + //for i in 0..5 { |
| 88 | + // assert_eq!(y1[i], y2[i]); |
| 89 | + // assert_eq!(y1[i], y3[i]); |
| 90 | + // assert_eq!(y1[i], y4[i]); |
| 91 | + //} |
| 92 | + |
| 93 | + // batch mode A) |
| 94 | + d_square2(&x1, &z5, &mut y5, &mut dy2); |
| 95 | + |
| 96 | + // assert y1 == y2 == y3 == y4 == y5 |
| 97 | + //for i in 0..5 { |
| 98 | + // assert_eq!(y1[i], y5[i]); |
| 99 | + //} |
| 100 | + |
| 101 | + // batch mode B) |
| 102 | + d_square3(&x1, &z1, &z2, &z3, &z4, &mut y6, &mut dy3_1, &mut dy3_2, &mut dy3_3, &mut dy3_4); |
| 103 | + for i in 0..5 { |
| 104 | + assert_eq!(y5[i], y6[i]); |
| 105 | + } |
| 106 | + |
| 107 | + for i in 0..5 { |
| 108 | + assert_eq!(dy2[0..5][i], dy3_1[i]); |
| 109 | + assert_eq!(dy2[5..10][i], dy3_2[i]); |
| 110 | + assert_eq!(dy2[10..15][i], dy3_3[i]); |
| 111 | + assert_eq!(dy2[15..20][i], dy3_4[i]); |
| 112 | + } |
| 113 | +} |
0 commit comments