Skip to content

Commit 55114ed

Browse files
authored
Fix phi constant pred (rust-lang#612)
1 parent 2cafc87 commit 55114ed

File tree

2 files changed

+170
-11
lines changed

2 files changed

+170
-11
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

-11
Original file line numberDiff line numberDiff line change
@@ -3614,17 +3614,6 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
36143614
if (NumPreds == 0)
36153615
continue;
36163616
Phi->removeIncomingValue(newBB);
3617-
3618-
// If we have a single predecessor, removeIncomingValue may have
3619-
// erased the PHI node itself.
3620-
if (NumPreds == 1)
3621-
continue;
3622-
3623-
// Try to replace the PHI node with a constant value.
3624-
if (Value *PhiConstant = Phi->hasConstantValue()) {
3625-
Phi->replaceAllUsesWith(PhiConstant);
3626-
Phi->eraseFromParent();
3627-
}
36283617
}
36293618
}
36303619

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -simplifycfg -instsimplify -gvn -adce -S | FileCheck %s
2+
3+
source_filename = "text"
4+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128-ni:10:11:12:13"
5+
target triple = "x86_64-pc-linux-gnu"
6+
7+
8+
define internal fastcc i64 @julia_ht_keyindex_1432({} addrspace(10)* nocapture nonnull readonly align 8 dereferenceable(64) %arg, {} addrspace(10)* nonnull %arg1) {
9+
top:
10+
%i = call {}*** @julia.ptls_states()
11+
%i2 = bitcast {} addrspace(10)* %arg to i8 addrspace(10)*
12+
%i3 = addrspacecast i8 addrspace(10)* %i2 to i8 addrspace(11)*
13+
%i4 = getelementptr inbounds i8, i8 addrspace(11)* %i3, i64 8
14+
%i5 = bitcast i8 addrspace(11)* %i4 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* addrspace(11)*
15+
%i6 = load atomic { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* addrspace(11)* %i5 unordered, align 8
16+
%i7 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %i6 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*
17+
%i8 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %i7, i64 0, i32 1
18+
%i9 = load i64, i64 addrspace(11)* %i8, align 8
19+
%i10 = getelementptr inbounds i8, i8 addrspace(11)* %i3, i64 56
20+
%i11 = bitcast i8 addrspace(11)* %i10 to i64 addrspace(11)*
21+
%i12 = load i64, i64 addrspace(11)* %i11, align 8
22+
%i13 = call i64 @jl_object_id({} addrspace(10)* nonnull %arg1)
23+
%i14 = shl i64 %i13, 21
24+
%i15 = xor i64 %i14, -1
25+
%i16 = add i64 %i13, %i15
26+
%i17 = lshr i64 %i16, 24
27+
%i18 = xor i64 %i17, %i16
28+
%i19 = mul i64 %i18, 265
29+
%i20 = lshr i64 %i19, 14
30+
%i21 = xor i64 %i20, %i19
31+
%i22 = mul i64 %i21, 21
32+
%i23 = lshr i64 %i22, 28
33+
%i24 = xor i64 %i23, %i22
34+
%i25 = mul i64 %i24, 2147483649
35+
%i26 = add nsw i64 %i9, -1
36+
%i27 = bitcast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %i6 to {} addrspace(10)* addrspace(13)* addrspace(10)*
37+
%i28 = bitcast {} addrspace(10)* %arg to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* addrspace(10)*
38+
%i29 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* addrspace(10)* %i28 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* addrspace(11)*
39+
%i30 = load atomic { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* addrspace(11)* %i29 unordered, align 8
40+
%i31 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %i30 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*
41+
%i32 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %i31, i64 0, i32 0
42+
%i33 = load i8 addrspace(13)*, i8 addrspace(13)* addrspace(11)* %i32, align 16
43+
%i34 = addrspacecast {} addrspace(10)* addrspace(13)* addrspace(10)* %i27 to {} addrspace(10)* addrspace(13)* addrspace(11)*
44+
%i35 = load {} addrspace(10)* addrspace(13)*, {} addrspace(10)* addrspace(13)* addrspace(11)* %i34, align 16
45+
br label %L84
46+
47+
L84: ; preds = %L106, %top
48+
%.pn = phi i64 [ %i25, %top ], [ %value_phi, %L106 ]
49+
%value_phi1 = phi i64 [ 0, %top ], [ %i40, %L106 ]
50+
%value_phi.in = and i64 %.pn, %i26
51+
%value_phi = add i64 %value_phi.in, 1
52+
%i36 = getelementptr inbounds i8, i8 addrspace(13)* %i33, i64 %value_phi.in
53+
%i37 = load i8, i8 addrspace(13)* %i36, align 1
54+
switch i8 %i37, label %L97 [
55+
i8 0, label %L105
56+
i8 2, label %L106
57+
]
58+
59+
L97: ; preds = %L84
60+
%i38 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %i35, i64 %value_phi.in
61+
%i39 = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %i38 unordered, align 8
62+
%.not12 = icmp eq {} addrspace(10)* %i39, null
63+
br i1 %.not12, label %fail, label %pass
64+
65+
L105: ; preds = %pass, %L106, %L84
66+
%merge.ph = phi i64 [ %value_phi, %pass ], [ -1, %L84 ], [ -1, %L106 ]
67+
ret i64 %merge.ph
68+
69+
L106: ; preds = %pass, %L84
70+
%i40 = add i64 %value_phi1, 1
71+
%.not13 = icmp slt i64 %i12, %i40
72+
br i1 %.not13, label %L105, label %L84
73+
74+
fail: ; preds = %L97
75+
call void @jl_throw({} addrspace(10)* addrspacecast ({}* inttoptr (i64 140161201230928 to {}*) to {} addrspace(10)*)) #1
76+
unreachable
77+
78+
pass: ; preds = %L97
79+
%i41 = icmp eq {} addrspace(10)* %i39, %arg1
80+
br i1 %i41, label %L105, label %L106
81+
}
82+
83+
; Function Attrs: readnone
84+
declare {}*** @julia.ptls_states() local_unnamed_addr #0
85+
86+
; Function Attrs: noreturn
87+
declare void @jl_throw({} addrspace(10)*) local_unnamed_addr #1
88+
89+
; Function Attrs: inaccessiblememonly allocsize(1)
90+
declare noalias nonnull {} addrspace(12)* @julia.gc_alloc_obj(i8*, i64, {} addrspace(10)*) local_unnamed_addr #2
91+
92+
; Function Attrs: readonly
93+
declare i64 @jl_object_id({} addrspace(10)*) local_unnamed_addr #3
94+
95+
declare double @__enzyme_autodiff(...)
96+
97+
define double @dsquare({} addrspace(10)* nocapture nonnull readonly align 8 dereferenceable(64) %arg, {} addrspace(10)* nocapture nonnull readonly align 8 dereferenceable(64) %arg1) local_unnamed_addr {
98+
entry:
99+
%call = tail call double (...) @__enzyme_autodiff(i8* bitcast (double ({} addrspace(10)*)* @julia_sum_rec_1428.inner.1 to i8*), metadata !"enzyme_dup", {} addrspace(10)* nocapture nonnull readonly align 8 dereferenceable(64) %arg, {} addrspace(10)* nocapture nonnull readonly align 8 dereferenceable(64) %arg1)
100+
ret double %call
101+
}
102+
103+
define double @julia_sum_rec_1428.inner.1({} addrspace(10)* nocapture nonnull readonly align 8 dereferenceable(64) %arg) local_unnamed_addr {
104+
entry:
105+
%i1 = bitcast {} addrspace(10)* %arg to i8 addrspace(10)*
106+
%i2 = addrspacecast i8 addrspace(10)* %i1 to i8 addrspace(11)*
107+
108+
%i31 = bitcast i8 addrspace(11)* %i2 to {} addrspace(10)* addrspace(13)* addrspace(10)* addrspace(11)*
109+
%i47 = load atomic {} addrspace(10)* addrspace(13)* addrspace(10)*, {} addrspace(10)* addrspace(13)* addrspace(10)* addrspace(11)* %i31 unordered, align 8
110+
111+
112+
%i3 = getelementptr inbounds i8, i8 addrspace(11)* %i2, i64 48
113+
%i4 = bitcast i8 addrspace(11)* %i3 to i64 addrspace(11)*
114+
%i5 = load i64, i64 addrspace(11)* %i4, align 8
115+
%i6 = bitcast {} addrspace(10)* %arg to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* addrspace(10)*
116+
%i7 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* addrspace(10)* %i6 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* addrspace(11)*
117+
%i8 = load atomic { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* addrspace(11)* %i7 unordered, align 8
118+
%i9 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %i8 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*
119+
%i14 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %i9, i64 0, i32 0
120+
%i15 = load i8 addrspace(13)*, i8 addrspace(13)* addrspace(11)* %i14, align 16
121+
%i16 = add i64 %i5, -1
122+
%i17 = getelementptr inbounds i8, i8 addrspace(13)* %i15, i64 %i16
123+
%i18 = load i8, i8 addrspace(13)* %i17, align 1
124+
%.not26 = icmp eq i8 %i18, 1
125+
126+
%i21 = bitcast i8 addrspace(11)* %i2 to {} addrspace(10)* addrspace(13)* addrspace(10)* addrspace(11)*
127+
%i22 = load atomic {} addrspace(10)* addrspace(13)* addrspace(10)*, {} addrspace(10)* addrspace(13)* addrspace(10)* addrspace(11)* %i21 unordered, align 8
128+
%i23 = addrspacecast {} addrspace(10)* addrspace(13)* addrspace(10)* %i22 to {} addrspace(10)* addrspace(13)* addrspace(11)*
129+
130+
%i24 = load {} addrspace(10)* addrspace(13)*, {} addrspace(10)* addrspace(13)* addrspace(11)* %i23, align 16
131+
132+
%i26 = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %i24 unordered, align 8
133+
134+
br i1 %.not26, label %L42.i, label %julia_sum_rec_1428.inner.exit
135+
136+
L42.i: ; preds = %L38.i
137+
%i28 = icmp sgt i64 %i5, -1
138+
br i1 %i28, label %L66.i, label %L62.i
139+
140+
L62.i: ; preds = %L132.i, %L50.i
141+
%value_phi3.i.lcssa = phi {} addrspace(10)* [ %i26, %L42.i ], [ null, %L101.i ]
142+
call void @jl_throw({} addrspace(10)* %value_phi3.i.lcssa)
143+
unreachable
144+
145+
L66.i: ; preds = %L132.i, %L66.i.lr.ph
146+
%i77 = call fastcc i64 @julia_ht_keyindex_1432({} addrspace(10)* null, {} addrspace(10)* nonnull %i26)
147+
%i49 = addrspacecast {} addrspace(10)* addrspace(13)* addrspace(10)* %i47 to {} addrspace(10)* addrspace(13)* addrspace(11)*
148+
%i50 = load {} addrspace(10)* addrspace(13)*, {} addrspace(10)* addrspace(13)* addrspace(11)* %i49, align 16
149+
%i52 = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %i50 unordered, align 8
150+
%i84 = call double @julia_sum_rec_1428.inner.1({} addrspace(10)* nocapture nonnull readonly align 8 dereferenceable(64) %i52)
151+
%.not17 = icmp sgt i64 %i5, 0
152+
%i72 = icmp sgt i64 %i5, -1
153+
br i1 %.not17, label %julia_sum_rec_1428.inner.exit, label %L101.i
154+
155+
L101.i: ; preds = %L108.i, %L101.i.preheader
156+
br i1 %i72, label %L66.i, label %L62.i
157+
158+
julia_sum_rec_1428.inner.exit: ; preds = %pass9.i, %L120.i, %L108.i, %L87.i, %L38.i, %L26.i, %L5.i, %entry
159+
ret double 1.000000e+00
160+
}
161+
162+
attributes #0 = { readnone "enzyme_inactive" }
163+
attributes #1 = { noreturn }
164+
attributes #2 = { inaccessiblememonly allocsize(1) }
165+
attributes #3 = { readonly "enzyme_inactive" }
166+
attributes #4 = { argmemonly nounwind }
167+
attributes #5 = { readonly }
168+
attributes #6 = { allocsize(1) }
169+
170+
; CHECK: define internal void @diffejulia_sum_rec_1428.inner.1

0 commit comments

Comments
 (0)