Skip to content

Commit e4c100e

Browse files
authored
handle more IO (rust-lang#606)
* handle cin * wip, more input-fncs * adding various IO handling * fix ci * extract cerr test, extra handling * wcout and local, still buggy * handle different mangling * add allocator as inactive * update Cacheable List and rename Fn * update if condition, undo renaming * remove unconditional print * make condition more precise * Undo isCertainPrintMallocOrFree changes
1 parent 55114ed commit e4c100e

File tree

5 files changed

+185
-10
lines changed

5 files changed

+185
-10
lines changed

enzyme/Enzyme/ActivityAnalysis.cpp

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,53 @@ cl::opt<bool>
7979
#include <unordered_map>
8080

8181
const char *KnownInactiveFunctionsStartingWith[] = {
82-
"_ZN4core3fmt", "_ZN3std2io5stdio6_print", "f90io", "$ss5print",
82+
"_ZN4core3fmt",
83+
"_ZN3std2io5stdio6_print",
84+
"f90io",
85+
"$ss5print",
8386
"_ZNSt7__cxx1112basic_string",
87+
"_ZNSt7__cxx1118basic_string",
8488
// ostream generic <<
85-
"_ZStlsISt11char_traitsIcEERSt13basic_ostreamIcT_ES5_",
86-
"_ZSt16__ostream_insert", "_ZNSo9_M_insert",
89+
"_ZStlsISt11char_traitsIcEERSt13basic_ostream",
90+
"_ZSt16__ostream_insert",
91+
"_ZStlsIwSt11char_traitsIwEERSt13basic_ostream",
92+
"_ZNSo9_M_insert",
93+
// ostream wchar
94+
"_ZNSt13basic_ostream",
8795
// ostream put
8896
"_ZNSo3put",
97+
// std::istream: widen_init, get, getline, >>, sync, ignore
98+
"_ZNKSt5ctypeIcE13_M_widen_init",
99+
"_ZNSi3get",
100+
"_ZNSi7getline",
101+
"_ZNSirsER",
102+
"_ZNSt7__cxx1115basic_stringbuf",
103+
"_ZNSi6ignore",
104+
// std::ios_base
105+
"_ZNSt8ios_base",
106+
"_ZNSt9basic_ios",
107+
"_ZStorSt13_Ios_OpenmodeS_",
108+
// std::local
109+
"_ZNSt6locale",
110+
"_ZNKSt6locale4name",
111+
// init
112+
"_ZStL8__ioinit"
113+
"_ZNSt9basic_ios",
89114
// std::cout
90115
"_ZSt4cout",
116+
// std::cin
117+
"_ZSt3cin",
118+
"_ZNSi10_M_extract",
91119
// generic <<
92120
"_ZNSolsE",
121+
// std::flush
122+
"_ZSt5flush",
123+
"_ZNSo5flush",
93124
// std::endl
94-
"_ZNSo5flushEv", "_ZSt4endl"};
125+
"_ZSt4endl",
126+
// std::allocator
127+
"_ZNSaIcE",
128+
};
95129

