Skip to content

Commit 9e76494

Browse files
wsmosesZuseZ4
andauthored
Handle ostream (rust-lang#242)
* Handle ostream * add more printFnc to utils.h * small refactor / formating * addressing feedback Co-authored-by: Manuel Drehwald <[email protected]>
1 parent 16ed986 commit 9e76494

File tree

4 files changed

+52
-26
lines changed

4 files changed

+52
-26
lines changed

enzyme/Enzyme/ActivityAnalysis.cpp

+12-2
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,18 @@ cl::opt<bool>
8080

8181
const char *KnownInactiveFunctionsStartingWith[] = {
8282
"_ZN4core3fmt", "_ZN3std2io5stdio6_print", "f90io", "$ss5print",
83-
"_ZNSt7__cxx1112basic_string"};
83+
"_ZNSt7__cxx1112basic_string",
84+
// ostream generic <<
85+
"_ZStlsISt11char_traitsIcEERSt13basic_ostreamIcT_ES5_",
86+
"_ZSt16__ostream_insert", "_ZNSo9_M_insert",
87+
// ostream put
88+
"_ZNSo3put",
89+
// std::cout
90+
"_ZSt4cout",
91+
// generic <<
92+
"_ZNSolsE",
93+
// std::endl
94+
"_ZNSo5flushEv", "_ZSt4endl"};
8495

8596
const char *KnownInactiveFunctionsContains[] = {
8697
"__enzyme_float", "__enzyme_double", "__enzyme_integer",
@@ -178,7 +189,6 @@ const std::set<std::string> KnownInactiveFunctions = {
178189
"_msize",
179190
"ftnio_fmt_write64",
180191
"f90_strcmp_klen",
181-
"vprintf",
182192
"__swift_instantiateConcreteTypeFromMangledName"};
183193

184194
/// Is the use of value val as an argument of call CI known to be inactive

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -4246,7 +4246,7 @@ void TypeAnalyzer::visitCallInst(CallInst &call) {
42464246
}
42474247

42484248
if (funcName == "__cxa_guard_acquire" || funcName == "printf" ||
4249-
funcName == "vprintf" || funcName == "puts") {
4249+
funcName == "vprintf" || funcName == "puts" || funcName == "fprintf") {
42504250
updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1), &call);
42514251
}
42524252

enzyme/Enzyme/Utils.h

+13-23
Original file line numberDiff line numberDiff line change
@@ -532,8 +532,16 @@ static inline bool isCertainPrintOrFree(llvm::Function *called) {
532532
return false;
533533

534534
if (called->getName() == "printf" || called->getName() == "puts" ||
535-
called->getName() == "fprintf" ||
535+
called->getName() == "fprintf" || called->getName() == "putchar" ||
536+
called->getName().startswith(
537+
"_ZStlsISt11char_traitsIcEERSt13basic_ostreamIcT_ES5_") ||
538+
called->getName().startswith("_ZNSolsE") ||
539+
called->getName().startswith("_ZNSo9_M_insert") ||
540+
called->getName().startswith("_ZSt16__ostream_insert") ||
541+
called->getName().startswith("_ZNSo3put") ||
542+
called->getName().startswith("_ZSt4endl") ||
536543
called->getName().startswith("_ZN3std2io5stdio6_print") ||
544+
called->getName().startswith("_ZNSo5flushEv") ||
537545
called->getName().startswith("_ZN4core3fmt") ||
538546
called->getName() == "vprintf" || called->getName() == "_ZdlPv" ||
539547
called->getName() == "_ZdlPvm" || called->getName() == "free" ||
@@ -562,30 +570,12 @@ static inline bool isCertainPrintMallocOrFree(llvm::Function *called) {
562570
if (called == nullptr)
563571
return false;
564572

565-
if (called->getName() == "printf" || called->getName() == "puts" ||
566-
called->getName() == "fprintf" ||
567-
called->getName().startswith("_ZN3std2io5stdio6_print") ||
568-
called->getName().startswith("_ZN4core3fmt") ||
569-
called->getName() == "vprintf" || called->getName() == "malloc" ||
570-
called->getName() == "swift_allocObject" ||
571-
called->getName() == "swift_release" || called->getName() == "_Znwm" ||
572-
called->getName() == "_ZdlPv" || called->getName() == "_ZdlPvm" ||
573-
called->getName() == "free" ||
574-
shadowHandlers.find(called->getName().str()) != shadowHandlers.end())
573+
if (isCertainPrintOrFree(called))
575574
return true;
576-
switch (called->getIntrinsicID()) {
577-
case llvm::Intrinsic::dbg_declare:
578-
case llvm::Intrinsic::dbg_value:
579-
#if LLVM_VERSION_MAJOR > 6
580-
case llvm::Intrinsic::dbg_label:
581-
#endif
582-
case llvm::Intrinsic::dbg_addr:
583-
case llvm::Intrinsic::lifetime_start:
584-
case llvm::Intrinsic::lifetime_end:
575+
576+
if (isCertainMallocOrFree(called))
585577
return true;
586-
default:
587-
break;
588-
}
578+
589579
return false;
590580
}
591581

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
2+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
3+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
4+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
5+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
6+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
7+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
8+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
9+
10+
#include <iostream>
11+
#include "test_utils.h"
12+
13+
extern double __enzyme_autodiff(void*, double);
14+
15+
double fn(double vec) {
16+
std::cout << "hello" << 7 << '7' << std::endl;
17+
std::cerr << vec << vec * vec << "\n";
18+
return vec * vec;
19+
}
20+
21+
int main() {
22+
double x = 2.1;
23+
double dsq = __enzyme_autodiff((void*)fn, x);
24+
25+
APPROX_EQ(dsq, 2 * x, 1e-7);
26+
}

0 commit comments

Comments
 (0)