Skip to content

Commit 44e7122

Browse files
authored
Correct and simplify sdot/ddot (rust-lang#498)
1 parent 0ef2b90 commit 44e7122

22 files changed

+435
-2061
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4652,6 +4652,11 @@ class AdjointGenerator
46524652

46534653
bool handleBLAS(llvm::CallInst &call, Function *called, StringRef funcName,
46544654
const std::map<Argument *, bool> &uncacheable_args) {
4655+
// Forward Mode not handled yet
4656+
assert(Mode != DerivativeMode::ForwardMode &&
4657+
Mode != DerivativeMode::ForwardModeSplit);
4658+
// Vector Mode not handled yet
4659+
assert(gutils->getWidth() == 1);
46554660
CallInst *const newCall = cast<CallInst>(gutils->getNewFromOriginal(&call));
46564661
IRBuilder<> BuilderZ(newCall);
46574662
BuilderZ.setFastMathFlags(getFast());
@@ -4671,9 +4676,6 @@ class AdjointGenerator
46714676
}
46724677
Type *castvals[2] = {call.getArgOperand(1)->getType(),
46734678
call.getArgOperand(3)->getType()};
4674-
auto *cachetype =
4675-
StructType::get(call.getContext(), ArrayRef<Type *>(castvals));
4676-
Value *undefinit = UndefValue::get(cachetype);
46774679
Value *cacheval;
46784680
auto in_arg = call.getCalledFunction()->arg_begin();
46794681
in_arg++;
@@ -4694,15 +4696,16 @@ class AdjointGenerator
46944696
if (xcache) {
46954697
auto dmemcpy =
46964698
getOrInsertMemcpyStrided(*gutils->oldFunc->getParent(),
4697-
PointerType::getUnqual(innerType), 0, 0);
4699+
cast<PointerType>(castvals[0]), 0, 0);
46984700
auto malins = CallInst::CreateMalloc(
46994701
gutils->getNewFromOriginal(&call), size->getType(), innerType,
4700-
size, call.getArgOperand(0), nullptr, "");
4701-
arg1 =
4702-
BuilderZ.CreateBitCast(malins, call.getArgOperand(1)->getType());
4702+
size, gutils->getNewFromOriginal(call.getArgOperand(0)), nullptr,
4703+
"");
4704+
arg1 = BuilderZ.CreateBitCast(malins, castvals[0]);
47034705
Value *args[4] = {arg1,
47044706
gutils->getNewFromOriginal(call.getArgOperand(1)),
4705-
call.getArgOperand(0), call.getArgOperand(2)};
4707+
gutils->getNewFromOriginal(call.getArgOperand(0)),
4708+
gutils->getNewFromOriginal(call.getArgOperand(2))};
47064709

47074710
BuilderZ.CreateCall(
47084711
dmemcpy, args,
@@ -4715,15 +4718,16 @@ class AdjointGenerator
47154718
if (ycache) {
47164719
auto dmemcpy =
47174720
getOrInsertMemcpyStrided(*gutils->oldFunc->getParent(),
4718-
PointerType::getUnqual(innerType), 0, 0);
4721+
cast<PointerType>(castvals[1]), 0, 0);
47194722
auto malins = CallInst::CreateMalloc(
47204723
gutils->getNewFromOriginal(&call), size->getType(), innerType,
4721-
size, call.getArgOperand(0), nullptr, "");
4722-
arg2 =
4723-
BuilderZ.CreateBitCast(malins, call.getArgOperand(3)->getType());
4724+
size, gutils->getNewFromOriginal(call.getArgOperand(0)), nullptr,
4725+
"");
4726+
arg2 = BuilderZ.CreateBitCast(malins, castvals[1]);
47244727
Value *args[4] = {arg2,
47254728
gutils->getNewFromOriginal(call.getArgOperand(3)),
4726-
call.getArgOperand(0), call.getArgOperand(4)};
4729+
gutils->getNewFromOriginal(call.getArgOperand(0)),
4730+
gutils->getNewFromOriginal(call.getArgOperand(4))};
47274731
BuilderZ.CreateCall(
47284732
dmemcpy, args,
47294733
gutils->getInvertedBundles(&call,
@@ -4733,7 +4737,10 @@ class AdjointGenerator
47334737
BuilderZ, /*lookup*/ false));
47344738
}
47354739
if (xcache && ycache) {
4736-
auto valins1 = BuilderZ.CreateInsertValue(undefinit, arg1, 0);
4740+
Type *cachetype =
4741+
StructType::get(call.getContext(), ArrayRef<Type *>(castvals));
4742+
auto valins1 =
4743+
BuilderZ.CreateInsertValue(UndefValue::get(cachetype), arg1, 0);
47374744
cacheval = BuilderZ.CreateInsertValue(valins1, arg2, 1);
47384745
} else if (xcache)
47394746
cacheval = arg1;
@@ -4758,6 +4765,16 @@ class AdjointGenerator
47584765
if (Mode == DerivativeMode::ReverseModeGradient &&
47594766
(!gutils->isConstantValue(call.getArgOperand(1)) ||
47604767
!gutils->isConstantValue(call.getArgOperand(3)))) {
4768+
Type *cachetype = nullptr;
4769+
if (xcache && ycache)
4770+
cachetype = StructType::get(call.getContext(),
4771+
ArrayRef<Type *>(castvals));
4772+
else if (xcache)
4773+
cachetype = castvals[0];
4774+
else {
4775+
assert(ycache);
4776+
cachetype = castvals[1];
4777+
}
47614778
cacheval = BuilderZ.CreatePHI(cachetype, 0);
47624779
}
47634780
cacheval =

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3266,9 +3266,14 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
32663266
"return or non-constant");
32673267
}
32683268

