Skip to content

Commit 51988b0

Browse files
authored
Move Assertions (rust-lang#306)
* move assertions * don't use lookup in invertPointerM when running in forward mode
1 parent 1010096 commit 51988b0

File tree

3 files changed

+153
-97
lines changed

3 files changed

+153
-97
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 53 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -716,8 +716,8 @@ class AdjointGenerator
716716
if (constantval) {
717717
ts = setPtrDiffe(orig_ptr, Constant::getNullValue(valType), Builder2);
718718
} else {
719-
auto dif1 =
720-
Builder2.CreateLoad(gutils->invertPointerM(orig_ptr, Builder2));
719+
auto dif1 = Builder2.CreateLoad(
720+
lookup(gutils->invertPointerM(orig_ptr, Builder2), Builder2));
721721
#if LLVM_VERSION_MAJOR >= 10
722722
dif1->setAlignment(SI.getAlign());
723723
#else
@@ -1340,8 +1340,6 @@ class AdjointGenerator
13401340

13411341
std::vector<SelectInst *> addToDiffe(Value *val, Value *dif,
13421342
IRBuilder<> &Builder, Type *T) {
1343-
assert(Mode == DerivativeMode::ReverseModeGradient ||
1344-
Mode == DerivativeMode::ReverseModeCombined);
13451343
return ((DiffeGradientUtils *)gutils)->addToDiffe(val, dif, Builder, T);
13461344
}
13471345

@@ -1928,7 +1926,8 @@ class AdjointGenerator
19281926
// (which thus == src and may be illegal)
19291927
if (gutils->isConstantValue(orig_src)) {
19301928
SmallVector<Value *, 4> args;
1931-
args.push_back(gutils->invertPointerM(orig_dst, Builder2));
1929+
args.push_back(
1930+
lookup(gutils->invertPointerM(orig_dst, Builder2), Builder2));
19321931
if (args[0]->getType()->isIntegerTy())
19331932
args[0] = Builder2.CreateIntToPtr(
19341933
args[0], Type::getInt8PtrTy(MTI->getContext()));
@@ -1958,7 +1957,8 @@ class AdjointGenerator
19581957

19591958
} else {
19601959
SmallVector<Value *, 4> args;
1961-
auto dsto = gutils->invertPointerM(orig_dst, Builder2);
1960+
auto dsto =
1961+
lookup(gutils->invertPointerM(orig_dst, Builder2), Builder2);
19621962
if (dsto->getType()->isIntegerTy())
19631963
dsto = Builder2.CreateIntToPtr(
19641964
dsto, Type::getInt8PtrTy(dsto->getContext()));
@@ -1968,7 +1968,8 @@ class AdjointGenerator
19681968
if (offset != 0)
19691969
dsto = Builder2.CreateConstInBoundsGEP1_64(dsto, offset);
19701970
args.push_back(Builder2.CreatePointerCast(dsto, secretpt));
1971-
auto srco = gutils->invertPointerM(orig_src, Builder2);
1971+
auto srco =
1972+
lookup(gutils->invertPointerM(orig_src, Builder2), Builder2);
19721973
if (srco->getType()->isIntegerTy())
19731974
srco = Builder2.CreateIntToPtr(
19741975
srco, Type::getInt8PtrTy(srco->getContext()));
@@ -2949,7 +2950,8 @@ class AdjointGenerator
29492950
IRBuilder<> Builder2(call.getParent());
29502951
getReverseBuilder(Builder2);
29512952
args.push_back(
2952-
gutils->invertPointerM(call.getArgOperand(i), Builder2));
2953+
lookup(gutils->invertPointerM(call.getArgOperand(i), Builder2),
2954+
Builder2));
29532955
}
29542956
pre_args.push_back(
29552957
gutils->invertPointerM(call.getArgOperand(i), BuilderZ));
@@ -3715,7 +3717,8 @@ class AdjointGenerator
37153717
llvm::errs() << " warning could not automatically determine mpi "
37163718
"status type, assuming [24 x i8]\n";
37173719
}
3718-
Value *d_req = gutils->invertPointerM(call.getOperand(6), Builder2);
3720+
Value *d_req = lookup(
3721+
gutils->invertPointerM(call.getOperand(6), Builder2), Builder2);
37193722
Value *args[] = {/*req*/ d_req,
37203723
/*status*/ IRBuilder<>(gutils->inversionAllocs)
37213724
.CreateAlloca(statusType)};
@@ -3769,7 +3772,8 @@ class AdjointGenerator
37693772
ConstantInt::get(Type::getInt8Ty(Builder2.getContext()), 0);
37703773
auto volatile_arg = ConstantInt::getFalse(Builder2.getContext());
37713774
assert(!gutils->isConstantValue(call.getOperand(0)));
3772-
auto dbuf = gutils->invertPointerM(call.getOperand(0), Builder2);
3775+
auto dbuf = lookup(
3776+
gutils->invertPointerM(call.getOperand(0), Builder2), Builder2);
37733777
if (dbuf->getType()->isIntegerTy())
37743778
dbuf = Builder2.CreateIntToPtr(
37753779
dbuf, Type::getInt8PtrTy(call.getContext()));
@@ -3790,8 +3794,8 @@ class AdjointGenerator
37903794
memset->addParamAttr(0, Attribute::NonNull);
37913795
} else if (funcName == "MPI_Isend" || funcName == "PMPI_Isend") {
37923796
assert(!gutils->isConstantValue(call.getOperand(0)));
3793-
Value *shadow =
3794-
gutils->invertPointerM(call.getOperand(0), Builder2);
3797+
Value *shadow = lookup(
3798+
gutils->invertPointerM(call.getOperand(0), Builder2), Builder2);
37953799
if (Mode == DerivativeMode::ReverseModeCombined) {
37963800
assert(firstallocation);
37973801
firstallocation = lookup(firstallocation, Builder2);
@@ -3830,7 +3834,8 @@ class AdjointGenerator
38303834
getReverseBuilder(Builder2);
38313835

38323836
assert(!gutils->isConstantValue(call.getOperand(0)));
3833-
Value *d_req = gutils->invertPointerM(call.getOperand(0), Builder2);
3837+
Value *d_req = lookup(
3838+
gutils->invertPointerM(call.getOperand(0), Builder2), Builder2);
38343839
if (d_req->getType()->isIntegerTy()) {
38353840
d_req = Builder2.CreateIntToPtr(
38363841
d_req,
@@ -3908,8 +3913,8 @@ class AdjointGenerator
39083913
assert(!gutils->isConstantValue(call.getOperand(1)));
39093914
Value *count =
39103915
lookup(gutils->getNewFromOriginal(call.getOperand(0)), Builder2);
3911-
Value *d_req_orig =
3912-
gutils->invertPointerM(call.getOperand(1), Builder2);
3916+
Value *d_req_orig = lookup(
3917+
gutils->invertPointerM(call.getOperand(1), Builder2), Builder2);
39133918
if (d_req_orig->getType()->isIntegerTy()) {
39143919
d_req_orig = Builder2.CreateIntToPtr(
39153920
d_req_orig,
@@ -4007,7 +4012,8 @@ class AdjointGenerator
40074012
Mode == DerivativeMode::ReverseModeCombined) {
40084013
IRBuilder<> Builder2(call.getParent());
40094014
getReverseBuilder(Builder2);
4010-
Value *shadow = gutils->invertPointerM(call.getOperand(0), Builder2);
4015+
Value *shadow = lookup(
4016+
gutils->invertPointerM(call.getOperand(0), Builder2), Builder2);
40114017

40124018
if (shadow->getType()->isIntegerTy())
40134019
shadow = Builder2.CreateIntToPtr(
@@ -4095,7 +4101,8 @@ class AdjointGenerator
40954101
IRBuilder<> Builder2(call.getParent());
40964102
getReverseBuilder(Builder2);
40974103

4098-
Value *shadow = gutils->invertPointerM(call.getOperand(0), Builder2);
4104+
Value *shadow = lookup(
4105+
gutils->invertPointerM(call.getOperand(0), Builder2), Builder2);
40994106
if (shadow->getType()->isIntegerTy())
41004107
shadow = Builder2.CreateIntToPtr(
41014108
shadow, Type::getInt8PtrTy(call.getContext()));
@@ -4165,7 +4172,8 @@ class AdjointGenerator
41654172
IRBuilder<> Builder2(call.getParent());
41664173
getReverseBuilder(Builder2);
41674174

4168-
Value *shadow = gutils->invertPointerM(call.getOperand(0), Builder2);
4175+
Value *shadow = lookup(
4176+
gutils->invertPointerM(call.getOperand(0), Builder2), Builder2);
41694177
if (shadow->getType()->isIntegerTy())
41704178
shadow = Builder2.CreateIntToPtr(
41714179
shadow, Type::getInt8PtrTy(call.getContext()));
@@ -4365,11 +4373,13 @@ class AdjointGenerator
43654373
report_fatal_error("unhandled mpi_allreduce op");
43664374
}
43674375

4368-
Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2);
4376+
Value *shadow_recvbuf =
4377+
lookup(gutils->invertPointerM(orig_recvbuf, Builder2), Builder2);
43694378
if (shadow_recvbuf->getType()->isIntegerTy())
43704379
shadow_recvbuf = Builder2.CreateIntToPtr(
43714380
shadow_recvbuf, Type::getInt8PtrTy(call.getContext()));
4372-
Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2);
4381+
Value *shadow_sendbuf =
4382+
lookup(gutils->invertPointerM(orig_sendbuf, Builder2), Builder2);
43734383
if (shadow_sendbuf->getType()->isIntegerTy())
43744384
shadow_sendbuf = Builder2.CreateIntToPtr(
43754385
shadow_sendbuf, Type::getInt8PtrTy(call.getContext()));
@@ -4552,11 +4562,13 @@ class AdjointGenerator
45524562
report_fatal_error("unhandled mpi_allreduce op");
45534563
}
45544564

4555-
Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2);
4565+
Value *shadow_recvbuf =
4566+
lookup(gutils->invertPointerM(orig_recvbuf, Builder2), Builder2);
45564567
if (shadow_recvbuf->getType()->isIntegerTy())
45574568
shadow_recvbuf = Builder2.CreateIntToPtr(
45584569
shadow_recvbuf, Type::getInt8PtrTy(call.getContext()));
4559-
Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2);
4570+
Value *shadow_sendbuf =
4571+
lookup(gutils->invertPointerM(orig_sendbuf, Builder2), Builder2);
45604572
if (shadow_sendbuf->getType()->isIntegerTy())
45614573
shadow_sendbuf = Builder2.CreateIntToPtr(
45624574
shadow_sendbuf, Type::getInt8PtrTy(call.getContext()));
@@ -4665,11 +4677,13 @@ class AdjointGenerator
46654677
Value *orig_root = call.getOperand(6);
46664678
Value *orig_comm = call.getOperand(7);
46674679

4668-
Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2);
4680+
Value *shadow_recvbuf =
4681+
lookup(gutils->invertPointerM(orig_recvbuf, Builder2), Builder2);
46694682
if (shadow_recvbuf->getType()->isIntegerTy())
46704683
shadow_recvbuf = Builder2.CreateIntToPtr(
46714684
shadow_recvbuf, Type::getInt8PtrTy(call.getContext()));
4672-
Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2);
4685+
Value *shadow_sendbuf =
4686+
lookup(gutils->invertPointerM(orig_sendbuf, Builder2), Builder2);
46734687
if (shadow_sendbuf->getType()->isIntegerTy())
46744688
shadow_sendbuf = Builder2.CreateIntToPtr(
46754689
shadow_sendbuf, Type::getInt8PtrTy(call.getContext()));
@@ -4820,11 +4834,13 @@ class AdjointGenerator
48204834
Value *orig_root = call.getOperand(6);
48214835
Value *orig_comm = call.getOperand(7);
48224836

