1
+ // ===- CApi.cpp - Enzyme API exported to C for external use -----------===//
2
+ //
3
+ // Enzyme Project
4
+ //
5
+ // Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6
+ // Exceptions. See https://llvm.org/LICENSE.txt for license information.
7
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8
+ //
9
+ // If using this code in an academic setting, please cite the following:
10
+ // @incollection{enzymeNeurips,
11
+ // title = {Instead of Rewriting Foreign Code for Machine Learning,
12
+ // Automatically Synthesize Fast Gradients},
13
+ // author = {Moses, William S. and Churavy, Valentin},
14
+ // booktitle = {Advances in Neural Information Processing Systems 33},
15
+ // year = {2020},
16
+ // note = {To appear in},
17
+ // }
18
+ //
19
+ // ===----------------------------------------------------------------------===//
20
+ //
21
+ // This file defines various utility functions of Enzyme for access via C
22
+ //
23
+ // ===----------------------------------------------------------------------===//
24
+ #include " CApi.h"
25
+ #include " EnzymeLogic.h"
26
+ #include " SCEV/TargetLibraryInfo.h"
27
+
28
+ #include " llvm/ADT/Triple.h"
29
+ #include " llvm/Analysis/CallGraph.h"
30
+ #include " llvm/Analysis/GlobalsModRef.h"
31
+
32
+ using namespace llvm ;
33
+
34
+ TargetLibraryInfo eunwrap (LLVMTargetLibraryInfoRef P) {
35
+ return TargetLibraryInfo (*reinterpret_cast <TargetLibraryInfoImpl *>(P));
36
+ }
37
+
38
+ TypeAnalysis &eunwrap (EnzymeTypeAnalysisRef TAR) {
39
+ return *(TypeAnalysis *)TAR;
40
+ }
41
+ llvm::AAResults &eunwrap (EnzymeAAResultsRef AAR) {
42
+ return *(llvm::AAResults *)AAR.AA ;
43
+ }
44
+ AugmentedReturn *eunwrap (EnzymeAugmentedReturnPtr ARP) {
45
+ return (AugmentedReturn *)ARP;
46
+ }
47
+ EnzymeAugmentedReturnPtr ewrap (const AugmentedReturn &AR) {
48
+ return (EnzymeAugmentedReturnPtr)(&AR);
49
+ }
50
+
51
+ ConcreteType eunwrap (CConcreteType CDT, llvm::LLVMContext &ctx) {
52
+ switch (CDT) {
53
+ case DT_Anything:
54
+ return BaseType::Anything;
55
+ case DT_Integer:
56
+ return BaseType::Integer;
57
+ case DT_Pointer:
58
+ return BaseType::Pointer;
59
+ case DT_Half:
60
+ return ConcreteType (llvm::Type::getHalfTy (ctx));
61
+ case DT_Float:
62
+ return ConcreteType (llvm::Type::getFloatTy (ctx));
63
+ case DT_Double:
64
+ return ConcreteType (llvm::Type::getDoubleTy (ctx));
65
+ case DT_Unknown:
66
+ return BaseType::Unknown;
67
+ }
68
+ llvm_unreachable (" Unknown concrete type to unwrap" );
69
+ }
70
+
71
+ std::vector<int > eunwrap (IntList IL) {
72
+ std::vector<int > v;
73
+ for (size_t i = 0 ; i < IL.size ; i++) {
74
+ v.push_back ((int )IL.data [i]);
75
+ }
76
+ return v;
77
+ }
78
+ std::set<int64_t > eunwrap64 (IntList IL) {
79
+ std::set<int64_t > v;
80
+ for (size_t i = 0 ; i < IL.size ; i++) {
81
+ v.insert ((int64_t )IL.data [i]);
82
+ }
83
+ return v;
84
+ }
85
+ TypeTree eunwrap (CTypeTreeRef CTT) { return *(TypeTree *)CTT; }
86
+
87
+ CConcreteType ewrap (const ConcreteType &CT) {
88
+ if (auto flt = CT.isFloat ()) {
89
+ if (flt->isHalfTy ())
90
+ return DT_Half;
91
+ if (flt->isFloatTy ())
92
+ return DT_Float;
93
+ if (flt->isDoubleTy ())
94
+ return DT_Double;
95
+ } else {
96
+ switch (CT.SubTypeEnum ) {
97
+ case BaseType::Integer:
98
+ return DT_Integer;
99
+ case BaseType::Pointer:
100
+ return DT_Pointer;
101
+ case BaseType::Anything:
102
+ return DT_Anything;
103
+ case BaseType::Unknown:
104
+ return DT_Unknown;
105
+ case BaseType::Float:
106
+ llvm_unreachable (" Illegal conversion of concretetype" );
107
+ }
108
+ }
109
+ llvm_unreachable (" Illegal conversion of concretetype" );
110
+ }
111
+
112
+ IntList ewrap (const std::vector<int > &offsets) {
113
+ IntList IL;
114
+ IL.size = offsets.size ();
115
+ IL.data = (int64_t *)malloc (IL.size * sizeof (*IL.data ));
116
+ for (size_t i = 0 ; i < offsets.size (); i++) {
117
+ IL.data [i] = offsets[i];
118
+ }
119
+ return IL;
120
+ }
121
+
122
+ CTypeTreeRef ewrap (const TypeTree &TT) {
123
+ return (CTypeTreeRef)(new TypeTree (TT));
124
+ }
125
+
126
+ FnTypeInfo eunwrap (CFnTypeInfo CTI, llvm::Function *F) {
127
+ FnTypeInfo FTI (F);
128
+ // auto &ctx = F->getContext();
129
+ FTI.Return = eunwrap (CTI.Return );
130
+
131
+ size_t argnum = 0 ;
132
+ for (auto &arg : F->args ()) {
133
+ FTI.Arguments [&arg] = eunwrap (CTI.Arguments [argnum]);
134
+ FTI.KnownValues [&arg] = eunwrap64 (CTI.KnownValues [argnum]);
135
+ argnum++;
136
+ }
137
+ return FTI;
138
+ }
139
+
140
+ extern " C" {
141
+
142
+ EnzymeTypeAnalysisRef CreateTypeAnalysis (char *TripleStr,
143
+ char **customRuleNames,
144
+ CustomRuleType *customRules,
145
+ size_t numRules) {
146
+ TypeAnalysis *TA = new TypeAnalysis (*(
147
+ new TargetLibraryInfo (*(new TargetLibraryInfoImpl (Triple (TripleStr))))));
148
+ for (size_t i = 0 ; i < numRules; i++) {
149
+ CustomRuleType rule = customRules[i];
150
+ TA->CustomRules [customRuleNames[i]] =
151
+ [=](int direction, TypeTree &returnTree,
152
+ std::vector<TypeTree> &argTrees,
153
+ std::vector<std::set<int64_t >> &knownValues,
154
+ CallInst *call) -> uint8_t {
155
+ CTypeTreeRef creturnTree = (CTypeTreeRef)(&returnTree);
156
+ CTypeTreeRef *cargs = new CTypeTreeRef[argTrees.size ()];
157
+ IntList *kvs = new IntList[argTrees.size ()];
158
+ for (size_t i = 0 ; i < argTrees.size (); ++i) {
159
+ cargs[i] = (CTypeTreeRef)(&(argTrees[i]));
160
+ kvs[i].size = knownValues[i].size ();
161
+ kvs[i].data = (int64_t *)malloc (kvs[i].size * sizeof (*kvs[i].data ));
162
+ size_t j = 0 ;
163
+ for (auto val : knownValues[i]) {
164
+ kvs[i].data [j] = val;
165
+ j++;
166
+ }
167
+ }
168
+ uint8_t result =
169
+ rule (direction, creturnTree, cargs, kvs, argTrees.size (), wrap (call));
170
+ delete[] cargs;
171
+ for (size_t i = 0 ; i < argTrees.size (); ++i) {
172
+ free (kvs[i].data );
173
+ }
174
+ delete[] kvs;
175
+ return result;
176
+ };
177
+ }
178
+ return (EnzymeTypeAnalysisRef)TA;
179
+ }
180
+ void FreeTypeAnalysis (EnzymeTypeAnalysisRef TAR) {
181
+ TypeAnalysis *TA = (TypeAnalysis *)TAR;
182
+ delete &TA->TLI .Impl ;
183
+ delete &TA->TLI ;
184
+ delete TA;
185
+ }
186
+
187
+ LLVMValueRef EnzymeCreatePrimalAndGradient (
188
+ LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
189
+ size_t constant_args_size, EnzymeTypeAnalysisRef TA,
190
+ EnzymeAAResultsRef global_AA, uint8_t returnValue, uint8_t dretUsed,
191
+ uint8_t topLevel, LLVMTypeRef additionalArg, CFnTypeInfo typeInfo,
192
+ uint8_t *_uncacheable_args, size_t uncacheable_args_size,
193
+ EnzymeAugmentedReturnPtr augmented, uint8_t AtomicAdd, uint8_t PostOpt) {
194
+ std::vector<DIFFE_TYPE> nconstant_args ((DIFFE_TYPE *)constant_args,
195
+ (DIFFE_TYPE *)constant_args +
196
+ constant_args_size);
197
+ std::map<llvm::Argument *, bool > uncacheable_args;
198
+ size_t argnum = 0 ;
199
+ for (auto &arg : cast<Function>(unwrap (todiff))->args ()) {
200
+ assert (argnum < uncacheable_args_size);
201
+ uncacheable_args[&arg] = _uncacheable_args[argnum];
202
+ argnum++;
203
+ }
204
+ return wrap (CreatePrimalAndGradient (
205
+ cast<Function>(unwrap (todiff)), (DIFFE_TYPE)retType, nconstant_args,
206
+ eunwrap (TA).TLI , eunwrap (TA), eunwrap (global_AA), returnValue, dretUsed,
207
+ topLevel, unwrap (additionalArg),
208
+ eunwrap (typeInfo, cast<Function>(unwrap (todiff))), uncacheable_args,
209
+ eunwrap (augmented), AtomicAdd, PostOpt));
210
+ }
211
+ EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal (
212
+ LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
213
+ size_t constant_args_size, EnzymeTypeAnalysisRef TA,
214
+ EnzymeAAResultsRef global_AA, uint8_t returnUsed, CFnTypeInfo typeInfo,
215
+ uint8_t *_uncacheable_args, size_t uncacheable_args_size,
216
+ uint8_t forceAnonymousTape, uint8_t AtomicAdd, uint8_t PostOpt) {
217
+
218
+ std::vector<DIFFE_TYPE> nconstant_args ((DIFFE_TYPE *)constant_args,
219
+ (DIFFE_TYPE *)constant_args +
220
+ constant_args_size);
221
+ std::map<llvm::Argument *, bool > uncacheable_args;
222
+ size_t argnum = 0 ;
223
+ for (auto &arg : cast<Function>(unwrap (todiff))->args ()) {
224
+ assert (argnum < uncacheable_args_size);
225
+ uncacheable_args[&arg] = _uncacheable_args[argnum];
226
+ argnum++;
227
+ }
228
+ return ewrap (CreateAugmentedPrimal (
229
+ cast<Function>(unwrap (todiff)), (DIFFE_TYPE)retType, nconstant_args,
230
+ eunwrap (TA).TLI , eunwrap (TA), eunwrap (global_AA), returnUsed,
231
+ eunwrap (typeInfo, cast<Function>(unwrap (todiff))), uncacheable_args,
232
+ forceAnonymousTape, AtomicAdd, PostOpt));
233
+ }
234
+
235
+ EnzymeAAResultsRef EnzymeGetGlobalAA (LLVMModuleRef M) {
236
+ ModuleAnalysisManager *AM = new ModuleAnalysisManager ();
237
+ AM->registerPass ([] { return CallGraphAnalysis (); });
238
+ FunctionAnalysisManager *FAM = new FunctionAnalysisManager ();
239
+ AM->registerPass ([=] { return FunctionAnalysisManagerModuleProxy (*FAM); });
240
+ FAM->registerPass ([=] { return ModuleAnalysisManagerFunctionProxy (*AM); });
241
+ FAM->registerPass ([] { return TargetLibraryAnalysis (); });
242
+ #if LLVM_VERSION_MAJOR >= 8
243
+ AM->registerPass ([] { return PassInstrumentationAnalysis (); });
244
+ FAM->registerPass ([] { return PassInstrumentationAnalysis (); });
245
+ #endif
246
+
247
+ #if LLVM_VERSION_MAJOR >= 10
248
+ auto GetTLI = [=](Function &F) -> TargetLibraryInfo & {
249
+ return FAM->getResult <TargetLibraryAnalysis>(F);
250
+ };
251
+ return (EnzymeAAResultsRef){
252
+ (struct EnzymeOpaqueAAResults *)(new GlobalsAAResult (
253
+ GlobalsAAResult::analyzeModule (
254
+ *unwrap (M), GetTLI,
255
+ AM->getResult <CallGraphAnalysis>(*unwrap (M))))),
256
+ AM, FAM};
257
+ #else
258
+ AM->registerPass ([] { return TargetLibraryAnalysis (); });
259
+ return (EnzymeAAResultsRef){
260
+ (struct EnzymeOpaqueAAResults *)(new GlobalsAAResult (
261
+ GlobalsAAResult::analyzeModule (
262
+ *unwrap (M), AM->getResult <TargetLibraryAnalysis>(*unwrap (M)),
263
+ AM->getResult <CallGraphAnalysis>(*unwrap (M))))),
264
+ AM, FAM};
265
+
266
+ #endif
267
+ }
268
+ void EnzymeFreeGlobalAA (EnzymeAAResultsRef AA) {
269
+ delete ((GlobalsAAResult *)AA.AA );
270
+ delete ((ModuleAnalysisManager *)AA.AM );
271
+ delete ((FunctionAnalysisManager *)AA.FAM );
272
+ }
273
+
274
+ LLVMValueRef
275
+ EnzymeExtractFunctionFromAugmentation (EnzymeAugmentedReturnPtr ret) {
276
+ auto AR = (AugmentedReturn *)ret;
277
+ return wrap (AR->fn );
278
+ }
279
+
280
+ LLVMTypeRef
281
+ EnzymeExtractTapeTypeFromAugmentation (EnzymeAugmentedReturnPtr ret) {
282
+ auto AR = (AugmentedReturn *)ret;
283
+ auto found = AR->returns .find (AugmentedStruct::Tape);
284
+ if (found == AR->returns .end ()) {
285
+ return wrap ((Type *)nullptr );
286
+ }
287
+ if (found->second == -1 ) {
288
+ return wrap (AR->fn ->getReturnType ());
289
+ }
290
+ return wrap (
291
+ cast<StructType>(AR->fn ->getReturnType ())->getTypeAtIndex (found->second ));
292
+ }
293
+
294
+ void EnzymeExtractReturnInfo (EnzymeAugmentedReturnPtr ret, int64_t *data,
295
+ uint8_t *existed, size_t len) {
296
+ assert (len == 3 );
297
+ auto AR = (AugmentedReturn *)ret;
298
+ AugmentedStruct todo[] = {AugmentedStruct::Tape, AugmentedStruct::Return,
299
+ AugmentedStruct::DifferentialReturn};
300
+ for (size_t i = 0 ; i < len; i++) {
301
+ auto found = AR->returns .find (todo[i]);
302
+ if (found != AR->returns .end ()) {
303
+ existed[i] = true ;
304
+ data[i] = (int64_t )found->second ;
305
+ } else {
306
+ existed[i] = false ;
307
+ }
308
+ }
309
+ }
310
+
311
+ CTypeTreeRef EnzymeNewTypeTree () { return (CTypeTreeRef)(new TypeTree ()); }
312
+ CTypeTreeRef EnzymeNewTypeTreeCT (CConcreteType CT, LLVMContextRef ctx) {
313
+ return (CTypeTreeRef)(new TypeTree (eunwrap (CT, *unwrap (ctx))));
314
+ }
315
+ CTypeTreeRef EnzymeNewTypeTreeTR (CTypeTreeRef CTR) {
316
+ return (CTypeTreeRef)(new TypeTree (*(TypeTree *)(CTR)));
317
+ }
318
+ void EnzymeFreeTypeTree (CTypeTreeRef CTT) { delete (TypeTree *)CTT; }
319
+ void EnzymeSetTypeTree (CTypeTreeRef dst, CTypeTreeRef src) {
320
+ *(TypeTree *)dst = *(TypeTree *)src;
321
+ }
322
+ void EnzymeMergeTypeTree (CTypeTreeRef dst, CTypeTreeRef src) {
323
+ ((TypeTree *)dst)->orIn (*(TypeTree *)src, /* PointerIntSame*/ false );
324
+ }
325
+
326
+ void EnzymeTypeTreeOnlyEq (CTypeTreeRef CTT, int64_t x) {
327
+ *(TypeTree *)CTT = ((TypeTree *)CTT)->Only (x);
328
+ }
329
+ void EnzymeTypeTreeShiftIndiciesEq (CTypeTreeRef CTT, const char *datalayout,
330
+ int64_t offset, int64_t maxSize,
331
+ uint64_t addOffset) {
332
+ DataLayout DL (datalayout);
333
+ *(TypeTree *)CTT =
334
+ ((TypeTree *)CTT)->ShiftIndices (DL, offset, maxSize, addOffset);
335
+ }
336
+ }
0 commit comments