Skip to content

Commit 85524a3

Browse files
authored
Improve debug handler (rust-lang#675)
1 parent 3fd233c commit 85524a3

File tree

1 file changed

+140
-74
lines changed

1 file changed

+140
-74
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 140 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -417,11 +417,18 @@ class AdjointGenerator
417417
}
418418
#endif
419419

420-
llvm::errs() << *gutils->oldFunc << "\n";
421-
llvm::errs() << *gutils->newFunc << "\n";
422-
llvm::errs() << "in Mode: " << to_string(Mode) << "\n";
423-
llvm::errs() << "cannot handle unknown instruction\n" << inst;
424-
report_fatal_error("unknown value");
420+
std::string s;
421+
llvm::raw_string_ostream ss(s);
422+
ss << *gutils->oldFunc << "\n";
423+
ss << *gutils->newFunc << "\n";
424+
ss << "in Mode: " << to_string(Mode) << "\n";
425+
ss << "cannot handle unknown instruction\n" << inst;
426+
if (CustomErrorHandler) {
427+
CustomErrorHandler(ss.str().c_str(), wrap(&inst), ErrorType::NoDerivative,
428+
nullptr);
429+
}
430+
llvm::errs() << ss.str() << "\n";
431+
report_fatal_error("unknown instruction");
425432
}
426433

427434
// Common function for falling back to the implementation
@@ -1170,10 +1177,16 @@ class AdjointGenerator
11701177
// TODO CHECK THIS
11711178
return Builder2.CreateZExt(dif, op0->getType());
11721179
} else {
1180+
std::string s;
1181+
llvm::raw_string_ostream ss(s);
1182+
ss << *I.getParent()->getParent() << "\n" << *I.getParent() << "\n";
1183+
ss << "cannot handle above cast " << I << "\n";
1184+
if (CustomErrorHandler) {
1185+
CustomErrorHandler(ss.str().c_str(), wrap(&I),
1186+
ErrorType::NoDerivative, nullptr);
1187+
}
11731188
TR.dump();
1174-
llvm::errs() << *I.getParent()->getParent() << "\n"
1175-
<< *I.getParent() << "\n";
1176-
llvm::errs() << "cannot handle above cast " << I << "\n";
1189+
llvm::errs() << ss.str() << "\n";
11771190
report_fatal_error("unknown instruction");
11781191
}
11791192
};
@@ -2204,24 +2217,30 @@ class AdjointGenerator
22042217
}
22052218
default:
22062219
def:;
2207-
llvm::errs() << *gutils->oldFunc->getParent() << "\n";
2208-
llvm::errs() << *gutils->oldFunc << "\n";
2220+
std::string s;
2221+
llvm::raw_string_ostream ss(s);
2222+
ss << *gutils->oldFunc->getParent() << "\n";
2223+
ss << *gutils->oldFunc << "\n";
22092224
for (auto &arg : gutils->oldFunc->args()) {
2210-
llvm::errs() << " constantarg[" << arg
2211-
<< "] = " << gutils->isConstantValue(&arg)
2212-
<< " type: " << TR.query(&arg).str() << " - vals: {";
2225+
ss << " constantarg[" << arg << "] = " << gutils->isConstantValue(&arg)
2226+
<< " type: " << TR.query(&arg).str() << " - vals: {";
22132227
for (auto v : TR.knownIntegralValues(&arg))
2214-
llvm::errs() << v << ",";
2215-
llvm::errs() << "}\n";
2228+
ss << v << ",";
2229+
ss << "}\n";
22162230
}
22172231
for (auto &BB : *gutils->oldFunc)
22182232
for (auto &I : BB) {
2219-
llvm::errs() << " constantinst[" << I
2220-
<< "] = " << gutils->isConstantInstruction(&I)
2221-
<< " val:" << gutils->isConstantValue(&I)
2222-
<< " type: " << TR.query(&I).str() << "\n";
2233+
ss << " constantinst[" << I
2234+
<< "] = " << gutils->isConstantInstruction(&I)
2235+
<< " val:" << gutils->isConstantValue(&I)
2236+
<< " type: " << TR.query(&I).str() << "\n";
22232237
}
2224-
llvm::errs() << "cannot handle unknown binary operator: " << BO << "\n";
2238+
ss << "cannot handle unknown binary operator: " << BO << "\n";
2239+
if (CustomErrorHandler) {
2240+
CustomErrorHandler(ss.str().c_str(), wrap(&BO), ErrorType::NoDerivative,
2241+
nullptr);
2242+
}
2243+
llvm::errs() << ss.str() << "\n";
22252244
report_fatal_error("unknown binary operator");
22262245
}
22272246

