Skip to content

Commit ea6916f

Browse files
wsmosesvchuravy
andcommitted
C API (#93)
* Initial Enzyme C API * Cpp api * add merge on TypeTree * change bool to uint8_t * fixup! change bool to uint8_t * fixup! fixup! change bool to uint8_t * fixup! fixup! fixup! change bool to uint8_t * fixup! fixup! fixup! fixup! change bool to uint8_t * Validate c header for rust usage Co-authored-by: Valentin Churavy <[email protected]>
1 parent b0ccebe commit ea6916f

File tree

10 files changed

+664
-56
lines changed

10 files changed

+664
-56
lines changed

Diff for: enzyme/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ file(READ ${LLVM_IDIR}/llvm/Analysis/ScalarEvolution.h INPUT_TEXT)
6262
string(REPLACE private public INPUT_TEXT "${INPUT_TEXT}")
6363
file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/include/SCEV/ScalarEvolution.h" "${INPUT_TEXT}")
6464

65+
66+
file(READ ${LLVM_IDIR}/llvm/Analysis/TargetLibraryInfo.h INPUT_TEXT)
67+
string(REPLACE "class TargetLibraryInfo {" "class TargetLibraryInfo {public:" INPUT_TEXT "${INPUT_TEXT}")
68+
file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/include/SCEV/TargetLibraryInfo.h" "${INPUT_TEXT}")
69+
6570
if (${LLVM_VERSION_MAJOR} GREATER_EQUAL 11)
6671
file(READ ${LLVM_IDIR}/llvm/Transforms/Utils/ScalarEvolutionExpander.h INPUT_TEXT)
6772
else()

Diff for: enzyme/Enzyme/CApi.cpp

+336
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,336 @@
1+
//===- CApi.cpp - Enzyme API exported to C for external use -----------===//
2+
//
3+
// Enzyme Project
4+
//
5+
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6+
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
//
9+
// If using this code in an academic setting, please cite the following:
10+
// @incollection{enzymeNeurips,
11+
// title = {Instead of Rewriting Foreign Code for Machine Learning,
12+
// Automatically Synthesize Fast Gradients},
13+
// author = {Moses, William S. and Churavy, Valentin},
14+
// booktitle = {Advances in Neural Information Processing Systems 33},
15+
// year = {2020},
16+
// note = {To appear in},
17+
// }
18+
//
19+
//===----------------------------------------------------------------------===//
20+
//
21+
// This file defines various utility functions of Enzyme for access via C
22+
//
23+
//===----------------------------------------------------------------------===//
24+
#include "CApi.h"
25+
#include "EnzymeLogic.h"
26+
#include "SCEV/TargetLibraryInfo.h"
27+
28+
#include "llvm/ADT/Triple.h"
29+
#include "llvm/Analysis/CallGraph.h"
30+
#include "llvm/Analysis/GlobalsModRef.h"
31+
32+
using namespace llvm;
33+
34+
TargetLibraryInfo eunwrap(LLVMTargetLibraryInfoRef P) {
35+
return TargetLibraryInfo(*reinterpret_cast<TargetLibraryInfoImpl *>(P));
36+
}
37+
38+
TypeAnalysis &eunwrap(EnzymeTypeAnalysisRef TAR) {
39+
return *(TypeAnalysis *)TAR;
40+
}
41+
llvm::AAResults &eunwrap(EnzymeAAResultsRef AAR) {
42+
return *(llvm::AAResults *)AAR.AA;
43+
}
44+
AugmentedReturn *eunwrap(EnzymeAugmentedReturnPtr ARP) {
45+
return (AugmentedReturn *)ARP;
46+
}
47+
EnzymeAugmentedReturnPtr ewrap(const AugmentedReturn &AR) {
48+
return (EnzymeAugmentedReturnPtr)(&AR);
49+
}
50+
51+
ConcreteType eunwrap(CConcreteType CDT, llvm::LLVMContext &ctx) {
52+
switch (CDT) {
53+
case DT_Anything:
54+
return BaseType::Anything;
55+
case DT_Integer:
56+
return BaseType::Integer;
57+
case DT_Pointer:
58+
return BaseType::Pointer;
59+
case DT_Half:
60+
return ConcreteType(llvm::Type::getHalfTy(ctx));
61+
case DT_Float:
62+
return ConcreteType(llvm::Type::getFloatTy(ctx));
63+
case DT_Double:
64+
return ConcreteType(llvm::Type::getDoubleTy(ctx));
65+
case DT_Unknown:
66+
return BaseType::Unknown;
67+
}
68+
llvm_unreachable("Unknown concrete type to unwrap");
69+
}
70+
71+
std::vector<int> eunwrap(IntList IL) {
72+
std::vector<int> v;
73+
for (size_t i = 0; i < IL.size; i++) {
74+
v.push_back((int)IL.data[i]);
75+
}
76+
return v;
77+
}
78+
std::set<int64_t> eunwrap64(IntList IL) {
79+
std::set<int64_t> v;
80+
for (size_t i = 0; i < IL.size; i++) {
81+
v.insert((int64_t)IL.data[i]);
82+
}
83+
return v;
84+
}
85+
TypeTree eunwrap(CTypeTreeRef CTT) { return *(TypeTree *)CTT; }
86+
87+
CConcreteType ewrap(const ConcreteType &CT) {
88+
if (auto flt = CT.isFloat()) {
89+
if (flt->isHalfTy())
90+
return DT_Half;
91+
if (flt->isFloatTy())
92+
return DT_Float;
93+
if (flt->isDoubleTy())
94+
return DT_Double;
95+
} else {
96+
switch (CT.SubTypeEnum) {
97+
case BaseType::Integer:
98+
return DT_Integer;
99+
case BaseType::Pointer:
100+
return DT_Pointer;
101+
case BaseType::Anything:
102+
return DT_Anything;
103+
case BaseType::Unknown:
104+
return DT_Unknown;
105+
case BaseType::Float:
106+
llvm_unreachable("Illegal conversion of concretetype");
107+
}
108+
}
109+
llvm_unreachable("Illegal conversion of concretetype");
110+
}
111+
112+
IntList ewrap(const std::vector<int> &offsets) {
113+
IntList IL;
114+
IL.size = offsets.size();
115+
IL.data = (int64_t *)malloc(IL.size * sizeof(*IL.data));
116+
for (size_t i = 0; i < offsets.size(); i++) {
117+
IL.data[i] = offsets[i];
118+
}
119+
return IL;
120+
}
121+
122+
CTypeTreeRef ewrap(const TypeTree &TT) {
123+
return (CTypeTreeRef)(new TypeTree(TT));
124+
}
125+
126+
FnTypeInfo eunwrap(CFnTypeInfo CTI, llvm::Function *F) {
127+
FnTypeInfo FTI(F);
128+
// auto &ctx = F->getContext();
129+
FTI.Return = eunwrap(CTI.Return);
130+
131+
size_t argnum = 0;
132+
for (auto &arg : F->args()) {
133+
FTI.Arguments[&arg] = eunwrap(CTI.Arguments[argnum]);
134+
FTI.KnownValues[&arg] = eunwrap64(CTI.KnownValues[argnum]);
135+
argnum++;
136+
}
137+
return FTI;
138+
}
139+
140+
extern "C" {
141+
142+
EnzymeTypeAnalysisRef CreateTypeAnalysis(char *TripleStr,
143+
char **customRuleNames,
144+
CustomRuleType *customRules,
145+
size_t numRules) {
146+
TypeAnalysis *TA = new TypeAnalysis(*(
147+
new TargetLibraryInfo(*(new TargetLibraryInfoImpl(Triple(TripleStr))))));
148+
for (size_t i = 0; i < numRules; i++) {
149+
CustomRuleType rule = customRules[i];
150+
TA->CustomRules[customRuleNames[i]] =
151+
[=](int direction, TypeTree &returnTree,
152+
std::vector<TypeTree> &argTrees,
153+
std::vector<std::set<int64_t>> &knownValues,
154+
CallInst *call) -> uint8_t {
155+
CTypeTreeRef creturnTree = (CTypeTreeRef)(&returnTree);
156+
CTypeTreeRef *cargs = new CTypeTreeRef[argTrees.size()];
157+
IntList *kvs = new IntList[argTrees.size()];
158+
for (size_t i = 0; i < argTrees.size(); ++i) {
159+
cargs[i] = (CTypeTreeRef)(&(argTrees[i]));
160+
kvs[i].size = knownValues[i].size();
161+
kvs[i].data = (int64_t *)malloc(kvs[i].size * sizeof(*kvs[i].data));
162+
size_t j = 0;
163+
for (auto val : knownValues[i]) {
164+
kvs[i].data[j] = val;
165+
j++;
166+
}
167+
}
168+
uint8_t result =
169+
rule(direction, creturnTree, cargs, kvs, argTrees.size(), wrap(call));
170+
delete[] cargs;
171+
for (size_t i = 0; i < argTrees.size(); ++i) {
172+
free(kvs[i].data);
173+
}
174+
delete[] kvs;
175+
return result;
176+
};
177+
}
178+
return (EnzymeTypeAnalysisRef)TA;
179+
}
180+
void FreeTypeAnalysis(EnzymeTypeAnalysisRef TAR) {
181+
TypeAnalysis *TA = (TypeAnalysis *)TAR;
182+
delete &TA->TLI.Impl;
183+
delete &TA->TLI;
184+
delete TA;
185+
}
186+
187+
LLVMValueRef EnzymeCreatePrimalAndGradient(
188+
LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
189+
size_t constant_args_size, EnzymeTypeAnalysisRef TA,
190+
EnzymeAAResultsRef global_AA, uint8_t returnValue, uint8_t dretUsed,
191+
uint8_t topLevel, LLVMTypeRef additionalArg, CFnTypeInfo typeInfo,
192+
uint8_t *_uncacheable_args, size_t uncacheable_args_size,
193+
EnzymeAugmentedReturnPtr augmented, uint8_t AtomicAdd, uint8_t PostOpt) {
194+
std::vector<DIFFE_TYPE> nconstant_args((DIFFE_TYPE *)constant_args,
195+
(DIFFE_TYPE *)constant_args +
196+
constant_args_size);
197+
std::map<llvm::Argument *, bool> uncacheable_args;
198+
size_t argnum = 0;
199+
for (auto &arg : cast<Function>(unwrap(todiff))->args()) {
200+
assert(argnum < uncacheable_args_size);
201+
uncacheable_args[&arg] = _uncacheable_args[argnum];
202+
argnum++;
203+
}
204+
return wrap(CreatePrimalAndGradient(
205+
cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
206+
eunwrap(TA).TLI, eunwrap(TA), eunwrap(global_AA), returnValue, dretUsed,
207+
topLevel, unwrap(additionalArg),
208+
eunwrap(typeInfo, cast<Function>(unwrap(todiff))), uncacheable_args,
209+
eunwrap(augmented), AtomicAdd, PostOpt));
210+
}
211+
EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
212+
LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
213+
size_t constant_args_size, EnzymeTypeAnalysisRef TA,
214+
EnzymeAAResultsRef global_AA, uint8_t returnUsed, CFnTypeInfo typeInfo,
215+
uint8_t *_uncacheable_args, size_t uncacheable_args_size,
216+
uint8_t forceAnonymousTape, uint8_t AtomicAdd, uint8_t PostOpt) {
217+
218+
std::vector<DIFFE_TYPE> nconstant_args((DIFFE_TYPE *)constant_args,
219+
(DIFFE_TYPE *)constant_args +
220+
constant_args_size);
221+
std::map<llvm::Argument *, bool> uncacheable_args;
222+
size_t argnum = 0;
223+
for (auto &arg : cast<Function>(unwrap(todiff))->args()) {
224+
assert(argnum < uncacheable_args_size);
225+
uncacheable_args[&arg] = _uncacheable_args[argnum];
226+
argnum++;
227+
}
228+
return ewrap(CreateAugmentedPrimal(
229+
cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
230+
eunwrap(TA).TLI, eunwrap(TA), eunwrap(global_AA), returnUsed,
231+
eunwrap(typeInfo, cast<Function>(unwrap(todiff))), uncacheable_args,
232+
forceAnonymousTape, AtomicAdd, PostOpt));
233+
}
234+
235+
EnzymeAAResultsRef EnzymeGetGlobalAA(LLVMModuleRef M) {
236+
ModuleAnalysisManager *AM = new ModuleAnalysisManager();
237+
AM->registerPass([] { return CallGraphAnalysis(); });
238+
FunctionAnalysisManager *FAM = new FunctionAnalysisManager();
239+
AM->registerPass([=] { return FunctionAnalysisManagerModuleProxy(*FAM); });
240+
FAM->registerPass([=] { return ModuleAnalysisManagerFunctionProxy(*AM); });
241+
FAM->registerPass([] { return TargetLibraryAnalysis(); });
242+
#if LLVM_VERSION_MAJOR >= 8
243+
AM->registerPass([] { return PassInstrumentationAnalysis(); });
244+
FAM->registerPass([] { return PassInstrumentationAnalysis(); });
245+
#endif
246+
247+
#if LLVM_VERSION_MAJOR >= 10
248+
auto GetTLI = [=](Function &F) -> TargetLibraryInfo & {
249+
return FAM->getResult<TargetLibraryAnalysis>(F);
250+
};
251+
return (EnzymeAAResultsRef){
252+
(struct EnzymeOpaqueAAResults *)(new GlobalsAAResult(
253+
GlobalsAAResult::analyzeModule(
254+
*unwrap(M), GetTLI,
255+
AM->getResult<CallGraphAnalysis>(*unwrap(M))))),
256+
AM, FAM};
257+
#else
258+
AM->registerPass([] { return TargetLibraryAnalysis(); });
259+
return (EnzymeAAResultsRef){
260+
(struct EnzymeOpaqueAAResults *)(new GlobalsAAResult(
261+
GlobalsAAResult::analyzeModule(
262+
*unwrap(M), AM->getResult<TargetLibraryAnalysis>(*unwrap(M)),
263+
AM->getResult<CallGraphAnalysis>(*unwrap(M))))),
264+
AM, FAM};
265+
266+
#endif
267+
}
268+
void EnzymeFreeGlobalAA(EnzymeAAResultsRef AA) {
269+
delete ((GlobalsAAResult *)AA.AA);
270+
delete ((ModuleAnalysisManager *)AA.AM);
271+
delete ((FunctionAnalysisManager *)AA.FAM);
272+
}
273+
274+
LLVMValueRef
275+
EnzymeExtractFunctionFromAugmentation(EnzymeAugmentedReturnPtr ret) {
276+
auto AR = (AugmentedReturn *)ret;
277+
return wrap(AR->fn);
278+
}
279+
280+
LLVMTypeRef
281+
EnzymeExtractTapeTypeFromAugmentation(EnzymeAugmentedReturnPtr ret) {
282+
auto AR = (AugmentedReturn *)ret;
283+
auto found = AR->returns.find(AugmentedStruct::Tape);
284+
if (found == AR->returns.end()) {
285+
return wrap((Type *)nullptr);
286+
}
287+
if (found->second == -1) {
288+
return wrap(AR->fn->getReturnType());
289+
}
290+
return wrap(
291+
cast<StructType>(AR->fn->getReturnType())->getTypeAtIndex(found->second));
292+
}
293+
294+
void EnzymeExtractReturnInfo(EnzymeAugmentedReturnPtr ret, int64_t *data,
295+
uint8_t *existed, size_t len) {
296+
assert(len == 3);
297+
auto AR = (AugmentedReturn *)ret;
298+
AugmentedStruct todo[] = {AugmentedStruct::Tape, AugmentedStruct::Return,
299+
AugmentedStruct::DifferentialReturn};
300+
for (size_t i = 0; i < len; i++) {
301+
auto found = AR->returns.find(todo[i]);
302+
if (found != AR->returns.end()) {
303+
existed[i] = true;
304+
data[i] = (int64_t)found->second;
305+
} else {
306+
existed[i] = false;
307+
}
308+
}
309+
}
310+
311+
CTypeTreeRef EnzymeNewTypeTree() { return (CTypeTreeRef)(new TypeTree()); }
312+
CTypeTreeRef EnzymeNewTypeTreeCT(CConcreteType CT, LLVMContextRef ctx) {
313+
return (CTypeTreeRef)(new TypeTree(eunwrap(CT, *unwrap(ctx))));
314+
}
315+
CTypeTreeRef EnzymeNewTypeTreeTR(CTypeTreeRef CTR) {
316+
return (CTypeTreeRef)(new TypeTree(*(TypeTree *)(CTR)));
317+
}
318+
void EnzymeFreeTypeTree(CTypeTreeRef CTT) { delete (TypeTree *)CTT; }
319+
void EnzymeSetTypeTree(CTypeTreeRef dst, CTypeTreeRef src) {
320+
*(TypeTree *)dst = *(TypeTree *)src;
321+
}
322+
void EnzymeMergeTypeTree(CTypeTreeRef dst, CTypeTreeRef src) {
323+
((TypeTree *)dst)->orIn(*(TypeTree *)src, /*PointerIntSame*/ false);
324+
}
325+
326+
void EnzymeTypeTreeOnlyEq(CTypeTreeRef CTT, int64_t x) {
327+
*(TypeTree *)CTT = ((TypeTree *)CTT)->Only(x);
328+
}
329+
void EnzymeTypeTreeShiftIndiciesEq(CTypeTreeRef CTT, const char *datalayout,
330+
int64_t offset, int64_t maxSize,
331+
uint64_t addOffset) {
332+
DataLayout DL(datalayout);
333+
*(TypeTree *)CTT =
334+
((TypeTree *)CTT)->ShiftIndices(DL, offset, maxSize, addOffset);
335+
}
336+
}

0 commit comments

Comments
 (0)