@@ -716,8 +716,8 @@ class AdjointGenerator
716
716
if (constantval) {
717
717
ts = setPtrDiffe (orig_ptr, Constant::getNullValue (valType), Builder2);
718
718
} else {
719
- auto dif1 =
720
- Builder2. CreateLoad (gutils->invertPointerM (orig_ptr, Builder2));
719
+ auto dif1 = Builder2. CreateLoad (
720
+ lookup (gutils->invertPointerM (orig_ptr, Builder2) , Builder2));
721
721
#if LLVM_VERSION_MAJOR >= 10
722
722
dif1->setAlignment (SI.getAlign ());
723
723
#else
@@ -1340,8 +1340,6 @@ class AdjointGenerator
1340
1340
1341
1341
std::vector<SelectInst *> addToDiffe (Value *val, Value *dif,
1342
1342
IRBuilder<> &Builder, Type *T) {
1343
- assert (Mode == DerivativeMode::ReverseModeGradient ||
1344
- Mode == DerivativeMode::ReverseModeCombined);
1345
1343
return ((DiffeGradientUtils *)gutils)->addToDiffe (val, dif, Builder, T);
1346
1344
}
1347
1345
@@ -1928,7 +1926,8 @@ class AdjointGenerator
1928
1926
// (which thus == src and may be illegal)
1929
1927
if (gutils->isConstantValue (orig_src)) {
1930
1928
SmallVector<Value *, 4 > args;
1931
- args.push_back (gutils->invertPointerM (orig_dst, Builder2));
1929
+ args.push_back (
1930
+ lookup (gutils->invertPointerM (orig_dst, Builder2), Builder2));
1932
1931
if (args[0 ]->getType ()->isIntegerTy ())
1933
1932
args[0 ] = Builder2.CreateIntToPtr (
1934
1933
args[0 ], Type::getInt8PtrTy (MTI->getContext ()));
@@ -1958,7 +1957,8 @@ class AdjointGenerator
1958
1957
1959
1958
} else {
1960
1959
SmallVector<Value *, 4 > args;
1961
- auto dsto = gutils->invertPointerM (orig_dst, Builder2);
1960
+ auto dsto =
1961
+ lookup (gutils->invertPointerM (orig_dst, Builder2), Builder2);
1962
1962
if (dsto->getType ()->isIntegerTy ())
1963
1963
dsto = Builder2.CreateIntToPtr (
1964
1964
dsto, Type::getInt8PtrTy (dsto->getContext ()));
@@ -1968,7 +1968,8 @@ class AdjointGenerator
1968
1968
if (offset != 0 )
1969
1969
dsto = Builder2.CreateConstInBoundsGEP1_64 (dsto, offset);
1970
1970
args.push_back (Builder2.CreatePointerCast (dsto, secretpt));
1971
- auto srco = gutils->invertPointerM (orig_src, Builder2);
1971
+ auto srco =
1972
+ lookup (gutils->invertPointerM (orig_src, Builder2), Builder2);
1972
1973
if (srco->getType ()->isIntegerTy ())
1973
1974
srco = Builder2.CreateIntToPtr (
1974
1975
srco, Type::getInt8PtrTy (srco->getContext ()));
@@ -2949,7 +2950,8 @@ class AdjointGenerator
2949
2950
IRBuilder<> Builder2 (call.getParent ());
2950
2951
getReverseBuilder (Builder2);
2951
2952
args.push_back (
2952
- gutils->invertPointerM (call.getArgOperand (i), Builder2));
2953
+ lookup (gutils->invertPointerM (call.getArgOperand (i), Builder2),
2954
+ Builder2));
2953
2955
}
2954
2956
pre_args.push_back (
2955
2957
gutils->invertPointerM (call.getArgOperand (i), BuilderZ));
@@ -3715,7 +3717,8 @@ class AdjointGenerator
3715
3717
llvm::errs () << " warning could not automatically determine mpi "
3716
3718
" status type, assuming [24 x i8]\n " ;
3717
3719
}
3718
- Value *d_req = gutils->invertPointerM (call.getOperand (6 ), Builder2);
3720
+ Value *d_req = lookup (
3721
+ gutils->invertPointerM (call.getOperand (6 ), Builder2), Builder2);
3719
3722
Value *args[] = {/* req*/ d_req,
3720
3723
/* status*/ IRBuilder<>(gutils->inversionAllocs )
3721
3724
.CreateAlloca (statusType)};
@@ -3769,7 +3772,8 @@ class AdjointGenerator
3769
3772
ConstantInt::get (Type::getInt8Ty (Builder2.getContext ()), 0 );
3770
3773
auto volatile_arg = ConstantInt::getFalse (Builder2.getContext ());
3771
3774
assert (!gutils->isConstantValue (call.getOperand (0 )));
3772
- auto dbuf = gutils->invertPointerM (call.getOperand (0 ), Builder2);
3775
+ auto dbuf = lookup (
3776
+ gutils->invertPointerM (call.getOperand (0 ), Builder2), Builder2);
3773
3777
if (dbuf->getType ()->isIntegerTy ())
3774
3778
dbuf = Builder2.CreateIntToPtr (
3775
3779
dbuf, Type::getInt8PtrTy (call.getContext ()));
@@ -3790,8 +3794,8 @@ class AdjointGenerator
3790
3794
memset ->addParamAttr (0 , Attribute::NonNull);
3791
3795
} else if (funcName == " MPI_Isend" || funcName == " PMPI_Isend" ) {
3792
3796
assert (!gutils->isConstantValue (call.getOperand (0 )));
3793
- Value *shadow =
3794
- gutils->invertPointerM (call.getOperand (0 ), Builder2);
3797
+ Value *shadow = lookup (
3798
+ gutils->invertPointerM (call.getOperand (0 ), Builder2), Builder2) ;
3795
3799
if (Mode == DerivativeMode::ReverseModeCombined) {
3796
3800
assert (firstallocation);
3797
3801
firstallocation = lookup (firstallocation, Builder2);
@@ -3830,7 +3834,8 @@ class AdjointGenerator
3830
3834
getReverseBuilder (Builder2);
3831
3835
3832
3836
assert (!gutils->isConstantValue (call.getOperand (0 )));
3833
- Value *d_req = gutils->invertPointerM (call.getOperand (0 ), Builder2);
3837
+ Value *d_req = lookup (
3838
+ gutils->invertPointerM (call.getOperand (0 ), Builder2), Builder2);
3834
3839
if (d_req->getType ()->isIntegerTy ()) {
3835
3840
d_req = Builder2.CreateIntToPtr (
3836
3841
d_req,
@@ -3908,8 +3913,8 @@ class AdjointGenerator
3908
3913
assert (!gutils->isConstantValue (call.getOperand (1 )));
3909
3914
Value *count =
3910
3915
lookup (gutils->getNewFromOriginal (call.getOperand (0 )), Builder2);
3911
- Value *d_req_orig =
3912
- gutils->invertPointerM (call.getOperand (1 ), Builder2);
3916
+ Value *d_req_orig = lookup (
3917
+ gutils->invertPointerM (call.getOperand (1 ), Builder2), Builder2) ;
3913
3918
if (d_req_orig->getType ()->isIntegerTy ()) {
3914
3919
d_req_orig = Builder2.CreateIntToPtr (
3915
3920
d_req_orig,
@@ -4007,7 +4012,8 @@ class AdjointGenerator
4007
4012
Mode == DerivativeMode::ReverseModeCombined) {
4008
4013
IRBuilder<> Builder2 (call.getParent ());
4009
4014
getReverseBuilder (Builder2);
4010
- Value *shadow = gutils->invertPointerM (call.getOperand (0 ), Builder2);
4015
+ Value *shadow = lookup (
4016
+ gutils->invertPointerM (call.getOperand (0 ), Builder2), Builder2);
4011
4017
4012
4018
if (shadow ->getType ()->isIntegerTy ())
4013
4019
shadow = Builder2.CreateIntToPtr (
@@ -4095,7 +4101,8 @@ class AdjointGenerator
4095
4101
IRBuilder<> Builder2 (call.getParent ());
4096
4102
getReverseBuilder (Builder2);
4097
4103
4098
- Value *shadow = gutils->invertPointerM (call.getOperand (0 ), Builder2);
4104
+ Value *shadow = lookup (
4105
+ gutils->invertPointerM (call.getOperand (0 ), Builder2), Builder2);
4099
4106
if (shadow ->getType ()->isIntegerTy ())
4100
4107
shadow = Builder2.CreateIntToPtr (
4101
4108
shadow , Type::getInt8PtrTy (call.getContext ()));
@@ -4165,7 +4172,8 @@ class AdjointGenerator
4165
4172
IRBuilder<> Builder2 (call.getParent ());
4166
4173
getReverseBuilder (Builder2);
4167
4174
4168
- Value *shadow = gutils->invertPointerM (call.getOperand (0 ), Builder2);
4175
+ Value *shadow = lookup (
4176
+ gutils->invertPointerM (call.getOperand (0 ), Builder2), Builder2);
4169
4177
if (shadow ->getType ()->isIntegerTy ())
4170
4178
shadow = Builder2.CreateIntToPtr (
4171
4179
shadow , Type::getInt8PtrTy (call.getContext ()));
@@ -4365,11 +4373,13 @@ class AdjointGenerator
4365
4373
report_fatal_error (" unhandled mpi_allreduce op" );
4366
4374
}
4367
4375
4368
- Value *shadow_recvbuf = gutils->invertPointerM (orig_recvbuf, Builder2);
4376
+ Value *shadow_recvbuf =
4377
+ lookup (gutils->invertPointerM (orig_recvbuf, Builder2), Builder2);
4369
4378
if (shadow_recvbuf->getType ()->isIntegerTy ())
4370
4379
shadow_recvbuf = Builder2.CreateIntToPtr (
4371
4380
shadow_recvbuf, Type::getInt8PtrTy (call.getContext ()));
4372
- Value *shadow_sendbuf = gutils->invertPointerM (orig_sendbuf, Builder2);
4381
+ Value *shadow_sendbuf =
4382
+ lookup (gutils->invertPointerM (orig_sendbuf, Builder2), Builder2);
4373
4383
if (shadow_sendbuf->getType ()->isIntegerTy ())
4374
4384
shadow_sendbuf = Builder2.CreateIntToPtr (
4375
4385
shadow_sendbuf, Type::getInt8PtrTy (call.getContext ()));
@@ -4552,11 +4562,13 @@ class AdjointGenerator
4552
4562
report_fatal_error (" unhandled mpi_allreduce op" );
4553
4563
}
4554
4564
4555
- Value *shadow_recvbuf = gutils->invertPointerM (orig_recvbuf, Builder2);
4565
+ Value *shadow_recvbuf =
4566
+ lookup (gutils->invertPointerM (orig_recvbuf, Builder2), Builder2);
4556
4567
if (shadow_recvbuf->getType ()->isIntegerTy ())
4557
4568
shadow_recvbuf = Builder2.CreateIntToPtr (
4558
4569
shadow_recvbuf, Type::getInt8PtrTy (call.getContext ()));
4559
- Value *shadow_sendbuf = gutils->invertPointerM (orig_sendbuf, Builder2);
4570
+ Value *shadow_sendbuf =
4571
+ lookup (gutils->invertPointerM (orig_sendbuf, Builder2), Builder2);
4560
4572
if (shadow_sendbuf->getType ()->isIntegerTy ())
4561
4573
shadow_sendbuf = Builder2.CreateIntToPtr (
4562
4574
shadow_sendbuf, Type::getInt8PtrTy (call.getContext ()));
@@ -4665,11 +4677,13 @@ class AdjointGenerator
4665
4677
Value *orig_root = call.getOperand (6 );
4666
4678
Value *orig_comm = call.getOperand (7 );
4667
4679
4668
- Value *shadow_recvbuf = gutils->invertPointerM (orig_recvbuf, Builder2);
4680
+ Value *shadow_recvbuf =
4681
+ lookup (gutils->invertPointerM (orig_recvbuf, Builder2), Builder2);
4669
4682
if (shadow_recvbuf->getType ()->isIntegerTy ())
4670
4683
shadow_recvbuf = Builder2.CreateIntToPtr (
4671
4684
shadow_recvbuf, Type::getInt8PtrTy (call.getContext ()));
4672
- Value *shadow_sendbuf = gutils->invertPointerM (orig_sendbuf, Builder2);
4685
+ Value *shadow_sendbuf =
4686
+ lookup (gutils->invertPointerM (orig_sendbuf, Builder2), Builder2);
4673
4687
if (shadow_sendbuf->getType ()->isIntegerTy ())
4674
4688
shadow_sendbuf = Builder2.CreateIntToPtr (
4675
4689
shadow_sendbuf, Type::getInt8PtrTy (call.getContext ()));
@@ -4820,11 +4834,13 @@ class AdjointGenerator
4820
4834
Value *orig_root = call.getOperand (6 );
4821
4835
Value *orig_comm = call.getOperand (7 );
4822
4836
4823
- Value *shadow_recvbuf = gutils->invertPointerM (orig_recvbuf, Builder2);
4837
+ Value *shadow_recvbuf =
4838
+ lookup (gutils->invertPointerM (orig_recvbuf, Builder2), Builder2);
4824
4839
if (shadow_recvbuf->getType ()->isIntegerTy ())
4825
4840
shadow_recvbuf = Builder2.CreateIntToPtr (
4826
4841
shadow_recvbuf, Type::getInt8PtrTy (call.getContext ()));
4827
- Value *shadow_sendbuf = gutils->invertPointerM (orig_sendbuf, Builder2);
4842
+ Value *shadow_sendbuf =
4843
+ lookup (gutils->invertPointerM (orig_sendbuf, Builder2), Builder2);
4828
4844
if (shadow_sendbuf->getType ()->isIntegerTy ())
4829
4845
shadow_sendbuf = Builder2.CreateIntToPtr (
4830
4846
shadow_sendbuf, Type::getInt8PtrTy (call.getContext ()));
@@ -5008,11 +5024,13 @@ class AdjointGenerator
5008
5024
Value *orig_recvcount = call.getOperand (4 );
5009
5025
Value *orig_comm = call.getOperand (6 );
5010
5026
5011
- Value *shadow_recvbuf = gutils->invertPointerM (orig_recvbuf, Builder2);
5027
+ Value *shadow_recvbuf =
5028
+ lookup (gutils->invertPointerM (orig_recvbuf, Builder2), Builder2);
5012
5029
if (shadow_recvbuf->getType ()->isIntegerTy ())
5013
5030
shadow_recvbuf = Builder2.CreateIntToPtr (
5014
5031
shadow_recvbuf, Type::getInt8PtrTy (call.getContext ()));
5015
- Value *shadow_sendbuf = gutils->invertPointerM (orig_sendbuf, Builder2);
5032
+ Value *shadow_sendbuf =
5033
+ lookup (gutils->invertPointerM (orig_sendbuf, Builder2), Builder2);
5016
5034
if (shadow_sendbuf->getType ()->isIntegerTy ())
5017
5035
shadow_sendbuf = Builder2.CreateIntToPtr (
5018
5036
shadow_sendbuf, Type::getInt8PtrTy (call.getContext ()));
@@ -5502,7 +5520,8 @@ class AdjointGenerator
5502
5520
diffe (orig, Builder2),
5503
5521
structarg1,
5504
5522
estride,
5505
- gutils->invertPointerM (orig->getArgOperand (3 ), Builder2),
5523
+ lookup (gutils->invertPointerM (orig->getArgOperand (3 ), Builder2),
5524
+ Builder2),
5506
5525
lookup (gutils->getNewFromOriginal (orig->getArgOperand (4 )),
5507
5526
Builder2)};
5508
5527
firstdcall = Builder2.CreateCall (derivcall, args1);
@@ -5520,7 +5539,8 @@ class AdjointGenerator
5520
5539
diffe (orig, Builder2),
5521
5540
structarg2,
5522
5541
estride,
5523
- gutils->invertPointerM (orig->getArgOperand (1 ), Builder2),
5542
+ lookup (gutils->invertPointerM (orig->getArgOperand (1 ), Builder2),
5543
+ Builder2),
5524
5544
lookup (gutils->getNewFromOriginal (orig->getArgOperand (2 )),
5525
5545
Builder2)};
5526
5546
seconddcall = Builder2.CreateCall (derivcall, args2);
@@ -7267,7 +7287,8 @@ class AdjointGenerator
7267
7287
IRBuilder<> Builder2 (call.getParent ());
7268
7288
getReverseBuilder (Builder2);
7269
7289
args.push_back (
7270
- gutils->invertPointerM (orig->getArgOperand (i), Builder2));
7290
+ lookup (gutils->invertPointerM (orig->getArgOperand (i), Builder2),
7291
+ Builder2));
7271
7292
}
7272
7293
pre_args.push_back (
7273
7294
gutils->invertPointerM (orig->getArgOperand (i), BuilderZ));
@@ -7702,7 +7723,7 @@ class AdjointGenerator
7702
7723
llvm::errs () << " orig: " << *orig << " callval: " << *callval << " \n " ;
7703
7724
}
7704
7725
assert (!gutils->isConstantValue (callval));
7705
- newcalled = gutils->invertPointerM (callval, Builder2);
7726
+ newcalled = lookup ( gutils->invertPointerM (callval, Builder2) , Builder2);
7706
7727
7707
7728
auto ft = cast<FunctionType>(
7708
7729
cast<PointerType>(callval->getType ())->getElementType ());
0 commit comments