3269-
if (key.todiff->empty() && CustomErrorHandler) {
3270-
std::string s = ("No derivative found for " + key.todiff->getName()).str();
3271-
CustomErrorHandler(s.c_str());
3269+
if (key.todiff->empty()) {
3270+
std::string str =
3271+
("No derivative found for " + key.todiff->getName()).str();
3272+
if (CustomErrorHandler) {
3273+
CustomErrorHandler(str.c_str());
3274+
} else {
3275+
llvm_unreachable(str.c_str());
3276+
}
32723277
}
32733278
assert(!key.todiff->empty());
32743279

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
4+
target triple = "x86_64-unknown-linux-gnu"
5+
6+
declare dso_local void @__enzyme_autodiff(...)
7+
8+
declare double @cblas_ddot(i32, double*, i32, double*, i32)
9+
10+
define void @active(i32 %len, double* noalias %m, double* %dm, i32 %incm, double* noalias %n, double* %dn, i32 %incn) {
11+
entry:
12+
call void (...) @__enzyme_autodiff(double (i32, double*, i32, double*, i32)* @f, i32 %len, double* noalias %m, double* %dm, i32 %incm, double* noalias %n, double* %dn, i32 %incn)
13+
ret void
14+
}
15+
16+
define void @inactiveFirst(i32 %len, double* noalias %m, i32 %incm, double* noalias %n, double* %dn, i32 %incn) {
17+
entry:
18+
call void (...) @__enzyme_autodiff(double (i32, double*, i32, double*, i32)* @f, i32 %len, metadata !"enzyme_const", double* noalias %m, i32 %incm, double* noalias %n, double* %dn, i32 %incn)
19+
ret void
20+
}
21+
22+
define void @inactiveSecond(i32 %len, double* noalias %m, double* noalias %dm, i32 %incm, double* noalias %n, i32 %incn) {
23+
entry:
24+
call void (...) @__enzyme_autodiff(double (i32, double*, i32, double*, i32)* @f, i32 %len, double* noalias %m, double* noalias %dm, i32 %incm, metadata !"enzyme_const", double* noalias %n, i32 %incn)
25+
ret void
26+
}
27+
28+
define void @activeMod(i32 %len, double* noalias %m, double* %dm, i32 %incm, double* noalias %n, double* %dn, i32 %incn) {
29+
entry:
30+
call void (...) @__enzyme_autodiff(double (i32, double*, i32, double*, i32)* @modf, i32 %len, double* noalias %m, double* %dm, i32 %incm, double* noalias %n, double* %dn, i32 %incn)
31+
ret void
32+
}
33+
34+
define void @inactiveModFirst(i32 %len, double* noalias %m, i32 %incm, double* noalias %n, double* %dn, i32 %incn) {
35+
entry:
36+
call void (...) @__enzyme_autodiff(double (i32, double*, i32, double*, i32)* @modf, i32 %len, metadata !"enzyme_const", double* noalias %m, i32 %incm, double* noalias %n, double* %dn, i32 %incn)
37+
ret void
38+
}
39+
40+
define void @inactiveModSecond(i32 %len, double* noalias %m, double* noalias %dm, i32 %incm, double* noalias %n, i32 %incn) {
41+
entry:
42+
call void (...) @__enzyme_autodiff(double (i32, double*, i32, double*, i32)* @modf, i32 %len, double* noalias %m, double* noalias %dm, i32 %incm, metadata !"enzyme_const", double* noalias %n, i32 %incn)
43+
ret void
44+
}
45+
46+
define double @f(i32 %len, double* noalias %m, i32 %incm, double* noalias %n, i32 %incn) {
47+
entry:
48+
%call = call double @cblas_ddot(i32 %len, double* %m, i32 %incm, double* %n, i32 %incn)
49+
ret double %call
50+
}
51+
52+
define double @modf(i32 %len, double* noalias %m, i32 %incm, double* noalias %n, i32 %incn) {
53+
entry:
54+
%call = call double @f(i32 %len, double* %m, i32 %incm, double* %n, i32 %incn)
55+
store double 0.000000e+00, double* %m
56+
store double 0.000000e+00, double* %n
57+
ret double %call
58+
}
59+
60+
61+
; CHECK: define void @active
62+
; CHECK-NEXT: entry
63+
; CHECK-NEXT: call void @[[active:.+]](
64+
65+
; CHECK: define void @inactiveFirst
66+
; CHECK-NEXT: entry
67+
; CHECK-NEXT: call void @[[inactiveFirst:.+]](
68+
69+
; CHECK: define void @inactiveSecond
70+
; CHECK-NEXT: entry
71+
; CHECK-NEXT: call void @[[inactiveSecond:.+]](
72+
73+
74+
; CHECK: define void @activeMod
75+
; CHECK-NEXT: entry
76+
; CHECK-NEXT: call void @[[activeMod:.+]](
77+
78+
; CHECK: define void @inactiveModFirst
79+
; CHECK-NEXT: entry
80+
; CHECK-NEXT: call void @[[inactiveModFirst:.+]](
81+
82+
; CHECK: define void @inactiveModSecond
83+
; CHECK-NEXT: entry
84+
; CHECK-NEXT: call void @[[inactiveModSecond:.+]](
85+
86+
87+
; CHECK: define internal void @[[active]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, double* %"n'", i32 %incn, double %differeturn)
88+
; CHECK-NEXT: entry:
89+
; CHECK-NEXT: %call = call double @cblas_ddot(i32 %len, double* nocapture readonly %m, i32 %incm, double* nocapture readonly %n, i32 %incn)
90+
; CHECK-NEXT: call void @cblas_daxpy(i32 %len, double %differeturn, double* %m, i32 %incm, double* %"n'", i32 %incn)
91+
; CHECK-NEXT: call void @cblas_daxpy(i32 %len, double %differeturn, double* %n, i32 %incn, double* %"m'", i32 %incm)
92+
; CHECK-NEXT: ret void
93+
; CHECK-NEXT: }
94+
95+
; CHECK: define internal void @[[inactiveFirst]](i32 %len, double* noalias %m, i32 %incm, double* noalias %n, double* %"n'", i32 %incn, double %differeturn)
96+
; CHECK-NEXT: entry:
97+
; CHECK-NEXT: %call = call double @cblas_ddot(i32 %len, double* nocapture readonly %m, i32 %incm, double* nocapture readonly %n, i32 %incn)
98+
; CHECK-NEXT: call void @cblas_daxpy(i32 %len, double %differeturn, double* %m, i32 %incm, double* %"n'", i32 %incn)
99+
; CHECK-NEXT: ret void
100+
; CHECK-NEXT: }
101+
102+
; CHECK: define internal void @[[inactiveSecond]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, i32 %incn, double %differeturn)
103+
; CHECK-NEXT: entry:
104+
; CHECK-NEXT: %call = call double @cblas_ddot(i32 %len, double* nocapture readonly %m, i32 %incm, double* nocapture readonly %n, i32 %incn)
105+
; CHECK-NEXT: call void @cblas_daxpy(i32 %len, double %differeturn, double* %n, i32 %incn, double* %"m'", i32 %incm)
106+
; CHECK-NEXT: ret void
107+
; CHECK-NEXT: }
108+
109+
; CHECK: define internal void @[[activeMod]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, double* %"n'", i32 %incn, double %differeturn)
110+
; CHECK-NEXT: entry:
111+
; CHECK: %call_augmented = call { double*, double* } @[[augMod:.+]](i32 %len, double* %m, double* %"m'", i32 %incm, double* %n, double* %"n'", i32 %incn)
112+
; CHECK: call void @[[revMod:.+]](i32 %len, double* %m, double* %"m'", i32 %incm, double* %n, double* %"n'", i32 %incn, double %differeturn, { double*, double* } %call_augmented)
113+
; CHECK-NEXT: ret void
114+
; CHECK-NEXT: }
115+
116+
; CHECK: define internal { double*, double* } @[[augMod]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, double* %"n'", i32 %incn)
117+
; CHECK-NEXT: entry:
118+
; CHECK-NEXT: %0 = zext i32 %len to i64
119+
; CHECK-NEXT: %mallocsize = mul i64 %0, ptrtoint (double* getelementptr (double, double* null, i32 1) to i64)
120+
; CHECK-NEXT: %malloccall = tail call i8* @malloc(i64 %mallocsize)
121+
; CHECK-NEXT: %1 = bitcast i8* %malloccall to double*
122+
; CHECK-NEXT: call void @__enzyme_memcpy_doubleda0sa0stride(double* %1, double* %m, i32 %len, i32 %incm)
123+
; CHECK-NEXT: %2 = zext i32 %len to i64
124+
; CHECK-NEXT: %mallocsize1 = mul i64 %2, ptrtoint (double* getelementptr (double, double* null, i32 1) to i64)
125+
; CHECK-NEXT: %malloccall2 = tail call i8* @malloc(i64 %mallocsize1)
126+
; CHECK-NEXT: %3 = bitcast i8* %malloccall2 to double*
127+
; CHECK-NEXT: call void @__enzyme_memcpy_doubleda0sa0stride(double* %3, double* %n, i32 %len, i32 %incn)
128+
; CHECK-NEXT: %4 = insertvalue { double*, double* } undef, double* %1, 0
129+
; CHECK-NEXT: %5 = insertvalue { double*, double* } %4, double* %3, 1
130+
; CHECK-NEXT: %call = call double @cblas_ddot(i32 %len, double* nocapture readonly %m, i32 %incm, double* nocapture readonly %n, i32 %incn)
131+
; CHECK-NEXT: ret { double*, double* } %5
132+
; CHECK-NEXT: }
133+
134+
; CHECK: define internal void @[[revMod]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, double* %"n'", i32 %incn, double %differeturn, { double*, double* }
135+
; CHECK-NEXT: entry:
136+
; CHECK-NEXT: %1 = extractvalue { double*, double* } %0, 0
137+
; CHECK-NEXT: %2 = extractvalue { double*, double* } %0, 1
138+
; CHECK-NEXT: call void @cblas_daxpy(i32 %len, double %differeturn, double* %1, i32 1, double* %"n'", i32 %incn)
139+
; CHECK-NEXT: %3 = bitcast double* %1 to i8*
140+
; CHECK-NEXT: tail call void @free(i8* %3)
141+
; CHECK-NEXT: call void @cblas_daxpy(i32 %len, double %differeturn, double* %2, i32 1, double* %"m'", i32 %incm)
142+
; CHECK-NEXT: %4 = bitcast double* %2 to i8*
143+
; CHECK-NEXT: tail call void @free(i8* %4)
144+
; CHECK-NEXT: ret void
145+
; CHECK-NEXT: }
146+
147+
; CHECK: define internal void @[[inactiveModFirst]](i32 %len, double* noalias %m, i32 %incm, double* noalias %n, double* %"n'", i32 %incn, double %differeturn)
148+
; CHECK-NEXT: entry:
149+
; CHECK: %call_augmented = call double* @[[augModFirst:.+]](i32 %len, double* %m, i32 %incm, double* %n, double* %"n'", i32 %incn)
150+
; CHECK: call void @[[revModFirst:.+]](i32 %len, double* %m, i32 %incm, double* %n, double* %"n'", i32 %incn, double %differeturn, double* %call_augmented)
151+
; CHECK-NEXT: ret void
152+
; CHECK-NEXT: }
153+
154+
; CHECK: define internal double* @[[augModFirst]](i32 %len, double* noalias %m, i32 %incm, double* noalias %n, double* %"n'", i32 %incn)
155+
; CHECK-NEXT: entry:
156+
; CHECK-NEXT: %0 = zext i32 %len to i64
157+
; CHECK-NEXT: %mallocsize = mul i64 %0, ptrtoint (double* getelementptr (double, double* null, i32 1) to i64)
158+
; CHECK-NEXT: %malloccall = tail call i8* @malloc(i64 %mallocsize)
159+
; CHECK-NEXT: %1 = bitcast i8* %malloccall to double*
160+
; CHECK-NEXT: call void @__enzyme_memcpy_doubleda0sa0stride(double* %1, double* %m, i32 %len, i32 %incm)
161+
; CHECK-NEXT: %call = call double @cblas_ddot(i32 %len, double* nocapture readonly %m, i32 %incm, double* nocapture readonly %n, i32 %incn)
162+
; CHECK-NEXT: ret double* %1
163+
; CHECK-NEXT: }
164+
165+
; CHECK: define internal void @[[revModFirst]](i32 %len, double* noalias %m, i32 %incm, double* noalias %n, double* %"n'", i32 %incn, double %differeturn, double*
166+
; CHECK-NEXT: entry:
167+
; CHECK-NEXT: call void @cblas_daxpy(i32 %len, double %differeturn, double* %0, i32 1, double* %"n'", i32 %incn)
168+
; CHECK-NEXT: %1 = bitcast double* %0 to i8*
169+
; CHECK-NEXT: tail call void @free(i8* %1)
170+
; CHECK-NEXT: ret void
171+
; CHECK-NEXT: }
172+
173+
; CHECK: define internal void @[[inactiveModSecond]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, i32 %incn, double %differeturn)
174+
; CHECK-NEXT: entry:
175+
; CHECK: %call_augmented = call double* @[[augModSecond:.+]](i32 %len, double* %m, double* %"m'", i32 %incm, double* %n, i32 %incn)
176+
; CHECK: call void @[[revModSecond:.+]](i32 %len, double* %m, double* %"m'", i32 %incm, double* %n, i32 %incn, double %differeturn, double* %call_augmented)
177+
; CHECK-NEXT: ret void
178+
; CHECK-NEXT: }
179+
180+
; CHECK: define internal double* @[[augModSecond]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, i32 %incn)
181+
; CHECK-NEXT: entry:
182+
; CHECK-NEXT: %0 = zext i32 %len to i64
183+
; CHECK-NEXT: %mallocsize = mul i64 %0, ptrtoint (double* getelementptr (double, double* null, i32 1) to i64)
184+
; CHECK-NEXT: %malloccall = tail call i8* @malloc(i64 %mallocsize)
185+
; CHECK-NEXT: %1 = bitcast i8* %malloccall to double*
186+
; CHECK-NEXT: call void @__enzyme_memcpy_doubleda0sa0stride(double* %1, double* %n, i32 %len, i32 %incn)
187+
; CHECK-NEXT: %call = call double @cblas_ddot(i32 %len, double* nocapture readonly %m, i32 %incm, double* nocapture readonly %n, i32 %incn)
188+
; CHECK-NEXT: ret double* %1
189+
; CHECK-NEXT: }
190+
191+
; CHECK: define internal void @[[revModSecond]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, i32 %incn, double %differeturn, double*
192+
; CHECK-NEXT: entry:
193+
; CHECK-NEXT: call void @cblas_daxpy(i32 %len, double %differeturn, double* %0, i32 1, double* %"m'", i32 %incm)
194+
; CHECK-NEXT: %1 = bitcast double* %0 to i8*
195+
; CHECK-NEXT: tail call void @free(i8* %1)
196+
; CHECK-NEXT: ret void
197+
; CHECK-NEXT: }
198+

0 commit comments

Comments
 (0)