Skip to content

Commit 3316e62

Browse files
tgymnichwsmoses
andauthored
Fix unwrap (rust-lang#647)
* fix unwrap stack overflow for vector mode * check for exactly one knownRecomputeHeuristic.count * added indices check * add tests Co-authored-by: William S. Moses <[email protected]>
1 parent 21bd55e commit 3316e62

File tree

6 files changed

+405
-28
lines changed

6 files changed

+405
-28
lines changed

enzyme/Enzyme/DifferentialUseAnalysis.h

+41-4
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,15 @@ static inline bool is_use_directly_needed_in_reverse(
106106
}
107107

108108
if (isa<CmpInst>(user) || isa<BranchInst>(user) || isa<ReturnInst>(user) ||
109-
isa<FPExtInst>(user) || isa<FPTruncInst>(user) ||
110-
(isa<ExtractElementInst>(user) &&
111-
cast<ExtractElementInst>(user)->getIndexOperand() != val)
109+
isa<FPExtInst>(user) || isa<FPTruncInst>(user)
112110
#if LLVM_VERSION_MAJOR >= 10
113111
|| isa<FreezeInst>(user)
114112
#endif
115113
// isa<ExtractElement>(use) ||
116114
// isa<InsertElementInst>(use) || isa<ShuffleVectorInst>(use) ||
117115
// isa<ExtractValueInst>(use) || isa<AllocaInst>(use)
118-
/*|| isa<StoreInst>(use)*/) {
116+
// || isa<StoreInst>(use)
117+
) {
119118
return false;
120119
}
121120

@@ -134,6 +133,21 @@ static inline bool is_use_directly_needed_in_reverse(
134133
// Otherwise, we need the value.
135134
return true;
136135
}
136+
if (auto EEI = dyn_cast<ExtractElementInst>(user)) {
137+
// Only need the index in the reverse, so if the value is not
138+
// the index, short circuit and say we don't need
139+
if (EEI->getIndexOperand() != val) {
140+
return false;
141+
}
142+
// The index is only needed in the reverse if the value being inserted
143+
// is a possible active floating point value
144+
if (gutils->isConstantValue(const_cast<ExtractElementInst *>(EEI)) ||
145+
TR.query(const_cast<ExtractElementInst *>(EEI))[{-1}] ==
146+
BaseType::Pointer)
147+
return false;
148+
// Otherwise, we need the value.
149+
return true;
150+
}
137151

138152
if (auto IVI = dyn_cast<InsertValueInst>(user)) {
139153
// Only need the index in the reverse, so if the value is not
@@ -157,6 +171,29 @@ static inline bool is_use_directly_needed_in_reverse(
157171
return true;
158172
}
159173

174+
if (auto EVI = dyn_cast<ExtractValueInst>(user)) {
175+
// Only need the index in the reverse, so if the value is not
176+
// the index, short circuit and say we don't need
177+
bool valueIsIndex = false;
178+
for (unsigned i = 2; i < EVI->getNumOperands(); ++i) {
179+
if (EVI->getOperand(i) == val) {
180+
valueIsIndex = true;
181+
}
182+
}
183+
184+
if (!valueIsIndex)
185+
return false;
186+
187+
// The index is only needed in the reverse if the value being inserted
188+
// is a possible active floating point value
189+
if (gutils->isConstantValue(const_cast<ExtractValueInst *>(EVI)) ||
190+
TR.query(const_cast<ExtractValueInst *>(EVI))[{-1}] ==
191+
BaseType::Pointer)
192+
return false;
193+
// Otherwise, we need the value.
194+
return true;
195+
}
196+
160197
Intrinsic::ID ID = Intrinsic::not_intrinsic;
161198
if (auto II = dyn_cast<IntrinsicInst>(user)) {
162199
ID = II->getIntrinsicID();

enzyme/Enzyme/GradientUtils.cpp

+36-14
Original file line numberDiff line numberDiff line change
@@ -442,22 +442,44 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
442442
assert(val->getType() == toreturn->getType());
443443
return toreturn;
444444
} else if (auto op = dyn_cast<InsertValueInst>(val)) {
445-
auto op0 = getOp(op->getAggregateOperand());
446-
if (op0 == nullptr)
447-
goto endCheck;
448-
auto op1 = getOp(op->getInsertedValueOperand());
449-
if (op1 == nullptr)
445+
// Unwrapped Aggregate, Indices, parent
446+
SmallVector<std::tuple<Value *, ArrayRef<unsigned>, InsertValueInst *>, 1>
447+
insertElements;
448+
449+
Value *agg = op;
450+
while (auto op1 = dyn_cast<InsertValueInst>(agg)) {
451+
if (Value *orig = isOriginal(op1)) {
452+
if (knownRecomputeHeuristic.count(orig)) {
453+
if (!knownRecomputeHeuristic[orig]) {
454+
break;
455+
}
456+
}
457+
}
458+
Value *valOp = op1->getInsertedValueOperand();
459+
valOp = getOp(valOp);
460+
if (valOp == nullptr)
461+
goto endCheck;
462+
insertElements.push_back({valOp, op1->getIndices(), op1});
463+
agg = op1->getAggregateOperand();
464+
}
465+
466+
Value *toreturn = getOp(agg);
467+
if (toreturn == nullptr)
450468
goto endCheck;
451-
auto toreturn = BuilderM.CreateInsertValue(op0, op1, op->getIndices(),
452-
op->getName() + "_unwrap");
453-
if (permitCache)
454-
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn;
455-
if (auto newi = dyn_cast<Instruction>(toreturn)) {
456-
newi->copyIRFlags(op);
457-
unwrappedLoads[newi] = val;
458-
if (newi->getParent()->getParent() != op->getParent()->getParent())
459-
newi->setDebugLoc(nullptr);
469+
for (auto &&[valOp, idcs, parent] : reverse(insertElements)) {
470+
toreturn = BuilderM.CreateInsertValue(toreturn, valOp, idcs,
471+
parent->getName() + "_unwrap");
472+
473+
if (permitCache)
474+
unwrap_cache[BuilderM.GetInsertBlock()][parent][idx.second] = toreturn;
475+
if (auto newi = dyn_cast<Instruction>(toreturn)) {
476+
newi->copyIRFlags(parent);
477+
unwrappedLoads[newi] = val;
478+
if (newi->getParent()->getParent() != parent->getParent()->getParent())
479+
newi->setDebugLoc(nullptr);
480+
}
460481
}
482+
461483
assert(val->getType() == toreturn->getType());
462484
return toreturn;
463485
} else if (auto op = dyn_cast<ExtractElementInst>(val)) {
+135
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -S | FileCheck %s
2+
3+
define void @tester(double* %in0, double* %in1, i1 %c) {
4+
entry:
5+
br i1 %c, label %trueb, label %exit
6+
7+
trueb:
8+
%pre_x0 = load double, double* %in0
9+
store double 0.000000e+00, double* %in0
10+
%x0 = insertvalue {double, double, double*} undef, double %pre_x0, 0
11+
12+
%pre_x1 = load double, double* %in1
13+
store double 0.000000e+00, double* %in1
14+
%x1 = insertvalue {double, double, double*} %x0, double %pre_x1, 1
15+
16+
%out1 = insertvalue {double, double, double*} %x1, double* %in0, 2
17+
18+
%post_x0 = extractvalue {double, double, double*} %out1, 0
19+
%post_x1 = extractvalue {double, double, double*} %x1, 1
20+
21+
%mul0 = fmul double %post_x0, %post_x1
22+
store double %mul0, double* %in0
23+
24+
br label %exit
25+
26+
exit:
27+
ret void
28+
}
29+
30+
define void @test_derivative(double* %x, double* %dx, double* %y, double* %dy) {
31+
entry:
32+
tail call void (...) @__enzyme_autodiff(void (double*, double*, i1)* nonnull @tester, double* %x, double* %dx, double* %y, double* %dy, i1 true)
33+
ret void
34+
}
35+
36+
; Function Attrs: nounwind
37+
declare void @__enzyme_autodiff(...)
38+
39+
; CHECK: define internal void @diffetester(double* %in0, double* %"in0'", double* %in1, double* %"in1'", i1 %c)
40+
; CHECK-NEXT: entry:
41+
; CHECK-NEXT: %"x1'de" = alloca { double, double, double* }
42+
; CHECK-NEXT: store { double, double, double* } zeroinitializer, { double, double, double* }* %"x1'de"
43+
; CHECK-NEXT: %"out1'de" = alloca { double, double, double* }
44+
; CHECK-NEXT: store { double, double, double* } zeroinitializer, { double, double, double* }* %"out1'de"
45+
; CHECK-NEXT: %"x0'de" = alloca { double, double, double* }
46+
; CHECK-NEXT: store { double, double, double* } zeroinitializer, { double, double, double* }* %"x0'de"
47+
; CHECK-NEXT: br i1 %c, label %trueb, label %exit
48+
49+
; CHECK: trueb: ; preds = %entry
50+
; CHECK-NEXT: %pre_x0 = load double, double* %in0
51+
; CHECK-NEXT: store double 0.000000e+00, double* %in0
52+
; CHECK-NEXT: %x0 = insertvalue { double, double, double* } undef, double %pre_x0, 0
53+
; CHECK-NEXT: %pre_x1 = load double, double* %in1
54+
; CHECK-NEXT: store double 0.000000e+00, double* %in1
55+
; CHECK-NEXT: %x1 = insertvalue { double, double, double* } %x0, double %pre_x1, 1
56+
; CHECK-NEXT: %out1 = insertvalue { double, double, double* } %x1, double* %in0, 2
57+
; CHECK-NEXT: %post_x0 = extractvalue { double, double, double* } %out1, 0
58+
; CHECK-NEXT: %post_x1 = extractvalue { double, double, double* } %x1, 1
59+
; CHECK-NEXT: %mul0 = fmul double %post_x0, %post_x1
60+
; CHECK-NEXT: store double %mul0, double* %in0
61+
; CHECK-NEXT: br label %exit
62+
63+
; CHECK: exit: ; preds = %trueb, %entry
64+
; CHECK-NEXT: %x1_cache.0 = phi { double, double, double* } [ %x1, %trueb ], [ undef, %entry ]
65+
; CHECK-NEXT: br label %invertexit
66+
67+
; CHECK: invertentry: ; preds = %invertexit, %inverttrueb
68+
; CHECK-NEXT: ret void
69+
70+
; CHECK: inverttrueb: ; preds = %invertexit
71+
; CHECK-NEXT: %0 = load double, double* %"in0'"
72+
; CHECK-NEXT: store double 0.000000e+00, double* %"in0'"
73+
; CHECK-NEXT: %1 = fadd fast double 0.000000e+00, %0
74+
; CHECK-NEXT: %post_x1_unwrap = extractvalue { double, double, double* } %x1_cache.0, 1
75+
; CHECK-NEXT: %m0diffepost_x0 = fmul fast double %1, %post_x1_unwrap
76+
; CHECK-NEXT: %out1_unwrap = insertvalue { double, double, double* } %x1_cache.0, double* %in0, 2
77+
; CHECK-NEXT: %post_x0_unwrap = extractvalue { double, double, double* } %out1_unwrap, 0
78+
; CHECK-NEXT: %m1diffepost_x1 = fmul fast double %1, %post_x0_unwrap
79+
; CHECK-NEXT: %2 = fadd fast double 0.000000e+00, %m0diffepost_x0
80+
; CHECK-NEXT: %3 = fadd fast double 0.000000e+00, %m1diffepost_x1
81+
; CHECK-NEXT: %4 = getelementptr inbounds { double, double, double* }, { double, double, double* }* %"x1'de", i32 0, i32 1
82+
; CHECK-NEXT: %5 = load double, double* %4
83+
; CHECK-NEXT: %6 = fadd fast double %5, %3
84+
; CHECK-NEXT: store double %6, double* %4
85+
; CHECK-NEXT: %7 = getelementptr inbounds { double, double, double* }, { double, double, double* }* %"out1'de", i32 0, i32 0
86+
; CHECK-NEXT: %8 = load double, double* %7
87+
; CHECK-NEXT: %9 = fadd fast double %8, %2
88+
; CHECK-NEXT: store double %9, double* %7
89+
; CHECK-NEXT: %10 = load { double, double, double* }, { double, double, double* }* %"out1'de"
90+
; CHECK-NEXT: %11 = insertvalue { double, double, double* } %10, double* null, 2
91+
; CHECK-NEXT: %12 = load { double, double, double* }, { double, double, double* }* %"x1'de"
92+
; CHECK-NEXT: %13 = extractvalue { double, double, double* } %10, 0
93+
; CHECK-NEXT: %14 = getelementptr inbounds { double, double, double* }, { double, double, double* }* %"x1'de", i32 0, i32 0
94+
; CHECK-NEXT: %15 = load double, double* %14
95+
; CHECK-NEXT: %16 = fadd fast double %15, %13
96+
; CHECK-NEXT: store double %16, double* %14
97+
; CHECK-NEXT: %17 = extractvalue { double, double, double* } %10, 1
98+
; CHECK-NEXT: %18 = getelementptr inbounds { double, double, double* }, { double, double, double* }* %"x1'de", i32 0, i32 1
99+
; CHECK-NEXT: %19 = load double, double* %18
100+
; CHECK-NEXT: %20 = fadd fast double %19, %17
101+
; CHECK-NEXT: store double %20, double* %18
102+
; CHECK-NEXT: store { double, double, double* } zeroinitializer, { double, double, double* }* %"out1'de"
103+
; CHECK-NEXT: %21 = load { double, double, double* }, { double, double, double* }* %"x1'de"
104+
; CHECK-NEXT: %22 = extractvalue { double, double, double* } %21, 1
105+
; CHECK-NEXT: %23 = fadd fast double 0.000000e+00, %22
106+
; CHECK-NEXT: %24 = load { double, double, double* }, { double, double, double* }* %"x1'de"
107+
; CHECK-NEXT: %25 = insertvalue { double, double, double* } %24, double 0.000000e+00, 1
108+
; CHECK-NEXT: %26 = load { double, double, double* }, { double, double, double* }* %"x0'de"
109+
; CHECK-NEXT: %27 = extractvalue { double, double, double* } %24, 0
110+
; CHECK-NEXT: %28 = getelementptr inbounds { double, double, double* }, { double, double, double* }* %"x0'de", i32 0, i32 0
111+
; CHECK-NEXT: %29 = load double, double* %28
112+
; CHECK-NEXT: %30 = fadd fast double %29, %27
113+
; CHECK-NEXT: store double %30, double* %28
114+
; CHECK-NEXT: %31 = getelementptr inbounds { double, double, double* }, { double, double, double* }* %"x0'de", i32 0, i32 1
115+
; CHECK-NEXT: %32 = load double, double* %31
116+
; CHECK-NEXT: %33 = fadd fast double %32, 0.000000e+00
117+
; CHECK-NEXT: store double %33, double* %31
118+
; CHECK-NEXT: store { double, double, double* } zeroinitializer, { double, double, double* }* %"x1'de"
119+
; CHECK-NEXT: store double 0.000000e+00, double* %"in1'"
120+
; CHECK-NEXT: %34 = load double, double* %"in1'"
121+
; CHECK-NEXT: %35 = fadd fast double %34, %23
122+
; CHECK-NEXT: store double %35, double* %"in1'"
123+
; CHECK-NEXT: %36 = load { double, double, double* }, { double, double, double* }* %"x0'de"
124+
; CHECK-NEXT: %37 = extractvalue { double, double, double* } %36, 0
125+
; CHECK-NEXT: %38 = fadd fast double 0.000000e+00, %37
126+
; CHECK-NEXT: store { double, double, double* } zeroinitializer, { double, double, double* }* %"x0'de"
127+
; CHECK-NEXT: store double 0.000000e+00, double* %"in0'"
128+
; CHECK-NEXT: %39 = load double, double* %"in0'"
129+
; CHECK-NEXT: %40 = fadd fast double %39, %38
130+
; CHECK-NEXT: store double %40, double* %"in0'"
131+
; CHECK-NEXT: br label %invertentry
132+
133+
; CHECK: invertexit: ; preds = %exit
134+
; CHECK-NEXT: br i1 %c, label %inverttrueb, label %invertentry
135+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)