@@ -2592,24 +2611,30 @@ class AdjointGenerator
25922611
}
25932612
default:
25942613
def:;
2595-
llvm::errs() << *gutils->oldFunc->getParent() << "\n";
2596-
llvm::errs() << *gutils->oldFunc << "\n";
2614+
std::string s;
2615+
llvm::raw_string_ostream ss(s);
2616+
ss << *gutils->oldFunc->getParent() << "\n";
2617+
ss << *gutils->oldFunc << "\n";
25972618
for (auto &arg : gutils->oldFunc->args()) {
2598-
llvm::errs() << " constantarg[" << arg
2599-
<< "] = " << gutils->isConstantValue(&arg)
2600-
<< " type: " << TR.query(&arg).str() << " - vals: {";
2619+
ss << " constantarg[" << arg << "] = " << gutils->isConstantValue(&arg)
2620+
<< " type: " << TR.query(&arg).str() << " - vals: {";
26012621
for (auto v : TR.knownIntegralValues(&arg))
2602-
llvm::errs() << v << ",";
2603-
llvm::errs() << "}\n";
2622+
ss << v << ",";
2623+
ss << "}\n";
26042624
}
26052625
for (auto &BB : *gutils->oldFunc)
26062626
for (auto &I : BB) {
2607-
llvm::errs() << " constantinst[" << I
2608-
<< "] = " << gutils->isConstantInstruction(&I)
2609-
<< " val:" << gutils->isConstantValue(&I)
2610-
<< " type: " << TR.query(&I).str() << "\n";
2627+
ss << " constantinst[" << I
2628+
<< "] = " << gutils->isConstantInstruction(&I)
2629+
<< " val:" << gutils->isConstantValue(&I)
2630+
<< " type: " << TR.query(&I).str() << "\n";
26112631
}
2612-
llvm::errs() << "cannot handle unknown binary operator: " << BO << "\n";
2632+
ss << "cannot handle unknown binary operator: " << BO << "\n";
2633+
if (CustomErrorHandler) {
2634+
CustomErrorHandler(ss.str().c_str(), wrap(&BO), ErrorType::NoDerivative,
2635+
nullptr);
2636+
}
2637+
llvm::errs() << ss.str() << "\n";
26132638
report_fatal_error("unknown binary operator");
26142639
break;
26152640
}
@@ -2637,9 +2662,16 @@ class AdjointGenerator
26372662
}
26382663

26392664
if (!gutils->isConstantValue(orig_op1)) {
2640-
llvm::errs() << "couldn't handle non constant inst in memset to "
2641-
"propagate differential to\n"
2642-
<< MS;
2665+
std::string s;
2666+
llvm::raw_string_ostream ss(s);
2667+
ss << "couldn't handle non constant inst in memset to "
2668+
"propagate differential to\n"
2669+
<< MS;
2670+
if (CustomErrorHandler) {
2671+
CustomErrorHandler(ss.str().c_str(), wrap(&MS), ErrorType::NoDerivative,
2672+
nullptr);
2673+
}
2674+
llvm::errs() << ss.str() << "\n";
26432675
report_fatal_error("non constant in memset");
26442676
}
26452677

