Skip to content

Commit 9013de3

Browse files
authored
Add inactive global specifier (rust-lang#938)
* Add inactive global * Add custom inactive global specification
1 parent b14088f commit 9013de3

File tree

3 files changed

+70
-0
lines changed

3 files changed

+70
-0
lines changed

enzyme/Enzyme/ActivityAnalysis.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -1092,6 +1092,10 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
10921092
InsertConstantValue(TR, Val);
10931093
return true;
10941094
}
1095+
if (hasMetadata(GI, "enzyme_inactive")) {
1096+
InsertConstantValue(TR, Val);
1097+
return true;
1098+
}
10951099

10961100
if (GI->getName().contains("enzyme_const") ||
10971101
InactiveGlobals.count(GI->getName().str())) {
@@ -2571,6 +2575,9 @@ bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults const &TR,
25712575
EnzymeNonmarkedGlobalsInactive) {
25722576
continue;
25732577
}
2578+
if (hasMetadata(GV, "enzyme_inactive")) {
2579+
continue;
2580+
}
25742581
if (GV->getName().contains("enzyme_const") ||
25752582
InactiveGlobals.count(GV->getName().str())) {
25762583
continue;

enzyme/Enzyme/PreserveNVVM.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,32 @@ class PreserveNVVM final : public FunctionPass {
150150
}
151151
}
152152
}
153+
SmallVector<GlobalVariable *, 1> toErase;
154+
for (GlobalVariable &g : F.getParent()->globals()) {
155+
if (g.getName().contains("__enzyme_inactive_global")) {
156+
if (g.hasInitializer()) {
157+
Value *V = g.getInitializer();
158+
while (1) {
159+
if (auto CE = dyn_cast<ConstantExpr>(V)) {
160+
V = CE->getOperand(0);
161+
continue;
162+
}
163+
if (auto CA = dyn_cast<ConstantAggregate>(V)) {
164+
V = CA->getOperand(0);
165+
continue;
166+
}
167+
break;
168+
}
169+
if (auto GV = cast<GlobalVariable>(V)) {
170+
GV->setMetadata("enzyme_inactive", MDNode::get(g.getContext(), {}));
171+
toErase.push_back(&g);
172+
}
173+
}
174+
}
175+
}
176+
for (auto G : toErase)
177+
G->eraseFromParent();
178+
153179
if (!Begin && F.hasFnAttribute("prev_fixup")) {
154180
changed = true;
155181
F.removeFnAttr("prev_fixup");
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -O0 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
2+
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
3+
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
4+
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
5+
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -O0 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
6+
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
7+
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
8+
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
9+
10+
#include "test_utils.h"
11+
12+
double __enzyme_autodiff(void*, ...);
13+
14+
enum class MyMemoryType
15+
{
16+
DEFAULT
17+
};
18+
19+
extern MyMemoryType host_mem_type;
20+
21+
__attribute__((noinline))
22+
void* alloc(int size, MyMemoryType mt) {
23+
return malloc(size);
24+
}
25+
26+
double square(double a)
27+
{
28+
double* D = (double*)alloc(sizeof(double), host_mem_type);
29+
D[0] = a;
30+
return D[0];
31+
}
32+
void* __enzyme_inactive_global = &host_mem_type;
33+
34+
int main() {
35+
double out = __enzyme_autodiff((void*)square, 10.0);
36+
APPROX_EQ(out, 1.0, 1e-7);
37+
}

0 commit comments

Comments
 (0)