Skip to content

Commit 09fc07e

Browse files
authored
Merge pull request #4 from tensorflow/master
Merge main branch to local branch
2 parents c57a2e7 + 6f3ec0f commit 09fc07e

File tree

43 files changed

+2197
-152
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2197
-152
lines changed

README.md

+4-21
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
# TensorFlow for Java
22

3-
***!!! IMPORTANT NOTICE !!! This repository is UNDER CONSTRUCTION and does not yet host the code of the
4-
offical TensorFlow Java artifacts!***
5-
6-
***Please refer to the [TensorFlow Java module](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java)
7-
of the main repository for the actual code.***
8-
93
## Welcome to the Java world of TensorFlow!
104

115
TensorFlow can run on any JVM for building, training and running machine learning models. It comes with
@@ -30,19 +24,10 @@ The following describes the layout of the repository and its different artifacts
3024
TensorFlow and just want a thin layer to access the TensorFlow runtime from the JVM
3125

3226
* `tensorflow-framework`
33-
* Complete but fairly primitive API for building and training neural networks with TensorFlow
34-
* Intended audience: expert neural network developers who prefer to make explicit, detailed decisions
35-
about their models and training algorithms
36-
37-
* `tensorflow-keras` (early WIP; only defined in `dev` profile)
38-
* Partially covers the framework API to allow simpler definition of models and training algorithms
39-
* Intended to be familiar if you know the Python Keras API, but prioritizes clean, idiomatic Java
40-
over fidelity to Python
41-
* Provides defaults based on common best practices
42-
* Allows developers to selectively be more explicit by overriding defaults or dipping into the framework API
43-
* Intended audience: neural network developers across the spectrum from beginner to expert who prefer to
44-
rely mostly on best-practice defaults and then selectively fine-tune
45-
27+
* Primary API for building and training neural networks with TensorFlow
28+
* Intended audience: neural network developers
29+
* For more information: [tensorflow-framework/README.md](tensorflow-framework/README.md)
30+
4631
* `ndarray`
4732
* Generic utility library for n-dimensional data I/O operations
4833
* Used by TensorFlow but does not depend on TensorFlow
@@ -172,8 +157,6 @@ This table shows the mapping between different version of TensorFlow for Java an
172157

173158
| TensorFlow Java Version | TensorFlow Version |
174159
| ------------- | ------------- |
175-
| 0.1.0-SNAPSHOT | 2.2.0 |
176-
| 0.2.0-SNAPSHOT | 2.3.1 |
177160
| 0.2.0 | 2.3.1 |
178161
| 0.3.0-SNAPSHOT | 2.3.1 |
179162

pom.xml

-13
Original file line numberDiff line numberDiff line change
@@ -119,19 +119,6 @@
119119
</dependencyManagement>
120120

121121
<profiles>
122-
<!--
123-
Developing profile
124-
The 'dev' profile is used for local development or PR compilation check only.
125-
Here, we enable the `tensorflow-keras` module only under this profile, until
126-
it is mature enough for being deployed and distributed for the end users.
127-
-->
128-
<profile>
129-
<id>dev</id>
130-
<modules>
131-
<!-- Disabled while the library is still empty -->
132-
<!--module>tensorflow-keras</module-->
133-
</modules>
134-
</profile>
135122
<!--
136123
Deploying profile
137124
Build the Javadoc when deploying

tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LeakyRelu.pbtxt

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
op {
22
graph_op_name: "LeakyRelu"
3+
visibility: VISIBLE
34
endpoint {
45
name: "nn.LeakyRelu"
56
}

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java

+14
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import org.tensorflow.op.nn.FusedResizeAndPadConv2d;
6060
import org.tensorflow.op.nn.InTopK;
6161
import org.tensorflow.op.nn.L2Loss;
62+
import org.tensorflow.op.nn.LeakyRelu;
6263
import org.tensorflow.op.nn.LearnedUnigramCandidateSampler;
6364
import org.tensorflow.op.nn.LocalResponseNormalization;
6465
import org.tensorflow.op.nn.LogSoftmax;
@@ -1226,6 +1227,19 @@ public <T extends TNumber> L2Loss<T> l2Loss(Operand<T> t) {
12261227
return L2Loss.create(scope, t);
12271228
}
12281229

1230+
/**
1231+
* Computes rectified linear: `max(features, features * alpha)`.
1232+
*
1233+
* @param <T> data type for {@code activations()} output
1234+
* @param features
1235+
* @param options carries optional attributes values
1236+
* @return a new instance of LeakyRelu
1237+
*/
1238+
public <T extends TNumber> LeakyRelu<T> leakyRelu(Operand<T> features,
1239+
LeakyRelu.Options... options) {
1240+
return LeakyRelu.create(scope, features, options);
1241+
}
1242+
12291243
/**
12301244
* Generates labels for candidate sampling with a learned unigram distribution.
12311245
* <p>

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/LeakyRelu.java

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
*
3434
* @param <T> data type for {@code activations()} output
3535
*/
36+
@Operator(group = "nn")
3637
public final class LeakyRelu<T extends TNumber> extends RawOp implements Operand<T> {
3738

3839
/**

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ opts, runOpts, new BytePointer(exportDir), new PointerPointer(tags),
429429
}
430430

431431
private static void validateTags(String[] tags) {
432-
if (tags == null || tags.length == 0 || Arrays.stream(tags).anyMatch(t -> t == null || t.isEmpty())) {
432+
if (tags == null || Arrays.stream(tags).anyMatch(t -> t == null || t.isEmpty())) {
433433
throw new IllegalArgumentException("Invalid tags: " + Arrays.toString(tags));
434434
}
435435
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import org.tensorflow.ndarray.NdArray;
3131
import org.tensorflow.ndarray.StdArrays;
3232
import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray;
33-
import org.tensorflow.types.family.TNumber;
33+
import org.tensorflow.types.family.TFloating;
3434

3535
/**
3636
* Brain 16-bit float tensor type.
@@ -48,7 +48,7 @@
4848
* <p>Note that some CPUs support the bfloat16 format natively, which can result in faster
4949
* computation compared to {@link TFloat16} when GPUs are not used.
5050
*/
51-
public interface TBfloat16 extends FloatNdArray, TNumber {
51+
public interface TBfloat16 extends FloatNdArray, TFloating {
5252
/** readable-name for the data type */
5353
static final String NAME = "BFLOAT16";
5454

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import org.tensorflow.ndarray.NdArray;
3131
import org.tensorflow.ndarray.StdArrays;
3232
import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray;
33-
import org.tensorflow.types.family.TNumber;
33+
import org.tensorflow.types.family.TFloating;
3434

3535
/**
3636
* IEEE-754 half-precision 16-bit float tensor type.
@@ -45,7 +45,7 @@
4545
* most CPUs do not support this format natively. For CPU computation on 16-bit floats, the {@link
4646
* TBfloat16} tensor type might be a better option.
4747
*/
48-
public interface TFloat16 extends FloatNdArray, TNumber {
48+
public interface TFloat16 extends FloatNdArray, TFloating {
4949

5050
/** readable-name for the data type */
5151
static final String NAME = "FLOAT16";

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@
2929
import org.tensorflow.ndarray.NdArray;
3030
import org.tensorflow.ndarray.StdArrays;
3131
import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray;
32-
import org.tensorflow.types.family.TNumber;
32+
import org.tensorflow.types.family.TFloating;
3333

3434
/** IEEE-754 single-precision 32-bit float tensor type. */
35-
public interface TFloat32 extends FloatNdArray, TNumber {
35+
public interface TFloat32 extends FloatNdArray, TFloating {
3636

3737
/** readable-name for the data type */
3838
static final String NAME = "FLOAT";

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@
2929
import org.tensorflow.ndarray.NdArray;
3030
import org.tensorflow.ndarray.StdArrays;
3131
import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray;
32-
import org.tensorflow.types.family.TNumber;
32+
import org.tensorflow.types.family.TFloating;
33+
3334

3435
/** IEEE-754 double-precision 64-bit float tensor type. */
35-
public interface TFloat64 extends DoubleNdArray, TNumber {
36+
public interface TFloat64 extends DoubleNdArray, TFloating {
3637

3738
/** readable-name for the data type */
3839
static final String NAME = "DOUBLE";
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package org.tensorflow.types.family;
2+
3+
/**
4+
* Marker interface for floating point tensor types.
5+
*
6+
* <p>Operations that only accepts floating point values as some of their operands enforce that the tensor
7+
* types for these operands to be bound to this interface. For example:
8+
*
9+
* <pre>{@code
10+
* TFloat32 tensor1 = TFloat32.vectorOf(1, 2, 3);
11+
* TBool tensor2 = TBool.vectorOf(true, false, true);
12+
*
13+
* Ops tf = Ops.create();
14+
* Exponential<TFloat32> exp = new Exponential<>(tf);
15+
* exp.call(tf.constant(tensor1)); // OK
16+
* exp.call(tf.constant(tensor2)); // Compilation failure
17+
* }</pre>
18+
*/
19+
public interface TFloating extends TNumber {}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java

+2-8
Original file line numberDiff line numberDiff line change
@@ -262,10 +262,7 @@ public void cannotExportMultipleFunctionsWithSameSignatureKey() throws IOExcepti
262262
@Test
263263
public void cannotExportOrImportInvalidTags() {
264264
assertThrows(IllegalArgumentException.class, () ->
265-
SavedModelBundle.loader("/").withTags()
266-
);
267-
assertThrows(IllegalArgumentException.class, () ->
268-
SavedModelBundle.loader("/").withTags(new String[]{})
265+
SavedModelBundle.loader("/").withTags(null)
269266
);
270267
assertThrows(IllegalArgumentException.class, () ->
271268
SavedModelBundle.loader("/").withTags(new String[]{"tag", null})
@@ -274,10 +271,7 @@ public void cannotExportOrImportInvalidTags() {
274271
SavedModelBundle.loader("/").withTags(new String[]{"tag", ""})
275272
);
276273
assertThrows(IllegalArgumentException.class, () ->
277-
SavedModelBundle.exporter("/").withTags()
278-
);
279-
assertThrows(IllegalArgumentException.class, () ->
280-
SavedModelBundle.exporter("/").withTags(new String[]{})
274+
SavedModelBundle.exporter("/").withTags(null)
281275
);
282276
assertThrows(IllegalArgumentException.class, () ->
283277
SavedModelBundle.exporter("/").withTags(new String[]{"tag", null})

tensorflow-framework/README.md

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Framework API
2+
3+
This is the primary Java API for building and training neural networks with TensorFlow.
4+
This API deliberately mirrors the overall structure of Python Keras. However, it
5+
is intended as a comfortable, idiomatic Java API for developers who may or may not
6+
be familiar with Keras.
7+
8+
This API is intended to provide convenient, sensible defaults, while still allowing you to
9+
exercise fine control over the details of your model, training, and inference when necessary.
10+
11+
More specifically, the following goals drive API evolution:
12+
13+
* If either you know how to implement a model in the Python Keras API, or you are reimplementing an
14+
existing Python Keras model in Java, you should be able to cleanly and naturally follow the same
15+
high-level structure in the framework API.
16+
17+
* Also, given some familiarity with patterns followed throughout the framework API, you should be
18+
able to easily translate every detail of a Python Keras implementation into the framework API.
19+
20+
* However, the framework API is not intended to literally mimic the Python Keras API. Rather, it
21+
should expose the same capabilities in an API that feels natural and idiomatic to a Java
22+
programmer who does not know Keras. If we ever find ourselves unable to reconcile this goal with
23+
easy translation from Python Keras, we may split out a Keras layer.
24+
25+
* Also, the framework API should support fine control over all aspects of modeling, training, and
26+
inference. Unlike with Python Keras, we want this to feel like staying in the same API rather
27+
than diving into a separate layer. But here again, if we are ever unable to reconcile this goal
28+
with easy translation from Python Keras, we may split the framework API into two layers.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================*/
15+
package org.tensorflow.framework.activations;
16+
17+
import org.tensorflow.Operand;
18+
import org.tensorflow.op.Ops;
19+
import org.tensorflow.types.family.TNumber;
20+
21+
/**
22+
* Abstract base class for Activations
23+
*
24+
* <p><b>Note:</b> The {@link #tf} attribute must be set prior to invoking the call method. See
25+
* {@link #setTF(Ops)} and the constructor {@link #Activation(Ops)}.
26+
*
27+
* @param <T> the data type of the activation
28+
*/
29+
public abstract class Activation<T extends TNumber> {
30+
31+
/** The TensorFlow Ops */
32+
protected Ops tf;
33+
34+
/**
35+
* Creates the abstract class for an Activation
36+
*
37+
* @param tf the TensorFlow Ops
38+
*/
39+
protected Activation(Ops tf) {
40+
this.tf = tf;
41+
}
42+
43+
/**
44+
* Sets the TensorFlow Ops
45+
*
46+
* @param tf the TensorFlow Ops
47+
*/
48+
protected void setTF(Ops tf) {
49+
this.tf = tf;
50+
}
51+
52+
/**
53+
* Gets the TensorFlow Ops
54+
*
55+
* @return the TensorFlow Ops
56+
*/
57+
protected Ops getTF() {
58+
return this.tf;
59+
}
60+
61+
/**
62+
* Gets the calculation operation for the activation.
63+
*
64+
* @param input the input tensor
65+
* @return The operand for the activation
66+
*/
67+
public abstract Operand<T> call(Operand<T> input);
68+
}

0 commit comments

Comments
 (0)