Skip to content

Commit 0d73a9b

Browse files
authored
Move op generation to Java (#244)
1 parent d570edb commit 0d73a9b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+37146
-54
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ __pycache__
1616
cmake_build/
1717
tensorflow/contrib/cmake/_build/
1818
.idea/**
19+
.run
1920
/build/
2021
[Bb]uild/
2122
/tensorflow/core/util/version_info.cc

Diff for: tensorflow-core/tensorflow-core-api/build.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ $BAZEL_BIN/java_op_generator \
8989
GEN_RESOURCE_DIR=src/gen/resources/org/tensorflow/op
9090
mkdir -p $GEN_RESOURCE_DIR
9191

92-
# Generate Java operator wrappers
92+
# Export op defs
9393
$BAZEL_BIN/java_op_exporter \
9494
--api_dirs=$BAZEL_SRCS/external/org_tensorflow/tensorflow/core/api_def/base_api,src/bazel/api_def \
9595
$TENSORFLOW_LIB > $GEN_RESOURCE_DIR/ops.pb

Diff for: tensorflow-core/tensorflow-core-api/pom.xml

+35
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,41 @@
328328
</dependency>
329329
</dependencies>
330330
</plugin>
331+
<plugin>
332+
<!--
333+
An execution to execute the op class generator (which lives in tensorflow-core-generator).
334+
Must be ran after build.sh, which generates the ops.pb file it reads.
335+
Will be ran during the generate-sources phase.
336+
-->
337+
<groupId>org.codehaus.mojo</groupId>
338+
<artifactId>exec-maven-plugin</artifactId>
339+
<version>3.0.0</version>
340+
<executions>
341+
<execution>
342+
<id>generate-ops</id>
343+
<goals>
344+
<goal>java</goal>
345+
</goals>
346+
<!-- <phase>generate-sources</phase>-->
347+
</execution>
348+
</executions>
349+
<dependencies>
350+
<dependency>
351+
<groupId>org.tensorflow</groupId>
352+
<artifactId>tensorflow-core-generator</artifactId>
353+
<version>${project.version}</version>
354+
</dependency>
355+
</dependencies>
356+
<configuration>
357+
<includeProjectDependencies>false</includeProjectDependencies>
358+
<includePluginDependencies>true</includePluginDependencies>
359+
<mainClass>org.tensorflow.op.generator.OpGenerator</mainClass>
360+
<arguments>
361+
<argument>${project.basedir}/src/gen/java</argument>
362+
<argument>${project.basedir}/src/gen/resources/org/tensorflow/op/ops.pb</argument>
363+
</arguments>
364+
</configuration>
365+
</plugin>
331366
<plugin>
332367
<artifactId>maven-jar-plugin</artifactId>
333368
<version>3.1.0</version>

Diff for: tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/registry/TensorTypeRegistry.java

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ private static <T extends TType> void register(Class<T> type) {
9191

9292
static {
9393
// TODO (karllessard) scan and registered automatically all annotated tensors types
94+
// TODO use in generator?
9495
register(TBool.class);
9596
register(TFloat64.class);
9697
register(TFloat32.class);

Diff for: tensorflow-core/tensorflow-core-generator/pom.xml

+13-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
<artifactId>tensorflow-core-generator</artifactId>
1111
<packaging>jar</packaging>
1212

13-
<name>TensorFlow Core Annotation Processor</name>
14-
<description>Annotation processor for TensorFlow Java client</description>
13+
<name>TensorFlow Core Generators</name>
14+
<description>Code generators for TensorFlow Java client</description>
1515

1616
<properties>
1717
<java.module.name>org.tensorflow.core.generator</java.module.name>
@@ -28,11 +28,22 @@
2828
<artifactId>javapoet</artifactId>
2929
<version>1.12.1</version>
3030
</dependency>
31+
<dependency>
32+
<groupId>com.google.protobuf</groupId>
33+
<artifactId>protobuf-java</artifactId>
34+
<version>${protobuf.version}</version>
35+
</dependency>
3136
<dependency>
3237
<groupId>com.github.javaparser</groupId>
3338
<artifactId>javaparser-core</artifactId>
3439
<version>3.15.12</version>
3540
</dependency>
41+
<!-- https://mvnrepository.com/artifact/org.commonmark/commonmark -->
42+
<dependency>
43+
<groupId>org.commonmark</groupId>
44+
<artifactId>commonmark</artifactId>
45+
<version>0.17.1</version>
46+
</dependency>
3647
</dependencies>
3748

3849
<build>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
Copyright 2021 The TensorFlow Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================
16+
*/
17+
package org.tensorflow;
18+
19+
import com.squareup.javapoet.ClassName;
20+
import com.squareup.javapoet.ParameterizedTypeName;
21+
import com.squareup.javapoet.TypeName;
22+
23+
public class Names {
24+
25+
public static final String TensorflowPackage = "org.tensorflow";
26+
public static final String OpPackage = TensorflowPackage + ".op";
27+
public static final String TypesPackage = TensorflowPackage + ".types";
28+
29+
public static final ClassName Operator = ClassName.get(OpPackage + ".annotation", "Operator");
30+
public static final ClassName Endpoint = ClassName.get(OpPackage + ".annotation", "Endpoint");
31+
32+
public static final ClassName TType = ClassName.get(TypesPackage + ".family", "TType");
33+
public static final ClassName TString = ClassName.get(TypesPackage, "TString");
34+
public static final ClassName TBool = ClassName.get(TypesPackage, "TBool");
35+
36+
public static final ClassName TNumber = ClassName.get(TypesPackage + ".family", "TNumber");
37+
38+
public static final ClassName TFloating = ClassName.get(TypesPackage + ".family", "TFloating");
39+
public static final ClassName TBfloat16 = ClassName.get(TypesPackage, "TBfloat16");
40+
public static final ClassName TFloat16 = ClassName.get(TypesPackage, "TFloat16");
41+
public static final ClassName TFloat32 = ClassName.get(TypesPackage, "TFloat32");
42+
public static final ClassName TFloat64 = ClassName.get(TypesPackage, "TFloat64");
43+
44+
public static final ClassName TIntegral = ClassName.get(TypesPackage + ".family", "TIntegral");
45+
public static final ClassName TUint8 = ClassName.get(TypesPackage, "TUint8");
46+
public static final ClassName TInt32 = ClassName.get(TypesPackage, "TInt32");
47+
public static final ClassName TInt64 = ClassName.get(TypesPackage, "TInt64");
48+
49+
public static final TypeName Op = ClassName.get(OpPackage, "Op");
50+
public static final ClassName RawOp = ClassName.get(OpPackage, "RawOp");
51+
public static final ClassName Operation = ClassName.get(TensorflowPackage, "Operation");
52+
public static final ClassName Operands = ClassName.get(OpPackage, "Operands");
53+
public static final ClassName OperationBuilder = ClassName.get(TensorflowPackage, "OperationBuilder");
54+
public static final TypeName IterableOp = ParameterizedTypeName.get(ClassName.get(Iterable.class), Op);
55+
56+
public static final ClassName Operand = ClassName.get(TensorflowPackage, "Operand");
57+
public static final ClassName Output = ClassName.get(TensorflowPackage, "Output");
58+
59+
public static final ClassName Shape = ClassName.get(TensorflowPackage + ".ndarray", "Shape");
60+
public static final ClassName Tensor = ClassName.get(TensorflowPackage, "Tensor");
61+
public static final ClassName ConcreteFunction = ClassName.get(TensorflowPackage, "ConcreteFunction");
62+
63+
public static final ClassName Scope = ClassName.get(OpPackage, "Scope");
64+
public static final TypeName DeviceSpec = ClassName.get(TensorflowPackage, "DeviceSpec");
65+
public static final ClassName Ops = ClassName.get(OpPackage, "Ops");
66+
67+
public static final TypeName ExecutionEnvironment =
68+
ClassName.get(TensorflowPackage, "ExecutionEnvironment");
69+
public static final TypeName EagerSession = ClassName.get(TensorflowPackage, "EagerSession");
70+
71+
public static final TypeName String = ClassName.get(String.class);
72+
73+
}

0 commit comments

Comments
 (0)