4823-
Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2);
4837+
Value *shadow_recvbuf =
4838+
lookup(gutils->invertPointerM(orig_recvbuf, Builder2), Builder2);
48244839
if (shadow_recvbuf->getType()->isIntegerTy())
48254840
shadow_recvbuf = Builder2.CreateIntToPtr(
48264841
shadow_recvbuf, Type::getInt8PtrTy(call.getContext()));
4827-
Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2);
4842+
Value *shadow_sendbuf =
4843+
lookup(gutils->invertPointerM(orig_sendbuf, Builder2), Builder2);
48284844
if (shadow_sendbuf->getType()->isIntegerTy())
48294845
shadow_sendbuf = Builder2.CreateIntToPtr(
48304846
shadow_sendbuf, Type::getInt8PtrTy(call.getContext()));
@@ -5008,11 +5024,13 @@ class AdjointGenerator
50085024
Value *orig_recvcount = call.getOperand(4);
50095025
Value *orig_comm = call.getOperand(6);
50105026

5011-
Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2);
5027+
Value *shadow_recvbuf =
5028+
lookup(gutils->invertPointerM(orig_recvbuf, Builder2), Builder2);
50125029
if (shadow_recvbuf->getType()->isIntegerTy())
50135030
shadow_recvbuf = Builder2.CreateIntToPtr(
50145031
shadow_recvbuf, Type::getInt8PtrTy(call.getContext()));
5015-
Value *shadow_sendbuf = gutils->invertPointerM(orig_sendbuf, Builder2);
5032+
Value *shadow_sendbuf =
5033+
lookup(gutils->invertPointerM(orig_sendbuf, Builder2), Builder2);
50165034
if (shadow_sendbuf->getType()->isIntegerTy())
50175035
shadow_sendbuf = Builder2.CreateIntToPtr(
50185036
shadow_sendbuf, Type::getInt8PtrTy(call.getContext()));
@@ -5502,7 +5520,8 @@ class AdjointGenerator
55025520
diffe(orig, Builder2),
55035521
structarg1,
55045522
estride,
5505-
gutils->invertPointerM(orig->getArgOperand(3), Builder2),
5523+
lookup(gutils->invertPointerM(orig->getArgOperand(3), Builder2),
5524+
Builder2),
55065525
lookup(gutils->getNewFromOriginal(orig->getArgOperand(4)),
55075526
Builder2)};
55085527
firstdcall = Builder2.CreateCall(derivcall, args1);
@@ -5520,7 +5539,8 @@ class AdjointGenerator
55205539
diffe(orig, Builder2),
55215540
structarg2,
55225541
estride,
5523-
gutils->invertPointerM(orig->getArgOperand(1), Builder2),
5542+
lookup(gutils->invertPointerM(orig->getArgOperand(1), Builder2),
5543+
Builder2),
55245544
lookup(gutils->getNewFromOriginal(orig->getArgOperand(2)),
55255545
Builder2)};
55265546
seconddcall = Builder2.CreateCall(derivcall, args2);
@@ -7267,7 +7287,8 @@ class AdjointGenerator
72677287
IRBuilder<> Builder2(call.getParent());
72687288
getReverseBuilder(Builder2);
72697289
args.push_back(
7270-
gutils->invertPointerM(orig->getArgOperand(i), Builder2));
7290+
lookup(gutils->invertPointerM(orig->getArgOperand(i), Builder2),
7291+
Builder2));
72717292
}
72727293
pre_args.push_back(
72737294
gutils->invertPointerM(orig->getArgOperand(i), BuilderZ));
@@ -7702,7 +7723,7 @@ class AdjointGenerator
77027723
llvm::errs() << " orig: " << *orig << " callval: " << *callval << "\n";
77037724
}
77047725
assert(!gutils->isConstantValue(callval));
7705-
newcalled = gutils->invertPointerM(callval, Builder2);
7726+
newcalled = lookup(gutils->invertPointerM(callval, Builder2), Builder2);
77067727

77077728
auto ft = cast<FunctionType>(
77087729
cast<PointerType>(callval->getType())->getElementType());

0 commit comments

Comments
 (0)