Skip to content

Commit 93a827e

Browse files
committed
Map experimental C (actually C++) API for gradient tape
1 parent 0d73a9b commit 93a827e

File tree

88 files changed

+2504
-69
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+2504
-69
lines changed

Diff for: tensorflow-core/pom.xml

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
<javacpp.platform.macosx-x86_64.extension>macosx-x86_64${javacpp.platform.extension}</javacpp.platform.macosx-x86_64.extension>
6262
<javacpp.platform.windows-x86.extension>windows-x86${javacpp.platform.extension}</javacpp.platform.windows-x86.extension>
6363
<javacpp.platform.windows-x86_64.extension>windows-x86_64${javacpp.platform.extension}</javacpp.platform.windows-x86_64.extension>
64-
<javacpp.version>1.5.4</javacpp.version>
64+
<javacpp.version>1.5.5</javacpp.version>
6565
</properties>
6666

6767
<profiles>

Diff for: tensorflow-core/tensorflow-core-api/pom.xml

+25
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,19 @@
141141
</execution>
142142
</executions>
143143
</plugin>
144+
<plugin>
145+
<artifactId>maven-resources-plugin</artifactId>
146+
<version>3.1.0</version>
147+
<executions>
148+
<execution>
149+
<id>javacpp-parser</id>
150+
<phase>generate-sources</phase>
151+
<goals>
152+
<goal>resources</goal>
153+
</goals>
154+
</execution>
155+
</executions>
156+
</plugin>
144157
<plugin>
145158
<artifactId>maven-compiler-plugin</artifactId>
146159
<version>3.8.0</version>
@@ -209,7 +222,15 @@
209222
<classPath>${project.build.outputDirectory}</classPath>
210223
<includePaths>
211224
<includePath>${project.basedir}/</includePath>
225+
<includePath>${project.basedir}/bazel-bin/external/llvm-project/llvm/include/</includePath>
226+
<includePath>${project.basedir}/bazel-bin/external/org_tensorflow/</includePath>
227+
<includePath>${project.basedir}/bazel-${project.artifactId}/external/eigen_archive/</includePath>
228+
<includePath>${project.basedir}/bazel-${project.artifactId}/external/com_google_absl/</includePath>
229+
<includePath>${project.basedir}/bazel-${project.artifactId}/external/com_google_protobuf/src/</includePath>
230+
<includePath>${project.basedir}/bazel-${project.artifactId}/external/farmhash_archive/src/</includePath>
231+
<includePath>${project.basedir}/bazel-${project.artifactId}/external/llvm-project/llvm/include/</includePath>
212232
<includePath>${project.basedir}/bazel-${project.artifactId}/external/org_tensorflow/</includePath>
233+
<includePath>${project.basedir}/target/classes/org/tensorflow/internal/c_api/include/</includePath>
213234
</includePaths>
214235
<linkPaths>
215236
<linkPath>${project.basedir}/bazel-bin/external/llvm_openmp/</linkPath>
@@ -315,6 +336,10 @@
315336
<outputDirectory>${project.build.directory}/native/org/tensorflow/internal/c_api/${native.classifier}/</outputDirectory>
316337
<skip>${javacpp.compiler.skip}</skip>
317338
<classOrPackageName>org.tensorflow.internal.c_api.**</classOrPackageName>
339+
<compilerOptions>
340+
<!-- TODO: Remove files from here as they get integrated into the Bazel build -->
341+
<compilerOption>${project.basedir}/bazel-${project.artifactId}/external/org_tensorflow/tensorflow/c/eager/gradients.cc</compilerOption>
342+
</compilerOptions>
318343
<copyLibs>true</copyLibs>
319344
<copyResources>true</copyResources>
320345
</configuration>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Targeted by JavaCPP version 1.5.5: DO NOT EDIT THIS FILE
2+
3+
package org.tensorflow.internal.c_api;
4+
5+
import java.nio.*;
6+
import org.bytedeco.javacpp.*;
7+
import org.bytedeco.javacpp.annotation.*;
8+
9+
import static org.tensorflow.internal.c_api.global.tensorflow.*;
10+
11+
12+
// Abstract interface to a context.
13+
//
14+
// This serves as a factory for creating `AbstractOperation`s and for
15+
// registering traced functions.
16+
// Operations creation within a context can only be executed in that context
17+
// (for now at least).
18+
// Implementations of the context may contain some state e.g. an execution
19+
// environment, a traced representation etc.
20+
@Namespace("tensorflow") @NoOffset @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
21+
public class AbstractContext extends Pointer {
22+
static { Loader.load(); }
23+
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
24+
public AbstractContext(Pointer p) { super(p); }
25+
26+
public native int getKind();
27+
28+
// Release any underlying resources, including the interface object.
29+
//
30+
// WARNING: The destructor of this class is marked as protected to disallow
31+
// clients from directly destroying this object since it may manage it's own
32+
// lifetime through ref counting. Thus clients MUST call Release() in order to
33+
// destroy an instance of this class.
34+
public native void Release();
35+
36+
// Creates an operation builder and ties it to this context.
37+
// The returned object can be used for setting operation's attributes,
38+
// adding inputs and finally executing (immediately or lazily as in tracing)
39+
// it in this context.
40+
public native AbstractOperation CreateOperation();
41+
42+
// Registers a function with this context, after this the function is
43+
// available to be called/referenced by its name in this context.
44+
public native @ByVal Status RegisterFunction(AbstractFunction arg0);
45+
// Remove a function. 'func' argument is the name of a previously added
46+
// FunctionDef. The name is in fdef.signature.name.
47+
public native @ByVal Status RemoveFunction(@StdString BytePointer func);
48+
public native @ByVal Status RemoveFunction(@StdString String func);
49+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Targeted by JavaCPP version 1.5.5: DO NOT EDIT THIS FILE
2+
3+
package org.tensorflow.internal.c_api;
4+
5+
import java.nio.*;
6+
import org.bytedeco.javacpp.*;
7+
import org.bytedeco.javacpp.annotation.*;
8+
9+
import static org.tensorflow.internal.c_api.global.tensorflow.*;
10+
11+
@Namespace("tensorflow::internal") @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
12+
public class AbstractContextDeleter extends Pointer {
13+
static { Loader.load(); }
14+
/** Default native constructor. */
15+
public AbstractContextDeleter() { super((Pointer)null); allocate(); }
16+
/** Native array allocator. Access with {@link Pointer#position(long)}. */
17+
public AbstractContextDeleter(long size) { super((Pointer)null); allocateArray(size); }
18+
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
19+
public AbstractContextDeleter(Pointer p) { super(p); }
20+
private native void allocate();
21+
private native void allocateArray(long size);
22+
@Override public AbstractContextDeleter position(long position) {
23+
return (AbstractContextDeleter)super.position(position);
24+
}
25+
@Override public AbstractContextDeleter getPointer(long i) {
26+
return new AbstractContextDeleter((Pointer)this).position(position + i);
27+
}
28+
29+
public native @Name("operator ()") void apply(AbstractContext p);
30+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Targeted by JavaCPP version 1.5.5: DO NOT EDIT THIS FILE
2+
3+
package org.tensorflow.internal.c_api;
4+
5+
import java.nio.*;
6+
import org.bytedeco.javacpp.*;
7+
import org.bytedeco.javacpp.annotation.*;
8+
9+
import static org.tensorflow.internal.c_api.global.tensorflow.*;
10+
11+
12+
// A traced function: this hides the complexity of converting the serialized
13+
// representation between various supported formats e.g. FunctionDef and Mlir
14+
// function.
15+
@Namespace("tensorflow") @NoOffset @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
16+
public class AbstractFunction extends Pointer {
17+
static { Loader.load(); }
18+
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
19+
public AbstractFunction(Pointer p) { super(p); }
20+
21+
// Returns which subclass is this instance of.
22+
public native int getKind();
23+
24+
// Returns the AbstractFunction as a FunctionDef.
25+
public native @ByVal Status GetFunctionDef(@Cast("tensorflow::FunctionDef**") PointerPointer arg0);
26+
public native @ByVal Status GetFunctionDef(@Cast("tensorflow::FunctionDef**") @ByPtrPtr Pointer arg0);
27+
}

0 commit comments

Comments
 (0)