96130
const char *KnownInactiveFunctionsContains[] = {
97131
"__enzyme_float", "__enzyme_double", "__enzyme_integer",
@@ -104,6 +138,22 @@ const std::set<std::string> InactiveGlobals = {
104138
"stderr",
105139
"stdout",
106140
"stdin",
141+
"_ZSt3cin",
142+
"_ZSt4cout",
143+
"_ZSt5wcout",
144+
"_ZSt4cerr",
145+
"_ZTVNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEEE",
146+
"_ZTVSt15basic_streambufIcSt11char_traitsIcEE",
147+
"_ZTVSt9basic_iosIcSt11char_traitsIcEE",
148+
// istream
149+
"_ZTVNSt7__cxx1119basic_istringstreamIcSt11char_traitsIcESaIcEEE",
150+
"_ZTTNSt7__cxx1119basic_istringstreamIcSt11char_traitsIcESaIcEEE",
151+
// ostream
152+
"_ZTVNSt7__cxx1119basic_ostringstreamIcSt11char_traitsIcESaIcEEE",
153+
"_ZTTNSt7__cxx1119basic_ostringstreamIcSt11char_traitsIcESaIcEEE",
154+
// stringstream
155+
"_ZTVNSt7__cxx1118basic_stringstreamIcSt11char_traitsIcESaIcEEE",
156+
"_ZTTNSt7__cxx1118basic_stringstreamIcSt11char_traitsIcESaIcEEE",
107157
};
108158

109159
const std::map<std::string, size_t> MPIInactiveCommAllocators = {
@@ -861,7 +911,8 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
861911

862912
// If this global is unchanging and the internal constant data
863913
// is inactive, the global is inactive
864-
if (GI->isConstant() && isConstantValue(TR, GI->getInitializer())) {
914+
if (GI->isConstant() && GI->hasInitializer() &&
915+
isConstantValue(TR, GI->getInitializer())) {
865916
InsertConstantValue(TR, Val);
866917
if (EnzymePrintActivity)
867918
llvm::errs() << " VALUE const global " << *Val
@@ -1136,10 +1187,12 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
11361187

11371188
// If requesting empty unknown functions to be considered inactive,
11381189
// abide by those rules
1139-
if (!isCertainPrintMallocOrFree(called) && called->empty() &&
1190+
if (EnzymeEmptyFnInactive && called->empty() &&
11401191
!hasMetadata(called, "enzyme_gradient") &&
11411192
!hasMetadata(called, "enzyme_derivative") &&
1142-
!isa<IntrinsicInst>(op) && EnzymeEmptyFnInactive) {
1193+
!isAllocationFunction(*called, TLI) &&
1194+
!isDeallocationFunction(*called, TLI) &&
1195+
!isa<IntrinsicInst>(op)) {
11431196
InsertConstantValue(TR, Val);
11441197
insertConstantsFrom(TR, *UpHypothesis);
11451198
return true;
@@ -1816,10 +1869,11 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults &TR,
18161869

18171870
// If requesting empty unknown functions to be considered inactive, abide
18181871
// by those rules
1819-
if (!isCertainPrintMallocOrFree(called) && called->empty() &&
1872+
if (EnzymeEmptyFnInactive && called->empty() &&
18201873
!hasMetadata(called, "enzyme_gradient") &&
18211874
!hasMetadata(called, "enzyme_derivative") &&
1822-
!isa<IntrinsicInst>(op) && EnzymeEmptyFnInactive) {
1875+
!isAllocationFunction(*called, TLI) &&
1876+
!isDeallocationFunction(*called, TLI) && !isa<IntrinsicInst>(op)) {
18231877
if (EnzymePrintActivity)
18241878
llvm::errs() << "constant(" << (int)directions << ") up-emptyconst "
18251879
<< *inst << "\n";
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
2+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
3+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
4+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
5+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
6+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
7+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
8+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
9+
10+
#include "test_utils.h"
11+
#include <iostream>
12+
#include <sstream>
13+
#include <utility>
14+
15+
extern double __enzyme_autodiff(void*, double);
16+
17+
double fn(double vec) {
18+
std::cerr << "foo" << std::endl;
19+
std::cerr << "foo" << std::flush;
20+
std::cerr << "foo" << '\n';
21+
std::flush(std::cerr);
22+
23+
return vec * vec;
24+
}
25+
26+
int main() {
27+
double x = 2.1;
28+
29+
double dsq = __enzyme_autodiff((void *)fn, x);
30+
APPROX_EQ(dsq, 2 * x, 1e-7);
31+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
2+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
3+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
4+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
5+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
6+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
7+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
8+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
9+
10+
#include "test_utils.h"
11+
#include <iostream>
12+
#include <sstream>
13+
#include <utility>
14+
15+
extern double __enzyme_autodiff(void*, double);
16+
17+
double fn(double vec) {
18+
std::stringstream testInput("14 1.5 somerandomextrachars");
19+
double in;
20+
float in2;
21+
testInput >> in >> in2;
22+
23+
testInput.ignore();
24+
25+
char ch;
26+
testInput.get(ch);
27+
28+
char foo[5];
29+
const char fdelim = '\t';
30+
testInput.get(foo, 3, fdelim);
31+
32+
// The following two lines cause a segfault with Enzyme
33+
// char bar[5];
34+
// testInput.getline(bar, 3);
35+
36+
return vec * vec * in * in2;
37+
}
38+
39+
int main() {
40+
double x = 2.1;
41+
double dsq = __enzyme_autodiff((void*)fn, x);
42+
43+
APPROX_EQ(dsq, 14 * 1.5 * 2 * x, 1e-7);
44+
}

enzyme/test/Integration/ReverseMode/cout.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ extern double __enzyme_autodiff(void*, double);
1414

1515
double fn(double vec) {
1616
std::cout << "hello" << 7 << '7' << std::endl;
17-
std::cerr << vec << vec * vec << "\n";
17+
std::cout << "foo" << std::flush;
1818
return vec * vec;
1919
}
2020

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
2+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
3+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
4+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S
5+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
6+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
7+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
8+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S
9+
10+
#include <iostream>
11+
#include "test_utils.h"
12+
13+
#include <stdlib.h>
14+
#include <stdio.h>
15+
#include <stdbool.h>
16+
#include <math.h>
17+
18+
19+
extern double __enzyme_autodiff(void*, double);
20+
21+
// testing both std::wcout and std::locale
22+
// https://en.cppreference.com/w/cpp/locale/locale
23+
double fn(double vec) {
24+
std::wcout.put('f');
25+
std::wcout << 1 << 1.0 << "somerandomchars";
26+
std::locale::global(std::locale(""));
27+
std::wcout.sync_with_stdio();
28+
std::wcout.imbue(std::locale());
29+
30+
// Currently not working
31+
// std::wcout << "User-preferred locale setting is ";
32+
// std::wcout << std::locale("").name().c_str();
33+
// std::wcout << 1000.01 << '\n';
34+
// std::wcout << 1000.01 << std::endl;
35+
// std::wcout << 1000.01 << std::flush;
36+
37+
return vec * vec;
38+
}
39+
40+
int main() {
41+
double x = 2.1;
42+
double dsq = __enzyme_autodiff((void*)fn, x);
43+
44+
APPROX_EQ(dsq, 2 * x, 1e-7);
45+
}
46+

0 commit comments

Comments
 (0)