Skip to content

Commit 7961e78

Browse files
authored
Fix header rematerialization (rust-lang#604)
* Fix rematerialization of loop header * Fix early version tests * Fix 13 bug
1 parent 9e76494 commit 7961e78

File tree

4 files changed

+338
-4
lines changed

4 files changed

+338
-4
lines changed

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,10 +1361,12 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
13611361
}
13621362
IRBuilder<> B(blocks[i]);
13631363

1364-
for (auto pair : unwrap_cache[oldB])
1365-
unwrap_cache[blocks[i]].insert(pair);
1366-
for (auto pair : lookup_cache[oldB])
1367-
lookup_cache[blocks[i]].insert(pair);
1364+
if (!prevIteration.count(PB)) {
1365+
for (auto pair : unwrap_cache[oldB])
1366+
unwrap_cache[blocks[i]].insert(pair);
1367+
for (auto pair : lookup_cache[oldB])
1368+
lookup_cache[blocks[i]].insert(pair);
1369+
}
13681370

13691371
if (auto inst =
13701372
dyn_cast<Instruction>(phi->getIncomingValueForBlock(PB))) {
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S -gvn -dse -dse | FileCheck %s
2+
3+
source_filename = "<source>"
4+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
5+
target triple = "x86_64-unknown-linux-gnu"
6+
7+
@.str = private unnamed_addr constant [11 x i8] c"dtheta=%d\0A\00", align 1
8+
@.str.1 = private unnamed_addr constant [23 x i8] c"dout[%d]=%f answer=%d\0A\00", align 1
9+
10+
define i32 @_Z18evaluate_integrandii(i32 %arg, i32 %arg1) {
11+
bb:
12+
%i = mul nsw i32 %arg1, %arg
13+
ret i32 %i
14+
}
15+
16+
define dso_local double @_Z15integrate_imagedPd(double %arg, double* nocapture %arg1) {
17+
bb:
18+
br label %bb5
19+
20+
bb5: ; preds = %bb5, %bb
21+
%i6 = phi i64 [ 0, %bb ], [ %i19, %bb5 ]
22+
%i7 = phi double [ %arg, %bb ], [ %i17, %bb5 ]
23+
%i8 = phi double [ 1.000000e+00, %bb ], [ %i18, %bb5 ]
24+
%i9 = fptosi double %i7 to i32
25+
%i10 = fptosi double %i8 to i32
26+
%i11 = mul nsw i32 %i9, %i10
27+
%i12 = sitofp i32 %i11 to double
28+
%i13 = getelementptr inbounds double, double* %arg1, i64 %i6
29+
%i14 = load double, double* %i13, align 8
30+
%i15 = fmul double %i14, %i12
31+
store double %i15, double* %i13, align 8
32+
%i17 = fdiv double %i7, 8.000000e-01
33+
%i18 = fmul double %i17, 2.500000e-01
34+
%i19 = add nuw nsw i64 %i6, 1
35+
%i20 = icmp eq i64 %i19, 10
36+
br i1 %i20, label %bb2, label %bb5
37+
38+
bb2: ; preds = %bb5
39+
%i = tail call i32 @_Z18evaluate_integrandii(i32 %i9, i32 %i10)
40+
%i3 = sitofp i32 %i to double
41+
%i4 = fmul double %i7, %i3
42+
ret double %i4
43+
}
44+
45+
; Function Attrs: nofree nounwind
46+
declare dso_local i32 @printf(i8* nocapture readonly, ...)
47+
48+
; Function Attrs: norecurse uwtable mustprogress
49+
define dso_local i32 @main() {
50+
bb:
51+
%i = alloca [10 x double], align 16
52+
%i1 = alloca [10 x double], align 16
53+
%i2 = bitcast [10 x double]* %i to i8*
54+
%i3 = bitcast [10 x double]* %i1 to i8*
55+
%i4 = getelementptr inbounds [10 x double], [10 x double]* %i1, i64 0, i64 0
56+
store double 1.000000e+00, double* %i4, align 16
57+
%i5 = getelementptr inbounds [10 x double], [10 x double]* %i1, i64 0, i64 1
58+
store double 1.000000e+00, double* %i5, align 8
59+
%i6 = getelementptr inbounds [10 x double], [10 x double]* %i1, i64 0, i64 2
60+
store double 1.000000e+00, double* %i6, align 16
61+
%i7 = getelementptr inbounds [10 x double], [10 x double]* %i1, i64 0, i64 3
62+
store double 1.000000e+00, double* %i7, align 8
63+
%i8 = getelementptr inbounds [10 x double], [10 x double]* %i1, i64 0, i64 4
64+
store double 1.000000e+00, double* %i8, align 16
65+
%i9 = getelementptr inbounds [10 x double], [10 x double]* %i1, i64 0, i64 5
66+
store double 1.000000e+00, double* %i9, align 8
67+
%i10 = getelementptr inbounds [10 x double], [10 x double]* %i1, i64 0, i64 6
68+
store double 1.000000e+00, double* %i10, align 16
69+
%i11 = getelementptr inbounds [10 x double], [10 x double]* %i1, i64 0, i64 7
70+
store double 1.000000e+00, double* %i11, align 8
71+
%i12 = getelementptr inbounds [10 x double], [10 x double]* %i1, i64 0, i64 8
72+
store double 1.000000e+00, double* %i12, align 16
73+
%i13 = getelementptr inbounds [10 x double], [10 x double]* %i1, i64 0, i64 9
74+
store double 1.000000e+00, double* %i13, align 8
75+
%i14 = getelementptr inbounds [10 x double], [10 x double]* %i, i64 0, i64 0
76+
call void (double (double, double*)*, ...) @_Z17__enzyme_autodiffPFddPdEz(double (double, double*)* nonnull @_Z15integrate_imagedPd, double 2.000000e+02, double* nonnull %i14, double* nonnull %i4)
77+
%i15 = load double, double* %i4, align 16
78+
%i16 = call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([23 x i8], [23 x i8]* @.str.1, i64 0, i64 0), i32 0, double %i15, i32 200)
79+
%i17 = load double, double* %i5, align 8
80+
%i18 = call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([23 x i8], [23 x i8]* @.str.1, i64 0, i64 0), i32 1, double %i17, i32 15500)
81+
%i19 = load double, double* %i6, align 16
82+
%i20 = call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([23 x i8], [23 x i8]* @.str.1, i64 0, i64 0), i32 2, double %i19, i32 24336)
83+
%i21 = load double, double* %i7, align 8
84+
%i22 = call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([23 x i8], [23 x i8]* @.str.1, i64 0, i64 0), i32 3, double %i21, i32 37830)
85+
%i23 = load double, double* %i8, align 16
86+
%i24 = call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([23 x i8], [23 x i8]* @.str.1, i64 0, i64 0), i32 4, double %i23, i32 59536)
87+
%i25 = load double, double* %i9, align 8
88+
%i26 = call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([23 x i8], [23 x i8]* @.str.1, i64 0, i64 0), i32 5, double %i25, i32 92720)
89+
%i27 = load double, double* %i10, align 16
90+
%i28 = call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([23 x i8], [23 x i8]* @.str.1, i64 0, i64 0), i32 6, double %i27, i32 144780)
91+
%i29 = load double, double* %i11, align 8
92+
%i30 = call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([23 x i8], [23 x i8]* @.str.1, i64 0, i64 0), i32 7, double %i29, i32 226814)
93+
%i31 = load double, double* %i12, align 16
94+
%i32 = call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([23 x i8], [23 x i8]* @.str.1, i64 0, i64 0), i32 8, double %i31, i32 355216)
95+
%i33 = load double, double* %i13, align 8
96+
%i34 = call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([23 x i8], [23 x i8]* @.str.1, i64 0, i64 0), i32 9, double %i33, i32 554280)
97+
ret i32 0
98+
}
99+
100+
declare void @_Z17__enzyme_autodiffPFddPdEz(double (double, double*)*, ...)
101+
102+
; CHECK: define internal { double } @diffe_Z15integrate_imagedPd(double %arg, double* nocapture %arg1, double* nocapture %"arg1'", double %differeturn)
103+
; CHECK-NEXT: bb:
104+
; CHECK-NEXT: %malloccall = tail call noalias nonnull dereferenceable(80) dereferenceable_or_null(80) i8* @malloc(i64 80)
105+
; CHECK-NEXT: %i7_malloccache = bitcast i8* %malloccall to double*
106+
; CHECK-NEXT: br label %bb5
107+
108+
; CHECK: bb5: ; preds = %bb5, %bb
109+
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %bb5 ], [ 0, %bb ]
110+
; CHECK-NEXT: %i7 = phi double [ %arg, %bb ], [ %i17, %bb5 ]
111+
; CHECK-NEXT: %i8 = phi double [ 1.000000e+00, %bb ], [ %i18, %bb5 ]
112+
; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1
113+
; CHECK-NEXT: %i9 = fptosi double %i7 to i32
114+
; CHECK-NEXT: %i10 = fptosi double %i8 to i32
115+
; CHECK-NEXT: %i11 = mul nsw i32 %i9, %i10
116+
; CHECK-NEXT: %i12 = sitofp i32 %i11 to double
117+
; CHECK-NEXT: %i13 = getelementptr inbounds double, double* %arg1, i64 %iv
118+
; CHECK-NEXT: %i14 = load double, double* %i13, align 8
119+
; CHECK-NEXT: %i15 = fmul double %i14, %i12
120+
; CHECK-NEXT: store double %i15, double* %i13, align 8
121+
; CHECK-NEXT: %0 = getelementptr inbounds double, double* %i7_malloccache, i64 %iv
122+
; CHECK-NEXT: store double %i7, double* %0, align 8, !invariant.group !0
123+
; CHECK-NEXT: %i17 = fdiv double %i7, 8.000000e-01
124+
; CHECK-NEXT: %i18 = fmul double %i17, 2.500000e-01
125+
; CHECK-NEXT: %i20 = icmp eq i64 %iv.next, 10
126+
; CHECK-NEXT: br i1 %i20, label %bb2, label %bb5
127+
128+
; CHECK: bb2: ; preds = %bb5
129+
; CHECK-NEXT: %i = tail call i32 @_Z18evaluate_integrandii(i32 %i9, i32 %i10)
130+
; CHECK-NEXT: %i3 = sitofp i32 %i to double
131+
; CHECK-NEXT: %m0diffei7 = fmul fast double %differeturn, %i3
132+
; CHECK-NEXT: br label %invertbb5
133+
134+
; CHECK: invertbb: ; preds = %invertbb5_phimerge
135+
; CHECK-NEXT: %1 = insertvalue { double } undef, double %14, 0
136+
; CHECK-NEXT: tail call void @free(i8* nonnull %malloccall)
137+
; CHECK-NEXT: ret { double } %1
138+
139+
; CHECK: invertbb5: ; preds = %bb2, %incinvertbb5
140+
; CHECK-NEXT: %"i7'de.0" = phi double [ %m0diffei7, %bb2 ], [ 0.000000e+00, %incinvertbb5 ]
141+
; CHECK-NEXT: %"i17'de.0" = phi double [ 0.000000e+00, %bb2 ], [ %12, %incinvertbb5 ]
142+
; CHECK-NEXT: %"arg'de.0" = phi double [ 0.000000e+00, %bb2 ], [ %14, %incinvertbb5 ]
143+
; CHECK-NEXT: %"iv'ac.0" = phi i64 [ 9, %bb2 ], [ %15, %incinvertbb5 ]
144+
; CHECK-NEXT: %d0diffei7 = fdiv fast double %"i17'de.0", 8.000000e-01
145+
; CHECK-NEXT: %2 = fadd fast double %"i7'de.0", %d0diffei7
146+
; CHECK-NEXT: %"i13'ipg_unwrap" = getelementptr inbounds double, double* %"arg1'", i64 %"iv'ac.0"
147+
; CHECK-NEXT: %3 = load double, double* %"i13'ipg_unwrap", align 8
148+
; DCE-NEXT: store double 0.000000e+00, double* %"i13'ipg_unwrap", align 8
149+
; CHECK: %4 = getelementptr inbounds double, double* %i7_malloccache, i64 %"iv'ac.0"
150+
; CHECK-NEXT: %5 = load double, double* %4, align 8, !invariant.group !0
151+
; CHECK-NEXT: %i9_unwrap = fptosi double %5 to i32
152+
; CHECK-NEXT: %6 = icmp ne i64 %"iv'ac.0", 0
153+
; CHECK-NEXT: br i1 %6, label %invertbb5_phirc, label %invertbb5_phimerge
154+
155+
; CHECK: invertbb5_phirc: ; preds = %invertbb5
156+
; CHECK-NEXT: %7 = sub nuw i64 %"iv'ac.0", 1
157+
; CHECK-NEXT: %8 = getelementptr inbounds double, double* %i7_malloccache, i64 %7
158+
; CHECK-NEXT: %9 = load double, double* %8, align 8, !invariant.group !0
159+
; CHECK-NEXT: %i17_unwrap = fdiv double %9, 8.000000e-01
160+
; CHECK-NEXT: %i18_unwrap = fmul double %i17_unwrap, 2.500000e-01
161+
; CHECK-NEXT: br label %invertbb5_phimerge
162+
163+
; CHECK: invertbb5_phimerge: ; preds = %invertbb5, %invertbb5_phirc
164+
; CHECK-NEXT: %10 = phi {{(fast )?}}double [ %i18_unwrap, %invertbb5_phirc ], [ 1.000000e+00, %invertbb5 ]
165+
; CHECK-NEXT: %i10_unwrap = fptosi double %10 to i32
166+
; CHECK-NEXT: %i11_unwrap = mul nsw i32 %i9_unwrap, %i10_unwrap
167+
; CHECK-NEXT: %i12_unwrap = sitofp i32 %i11_unwrap to double
168+
; CHECK-NEXT: %m0diffei14 = fmul fast double %3, %i12_unwrap
169+
; CHECK-NEXT: store double %m0diffei14, double* %"i13'ipg_unwrap", align 8
170+
; CHECK-NEXT: %11 = icmp eq i64 %"iv'ac.0", 0
171+
; CHECK-NEXT: %12 = select {{(fast )?}}i1 %11, double 0.000000e+00, double %2
172+
; CHECK-NEXT: %13 = fadd fast double %"arg'de.0", %2
173+
; CHECK-NEXT: %14 = select {{(fast )?}}i1 %11, double %13, double %"arg'de.0"
174+
; CHECK-NEXT: br i1 %11, label %invertbb, label %incinvertbb5
175+
176+
; CHECK: incinvertbb5: ; preds = %invertbb5_phimerge
177+
; CHECK-NEXT: %15 = add nsw i64 %"iv'ac.0", -1
178+
; CHECK-NEXT: br label %invertbb5
179+
; CHECK-NEXT: }
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
2+
// RUN: %clang -std=c11 -fno-unroll-loops -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
3+
// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
4+
// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
5+
// RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -
6+
// RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -
7+
// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -
8+
// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -
9+
10+
#include <stdio.h>
11+
#include <math.h>
12+
#include <assert.h>
13+
14+
#include "test_utils.h"
15+
16+
17+
#include <stdio.h>
18+
19+
__attribute__((noinline))
20+
int evaluate_integrand(const int nr,
21+
const int dtheta)
22+
{
23+
return nr * dtheta;
24+
}
25+
26+
double integrate_image(double dr, double* out)
27+
{
28+
double dtheta = 1;
29+
30+
{
31+
double I_estimate;
32+
33+
for (int k=0; k<10; k++)
34+
{
35+
36+
int nr = (int)(dr);
37+
int ntheta = (int)(dtheta);
38+
39+
double sum = evaluate_integrand(nr, ntheta);
40+
41+
out[k] *= nr * ntheta;
42+
printf("dtheta=%d\n", nr * ntheta);
43+
I_estimate = sum * dr;
44+
45+
// Update the step size
46+
dr /= 0.8;
47+
dtheta = dr / 4.0;
48+
}
49+
return I_estimate;
50+
}
51+
}
52+
53+
void __enzyme_autodiff(double (*)(double, double*), ...);
54+
55+
int main()
56+
{
57+
double out[10];
58+
double d_out[10];
59+
for(int i=0; i<10; i++)
60+
d_out[i] = 1.0;
61+
62+
int answer[10] = {
63+
200,
64+
15500,
65+
24336,
66+
37830,
67+
59536,
68+
92720,
69+
144780,
70+
226814,
71+
355216,
72+
554280
73+
};
74+
75+
__enzyme_autodiff(integrate_image, 200.0, out, d_out);
76+
77+
for (int i = 0; i < 10; i++)
78+
{
79+
printf("dout[%d]=%f answer=%d\n", i, d_out[i], answer[i]);
80+
APPROX_EQ(d_out[i], answer[i], 1e-6);
81+
}
82+
83+
return 0;
84+
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// note not doing O0 below as to ensure we get tbaa
2+
// RUN: if [ %llvmver -ge 9 ]; then %clang -fopenmp -std=c11 -O1 -fno-vectorize -fno-unroll-loops -disable-llvm-optzns %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %clang -fopenmp -x ir - -o %s.out && %s.out; fi
3+
// RUN: if [ %llvmver -ge 9 ]; then %clang -fopenmp -std=c11 -O1 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %clang -fopenmp -x ir - -o %s.out && %s.out ; fi
4+
// RUN: if [ %llvmver -ge 9 ]; then %clang -fopenmp -std=c11 -O2 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %clang -fopenmp -x ir - -o %s.out && %s.out ; fi
5+
// RUN: if [ %llvmver -ge 9 ]; then %clang -fopenmp -std=c11 -O3 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %clang -fopenmp -x ir - -o %s.out && %s.out ; fi
6+
// note not doing O0 below as to ensure we get tbaa
7+
// RUN: if [ %llvmver -ge 9 ]; then %clang -fopenmp -std=c11 -O1 -fno-vectorize -fno-unroll-loops -Xclang -disable-llvm-optzns %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out; fi
8+
// RUN: if [ %llvmver -ge 9 ]; then %clang -fopenmp -std=c11 -O1 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out ; fi
9+
// RUN: if [ %llvmver -ge 9 ]; then %clang -fopenmp -std=c11 -O2 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out ; fi
10+
// RUN: if [ %llvmver -ge 9 ]; then %clang -fopenmp -std=c11 -O3 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out ; fi
11+
12+
#include <stdio.h>
13+
#include <math.h>
14+
#include <assert.h>
15+
16+
#include "test_utils.h"
17+
18+
double __enzyme_autodiff(void*, ...);
19+
20+
/*
21+
void omp(float& a, int N) {
22+
#define N 20
23+
#pragma omp parallel for
24+
for (int i=0; i<N; i++) {
25+
//a[i] *= a[i];
26+
(&a)[i] *= (&a)[i];
27+
}
28+
#undef N
29+
(&a)[0] = 0;
30+
}
31+
*/
32+
void omp(float* a, int N) {
33+
#pragma omp parallel for
34+
for (int i=0; i<N; i++) {
35+
//a[i] *= a[i];
36+
a[i] *= a[i];
37+
}
38+
a[0] = 0;
39+
}
40+
41+
int main(int argc, char** argv) {
42+
43+
int N = 20;
44+
float a[N];
45+
for(int i=0; i<N; i++) {
46+
a[i] = i+1;
47+
}
48+
49+
float d_a[N];
50+
for(int i=0; i<N; i++)
51+
d_a[i] = 1.0f;
52+
53+
//omp(*a, N);
54+
printf("ran omp\n");
55+
__enzyme_autodiff((void*)omp, a, d_a, N);
56+
57+
for(int i=0; i<N; i++) {
58+
printf("a[%d]=%f d_a[%d]=%f\n", i, a[i], i, d_a[i]);
59+
}
60+
61+
//APPROX_EQ(da, 17711.0*2, 1e-10);
62+
//APPROX_EQ(db, 17711.0*2, 1e-10);
63+
//printf("hello! %f, res2 %f, da: %f, db: %f\n", ret, ret, da,db);
64+
APPROX_EQ(d_a[0], 0.0f, 1e-10);
65+
for(int i=1; i<N; i++) {
66+
APPROX_EQ(d_a[i], 2.0f*(i+1), 1e-10);
67+
}
68+
return 0;
69+
}

0 commit comments

Comments
 (0)