@@ -417,11 +417,18 @@ class AdjointGenerator
417
417
}
418
418
#endif
419
419
420
- llvm::errs() << *gutils->oldFunc << "\n";
421
- llvm::errs() << *gutils->newFunc << "\n";
422
- llvm::errs() << "in Mode: " << to_string(Mode) << "\n";
423
- llvm::errs() << "cannot handle unknown instruction\n" << inst;
424
- report_fatal_error("unknown value");
420
+ std::string s;
421
+ llvm::raw_string_ostream ss(s);
422
+ ss << *gutils->oldFunc << "\n";
423
+ ss << *gutils->newFunc << "\n";
424
+ ss << "in Mode: " << to_string(Mode) << "\n";
425
+ ss << "cannot handle unknown instruction\n" << inst;
426
+ if (CustomErrorHandler) {
427
+ CustomErrorHandler(ss.str().c_str(), wrap(&inst), ErrorType::NoDerivative,
428
+ nullptr);
429
+ }
430
+ llvm::errs() << ss.str() << "\n";
431
+ report_fatal_error("unknown instruction");
425
432
}
426
433
427
434
// Common function for falling back to the implementation
@@ -1170,10 +1177,16 @@ class AdjointGenerator
1170
1177
// TODO CHECK THIS
1171
1178
return Builder2.CreateZExt(dif, op0->getType());
1172
1179
} else {
1180
+ std::string s;
1181
+ llvm::raw_string_ostream ss(s);
1182
+ ss << *I.getParent()->getParent() << "\n" << *I.getParent() << "\n";
1183
+ ss << "cannot handle above cast " << I << "\n";
1184
+ if (CustomErrorHandler) {
1185
+ CustomErrorHandler(ss.str().c_str(), wrap(&I),
1186
+ ErrorType::NoDerivative, nullptr);
1187
+ }
1173
1188
TR.dump();
1174
- llvm::errs() << *I.getParent()->getParent() << "\n"
1175
- << *I.getParent() << "\n";
1176
- llvm::errs() << "cannot handle above cast " << I << "\n";
1189
+ llvm::errs() << ss.str() << "\n";
1177
1190
report_fatal_error("unknown instruction");
1178
1191
}
1179
1192
};
@@ -2204,24 +2217,30 @@ class AdjointGenerator
2204
2217
}
2205
2218
default:
2206
2219
def:;
2207
- llvm::errs() << *gutils->oldFunc->getParent() << "\n";
2208
- llvm::errs() << *gutils->oldFunc << "\n";
2220
+ std::string s;
2221
+ llvm::raw_string_ostream ss(s);
2222
+ ss << *gutils->oldFunc->getParent() << "\n";
2223
+ ss << *gutils->oldFunc << "\n";
2209
2224
for (auto &arg : gutils->oldFunc->args()) {
2210
- llvm::errs() << " constantarg[" << arg
2211
- << "] = " << gutils->isConstantValue(&arg)
2212
- << " type: " << TR.query(&arg).str() << " - vals: {";
2225
+ ss << " constantarg[" << arg << "] = " << gutils->isConstantValue(&arg)
2226
+ << " type: " << TR.query(&arg).str() << " - vals: {";
2213
2227
for (auto v : TR.knownIntegralValues(&arg))
2214
- llvm::errs() << v << ",";
2215
- llvm::errs() << "}\n";
2228
+ ss << v << ",";
2229
+ ss << "}\n";
2216
2230
}
2217
2231
for (auto &BB : *gutils->oldFunc)
2218
2232
for (auto &I : BB) {
2219
- llvm::errs() << " constantinst[" << I
2220
- << "] = " << gutils->isConstantInstruction(&I)
2221
- << " val:" << gutils->isConstantValue(&I)
2222
- << " type: " << TR.query(&I).str() << "\n";
2233
+ ss << " constantinst[" << I
2234
+ << "] = " << gutils->isConstantInstruction(&I)
2235
+ << " val:" << gutils->isConstantValue(&I)
2236
+ << " type: " << TR.query(&I).str() << "\n";
2223
2237
}
2224
- llvm::errs() << "cannot handle unknown binary operator: " << BO << "\n";
2238
+ ss << "cannot handle unknown binary operator: " << BO << "\n";
2239
+ if (CustomErrorHandler) {
2240
+ CustomErrorHandler(ss.str().c_str(), wrap(&BO), ErrorType::NoDerivative,
2241
+ nullptr);
2242
+ }
2243
+ llvm::errs() << ss.str() << "\n";
2225
2244
report_fatal_error("unknown binary operator");
2226
2245
}
2227
2246
@@ -2592,24 +2611,30 @@ class AdjointGenerator
2592
2611
}
2593
2612
default:
2594
2613
def:;
2595
- llvm::errs() << *gutils->oldFunc->getParent() << "\n";
2596
- llvm::errs() << *gutils->oldFunc << "\n";
2614
+ std::string s;
2615
+ llvm::raw_string_ostream ss(s);
2616
+ ss << *gutils->oldFunc->getParent() << "\n";
2617
+ ss << *gutils->oldFunc << "\n";
2597
2618
for (auto &arg : gutils->oldFunc->args()) {
2598
- llvm::errs() << " constantarg[" << arg
2599
- << "] = " << gutils->isConstantValue(&arg)
2600
- << " type: " << TR.query(&arg).str() << " - vals: {";
2619
+ ss << " constantarg[" << arg << "] = " << gutils->isConstantValue(&arg)
2620
+ << " type: " << TR.query(&arg).str() << " - vals: {";
2601
2621
for (auto v : TR.knownIntegralValues(&arg))
2602
- llvm::errs() << v << ",";
2603
- llvm::errs() << "}\n";
2622
+ ss << v << ",";
2623
+ ss << "}\n";
2604
2624
}
2605
2625
for (auto &BB : *gutils->oldFunc)
2606
2626
for (auto &I : BB) {
2607
- llvm::errs() << " constantinst[" << I
2608
- << "] = " << gutils->isConstantInstruction(&I)
2609
- << " val:" << gutils->isConstantValue(&I)
2610
- << " type: " << TR.query(&I).str() << "\n";
2627
+ ss << " constantinst[" << I
2628
+ << "] = " << gutils->isConstantInstruction(&I)
2629
+ << " val:" << gutils->isConstantValue(&I)
2630
+ << " type: " << TR.query(&I).str() << "\n";
2611
2631
}
2612
- llvm::errs() << "cannot handle unknown binary operator: " << BO << "\n";
2632
+ ss << "cannot handle unknown binary operator: " << BO << "\n";
2633
+ if (CustomErrorHandler) {
2634
+ CustomErrorHandler(ss.str().c_str(), wrap(&BO), ErrorType::NoDerivative,
2635
+ nullptr);
2636
+ }
2637
+ llvm::errs() << ss.str() << "\n";
2613
2638
report_fatal_error("unknown binary operator");
2614
2639
break;
2615
2640
}
@@ -2637,9 +2662,16 @@ class AdjointGenerator
2637
2662
}
2638
2663
2639
2664
if (!gutils->isConstantValue(orig_op1)) {
2640
- llvm::errs() << "couldn't handle non constant inst in memset to "
2641
- "propagate differential to\n"
2642
- << MS;
2665
+ std::string s;
2666
+ llvm::raw_string_ostream ss(s);
2667
+ ss << "couldn't handle non constant inst in memset to "
2668
+ "propagate differential to\n"
2669
+ << MS;
2670
+ if (CustomErrorHandler) {
2671
+ CustomErrorHandler(ss.str().c_str(), wrap(&MS), ErrorType::NoDerivative,
2672
+ nullptr);
2673
+ }
2674
+ llvm::errs() << ss.str() << "\n";
2643
2675
report_fatal_error("non constant in memset");
2644
2676
}
2645
2677
@@ -3123,9 +3155,16 @@ class AdjointGenerator
3123
3155
default:
3124
3156
if (gutils->isConstantInstruction(&I))
3125
3157
return;
3126
- llvm::errs() << *gutils->oldFunc << "\n";
3127
- llvm::errs() << *gutils->newFunc << "\n";
3128
- llvm::errs() << "cannot handle (augmented) unknown intrinsic\n" << I;
3158
+ std::string s;
3159
+ llvm::raw_string_ostream ss(s);
3160
+ ss << *gutils->oldFunc << "\n";
3161
+ ss << *gutils->newFunc << "\n";
3162
+ ss << "cannot handle (augmented) unknown intrinsic\n" << I;
3163
+ if (CustomErrorHandler) {
3164
+ CustomErrorHandler(ss.str().c_str(), wrap(&I),
3165
+ ErrorType::NoDerivative, nullptr);
3166
+ }
3167
+ llvm::errs() << ss.str() << "\n";
3129
3168
report_fatal_error("(augmented) unknown intrinsic");
3130
3169
}
3131
3170
return;
@@ -3649,25 +3688,32 @@ class AdjointGenerator
3649
3688
default:
3650
3689
if (gutils->isConstantInstruction(&I))
3651
3690
return;
3652
- llvm::errs() << *gutils->oldFunc << "\n";
3653
- llvm::errs() << *gutils->newFunc << "\n";
3691
+
3692
+ std::string s;
3693
+ llvm::raw_string_ostream ss(s);
3694
+ ss << *gutils->oldFunc << "\n";
3695
+ ss << *gutils->newFunc << "\n";
3654
3696
if (Intrinsic::isOverloaded(ID))
3655
3697
#if LLVM_VERSION_MAJOR >= 13
3656
- llvm::errs() << "cannot handle (reverse) unknown intrinsic\n"
3657
- << Intrinsic::getName(ID, ArrayRef<Type *>(),
3658
- gutils->oldFunc->getParent(),
3659
- nullptr)
3660
- << "\n"
3661
- << I;
3698
+ ss << "cannot handle (reverse) unknown intrinsic\n"
3699
+ << Intrinsic::getName(ID, ArrayRef<Type *>(),
3700
+ gutils->oldFunc->getParent(), nullptr)
3701
+ << "\n"
3702
+ << I;
3662
3703
#else
3663
- llvm::errs() << "cannot handle (reverse) unknown intrinsic\n"
3664
- << Intrinsic::getName(ID, ArrayRef<Type *>()) << "\n"
3665
- << I;
3704
+ ss << "cannot handle (reverse) unknown intrinsic\n"
3705
+ << Intrinsic::getName(ID, ArrayRef<Type *>()) << "\n"
3706
+ << I;
3666
3707
#endif
3667
3708
else
3668
- llvm::errs() << "cannot handle (reverse) unknown intrinsic\n"
3669
- << Intrinsic::getName(ID) << "\n"
3670
- << I;
3709
+ ss << "cannot handle (reverse) unknown intrinsic\n"
3710
+ << Intrinsic::getName(ID) << "\n"
3711
+ << I;
3712
+ if (CustomErrorHandler) {
3713
+ CustomErrorHandler(ss.str().c_str(), wrap(&I),
3714
+ ErrorType::NoDerivative, nullptr);
3715
+ }
3716
+ llvm::errs() << ss.str() << "\n";
3671
3717
report_fatal_error("(reverse) unknown intrinsic");
3672
3718
}
3673
3719
return;
@@ -4163,25 +4209,31 @@ class AdjointGenerator
4163
4209
default:
4164
4210
if (gutils->isConstantInstruction(&I))
4165
4211
return;
4166
- llvm::errs() << *gutils->oldFunc << "\n";
4167
- llvm::errs() << *gutils->newFunc << "\n";
4212
+ std::string s;
4213
+ llvm::raw_string_ostream ss(s);
4214
+ ss << *gutils->oldFunc << "\n";
4215
+ ss << *gutils->newFunc << "\n";
4168
4216
if (Intrinsic::isOverloaded(ID))
4169
4217
#if LLVM_VERSION_MAJOR >= 13
4170
- llvm::errs() << "cannot handle (forward) unknown intrinsic\n"
4171
- << Intrinsic::getName(ID, ArrayRef<Type *>(),
4172
- gutils->oldFunc->getParent(),
4173
- nullptr)
4174
- << "\n"
4175
- << I;
4218
+ ss << "cannot handle (forward) unknown intrinsic\n"
4219
+ << Intrinsic::getName(ID, ArrayRef<Type *>(),
4220
+ gutils->oldFunc->getParent(), nullptr)
4221
+ << "\n"
4222
+ << I;
4176
4223
#else
4177
- llvm::errs() << "cannot handle (forward) unknown intrinsic\n"
4178
- << Intrinsic::getName(ID, ArrayRef<Type *>()) << "\n"
4179
- << I;
4224
+ ss << "cannot handle (forward) unknown intrinsic\n"
4225
+ << Intrinsic::getName(ID, ArrayRef<Type *>()) << "\n"
4226
+ << I;
4180
4227
#endif
4181
4228
else
4182
- llvm::errs() << "cannot handle (forward) unknown intrinsic\n"
4183
- << Intrinsic::getName(ID) << "\n"
4184
- << I;
4229
+ ss << "cannot handle (forward) unknown intrinsic\n"
4230
+ << Intrinsic::getName(ID) << "\n"
4231
+ << I;
4232
+ if (CustomErrorHandler) {
4233
+ CustomErrorHandler(ss.str().c_str(), wrap(&I),
4234
+ ErrorType::NoDerivative, nullptr);
4235
+ }
4236
+ llvm::errs() << ss.str() << "\n";
4185
4237
report_fatal_error("(forward) unknown intrinsic");
4186
4238
}
4187
4239
return;
@@ -7056,10 +7108,17 @@ class AdjointGenerator
7056
7108
}
7057
7109
}
7058
7110
if (!isSum) {
7059
- llvm::errs() << *gutils->oldFunc << "\n";
7060
- llvm::errs() << *gutils->newFunc << "\n";
7061
- llvm::errs() << " call: " << call << "\n";
7062
- llvm::errs() << " unhandled mpi_allreduce op: " << *orig_op << "\n";
7111
+ std::string s;
7112
+ llvm::raw_string_ostream ss(s);
7113
+ ss << *gutils->oldFunc << "\n";
7114
+ ss << *gutils->newFunc << "\n";
7115
+ ss << " call: " << call << "\n";
7116
+ ss << " unhandled mpi_allreduce op: " << *orig_op << "\n";
7117
+ if (CustomErrorHandler) {
7118
+ CustomErrorHandler(ss.str().c_str(), wrap(&call),
7119
+ ErrorType::NoDerivative, nullptr);
7120
+ }
7121
+ llvm::errs() << ss.str() << "\n";
7063
7122
report_fatal_error("unhandled mpi_allreduce op");
7064
7123
}
7065
7124
@@ -7319,10 +7378,17 @@ class AdjointGenerator
7319
7378
}
7320
7379
}
7321
7380
if (!isSum) {
7322
- llvm::errs() << *gutils->oldFunc << "\n";
7323
- llvm::errs() << *gutils->newFunc << "\n";
7324
- llvm::errs() << " call: " << call << "\n";
7325
- llvm::errs() << " unhandled mpi_allreduce op: " << *orig_op << "\n";
7381
+ std::string s;
7382
+ llvm::raw_string_ostream ss(s);
7383
+ ss << *gutils->oldFunc << "\n";
7384
+ ss << *gutils->newFunc << "\n";
7385
+ ss << " call: " << call << "\n";
7386
+ ss << " unhandled mpi_allreduce op: " << *orig_op << "\n";
7387
+ if (CustomErrorHandler) {
7388
+ CustomErrorHandler(ss.str().c_str(), wrap(&call),
7389
+ ErrorType::NoDerivative, nullptr);
7390
+ }
7391
+ llvm::errs() << ss.str() << "\n";
7326
7392
report_fatal_error("unhandled mpi_allreduce op");
7327
7393
}
7328
7394
0 commit comments