@@ -3123,9 +3155,16 @@ class AdjointGenerator
31233155
default:
31243156
if (gutils->isConstantInstruction(&I))
31253157
return;
3126-
llvm::errs() << *gutils->oldFunc << "\n";
3127-
llvm::errs() << *gutils->newFunc << "\n";
3128-
llvm::errs() << "cannot handle (augmented) unknown intrinsic\n" << I;
3158+
std::string s;
3159+
llvm::raw_string_ostream ss(s);
3160+
ss << *gutils->oldFunc << "\n";
3161+
ss << *gutils->newFunc << "\n";
3162+
ss << "cannot handle (augmented) unknown intrinsic\n" << I;
3163+
if (CustomErrorHandler) {
3164+
CustomErrorHandler(ss.str().c_str(), wrap(&I),
3165+
ErrorType::NoDerivative, nullptr);
3166+
}
3167+
llvm::errs() << ss.str() << "\n";
31293168
report_fatal_error("(augmented) unknown intrinsic");
31303169
}
31313170
return;
@@ -3649,25 +3688,32 @@ class AdjointGenerator
36493688
default:
36503689
if (gutils->isConstantInstruction(&I))
36513690
return;
3652-
llvm::errs() << *gutils->oldFunc << "\n";
3653-
llvm::errs() << *gutils->newFunc << "\n";
3691+
3692+
std::string s;
3693+
llvm::raw_string_ostream ss(s);
3694+
ss << *gutils->oldFunc << "\n";
3695+
ss << *gutils->newFunc << "\n";
36543696
if (Intrinsic::isOverloaded(ID))
36553697
#if LLVM_VERSION_MAJOR >= 13
3656-
llvm::errs() << "cannot handle (reverse) unknown intrinsic\n"
3657-
<< Intrinsic::getName(ID, ArrayRef<Type *>(),
3658-
gutils->oldFunc->getParent(),
3659-
nullptr)
3660-
<< "\n"
3661-
<< I;
3698+
ss << "cannot handle (reverse) unknown intrinsic\n"
3699+
<< Intrinsic::getName(ID, ArrayRef<Type *>(),
3700+
gutils->oldFunc->getParent(), nullptr)
3701+
<< "\n"
3702+
<< I;
36623703
#else
3663-
llvm::errs() << "cannot handle (reverse) unknown intrinsic\n"
3664-
<< Intrinsic::getName(ID, ArrayRef<Type *>()) << "\n"
3665-
<< I;
3704+
ss << "cannot handle (reverse) unknown intrinsic\n"
3705+
<< Intrinsic::getName(ID, ArrayRef<Type *>()) << "\n"
3706+
<< I;
36663707
#endif
36673708
else
3668-
llvm::errs() << "cannot handle (reverse) unknown intrinsic\n"
3669-
<< Intrinsic::getName(ID) << "\n"
3670-
<< I;
3709+
ss << "cannot handle (reverse) unknown intrinsic\n"
3710+
<< Intrinsic::getName(ID) << "\n"
3711+
<< I;
3712+
if (CustomErrorHandler) {
3713+
CustomErrorHandler(ss.str().c_str(), wrap(&I),
3714+
ErrorType::NoDerivative, nullptr);
3715+
}
3716+
llvm::errs() << ss.str() << "\n";
36713717
report_fatal_error("(reverse) unknown intrinsic");
36723718
}
36733719
return;
@@ -4163,25 +4209,31 @@ class AdjointGenerator
41634209
default:
41644210
if (gutils->isConstantInstruction(&I))
41654211
return;
4166-
llvm::errs() << *gutils->oldFunc << "\n";
4167-
llvm::errs() << *gutils->newFunc << "\n";
4212+
std::string s;
4213+
llvm::raw_string_ostream ss(s);
4214+
ss << *gutils->oldFunc << "\n";
4215+
ss << *gutils->newFunc << "\n";
41684216
if (Intrinsic::isOverloaded(ID))
41694217
#if LLVM_VERSION_MAJOR >= 13
4170-
llvm::errs() << "cannot handle (forward) unknown intrinsic\n"
4171-
<< Intrinsic::getName(ID, ArrayRef<Type *>(),
4172-
gutils->oldFunc->getParent(),
4173-
nullptr)
4174-
<< "\n"
4175-
<< I;
4218+
ss << "cannot handle (forward) unknown intrinsic\n"
4219+
<< Intrinsic::getName(ID, ArrayRef<Type *>(),
4220+
gutils->oldFunc->getParent(), nullptr)
4221+
<< "\n"
4222+
<< I;
41764223
#else
4177-
llvm::errs() << "cannot handle (forward) unknown intrinsic\n"
4178-
<< Intrinsic::getName(ID, ArrayRef<Type *>()) << "\n"
4179-
<< I;
4224+
ss << "cannot handle (forward) unknown intrinsic\n"
4225+
<< Intrinsic::getName(ID, ArrayRef<Type *>()) << "\n"
4226+
<< I;
41804227
#endif
41814228
else
4182-
llvm::errs() << "cannot handle (forward) unknown intrinsic\n"
4183-
<< Intrinsic::getName(ID) << "\n"
4184-
<< I;
4229+
ss << "cannot handle (forward) unknown intrinsic\n"
4230+
<< Intrinsic::getName(ID) << "\n"
4231+
<< I;
4232+
if (CustomErrorHandler) {
4233+
CustomErrorHandler(ss.str().c_str(), wrap(&I),
4234+
ErrorType::NoDerivative, nullptr);
4235+
}
4236+
llvm::errs() << ss.str() << "\n";
41854237
report_fatal_error("(forward) unknown intrinsic");
41864238
}
41874239
return;
@@ -7056,10 +7108,17 @@ class AdjointGenerator
70567108
}
70577109
}
70587110
if (!isSum) {
7059-
llvm::errs() << *gutils->oldFunc << "\n";
7060-
llvm::errs() << *gutils->newFunc << "\n";
7061-
llvm::errs() << " call: " << call << "\n";
7062-
llvm::errs() << " unhandled mpi_allreduce op: " << *orig_op << "\n";
7111+
std::string s;
7112+
llvm::raw_string_ostream ss(s);
7113+
ss << *gutils->oldFunc << "\n";
7114+
ss << *gutils->newFunc << "\n";
7115+
ss << " call: " << call << "\n";
7116+
ss << " unhandled mpi_allreduce op: " << *orig_op << "\n";
7117+
if (CustomErrorHandler) {
7118+
CustomErrorHandler(ss.str().c_str(), wrap(&call),
7119+
ErrorType::NoDerivative, nullptr);
7120+
}
7121+
llvm::errs() << ss.str() << "\n";
70637122
report_fatal_error("unhandled mpi_allreduce op");
70647123
}
70657124

@@ -7319,10 +7378,17 @@ class AdjointGenerator
73197378
}
73207379
}
73217380
if (!isSum) {
7322-
llvm::errs() << *gutils->oldFunc << "\n";
7323-
llvm::errs() << *gutils->newFunc << "\n";
7324-
llvm::errs() << " call: " << call << "\n";
7325-
llvm::errs() << " unhandled mpi_allreduce op: " << *orig_op << "\n";
7381+
std::string s;
7382+
llvm::raw_string_ostream ss(s);
7383+
ss << *gutils->oldFunc << "\n";
7384+
ss << *gutils->newFunc << "\n";
7385+
ss << " call: " << call << "\n";
7386+
ss << " unhandled mpi_allreduce op: " << *orig_op << "\n";
7387+
if (CustomErrorHandler) {
7388+
CustomErrorHandler(ss.str().c_str(), wrap(&call),
7389+
ErrorType::NoDerivative, nullptr);
7390+
}
7391+
llvm::errs() << ss.str() << "\n";
73267392
report_fatal_error("unhandled mpi_allreduce op");
73277393
}
73287394

0 commit comments

Comments
 (0)