diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SoftmaxCrossEntropyWithLogits.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SoftmaxCrossEntropyWithLogits.pbtxt index 5dba2164cd6..e064562c0f2 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SoftmaxCrossEntropyWithLogits.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SoftmaxCrossEntropyWithLogits.pbtxt @@ -1,6 +1,6 @@ op { graph_op_name: "SoftmaxCrossEntropyWithLogits" endpoint { - name: "nn.raw.SoftmaxCrossEntropyWithLogits" + name: "nn.SoftmaxCrossEntropyWithLogits" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SparseSoftmaxCrossEntropyWithLogits.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SparseSoftmaxCrossEntropyWithLogits.pbtxt index cf80ff77565..7627d5f6074 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SparseSoftmaxCrossEntropyWithLogits.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SparseSoftmaxCrossEntropyWithLogits.pbtxt @@ -1,6 +1,6 @@ op { graph_op_name: "SparseSoftmaxCrossEntropyWithLogits" endpoint { - name: "nn.raw.SparseSoftmaxCrossEntropyWithLogits" + name: "nn.SparseSoftmaxCrossEntropyWithLogits" } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java index 8958b4fe2ff..2bd4d13145f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java @@ -83,7 +83,6 @@ import org.tensorflow.op.nn.Relu; import org.tensorflow.op.nn.Relu6; import org.tensorflow.op.nn.Selu; -import org.tensorflow.op.nn.SigmoidCrossEntropyWithLogits; import org.tensorflow.op.nn.Softmax; import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits; import org.tensorflow.op.nn.Softsign; @@ -103,8 +102,6 @@ * @see {@link Ops} */ public final class NnOps { - public final NnRawOps raw; - private final Scope scope; private final Ops ops; @@ -112,7 +109,6 @@ public final class NnOps { NnOps(Ops ops) { this.scope = ops.scope(); this.ops = ops; - raw = new NnRawOps(ops); } /** @@ -1797,56 +1793,6 @@ public Selu selu(Operand features) { return Selu.create(scope, features); } - /** - * Computes sigmoid cross entropy given logits. - * - *

Measures the probability error in discrete classification tasks in which each class is - * independent and not mutually exclusive. For instance, one could perform multilabel - * classification where a picture can contain both an elephant and a dog at the same time. - * - *

For brevity, let x = logits, z = labels. The logistic loss in - * pseudo-code is - * - *

-   *  z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
-   *   = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
-   *   = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
-   *   = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
-   *   = (1 - z) * x + log(1 + exp(-x))
-   *   = x - x * z + log(1 + exp(-x))
-   *  
- * - *

For x < 0, to avoid overflow in exp(-x), we reformulate the above - * - *

-   *  x - x * z + log(1 + exp(-x))
-   *   = log(exp(x)) - x * z + log(1 + exp(-x))
-   *   = - x * z + log(1 + exp(x))
-   *  
- * - *

Hence, to ensure stability and avoid overflow, the implementation uses this equivalent - * formulation - * - *

-   *    max(x, 0) - x * z + log(1 + exp(-abs(x)))
-   *  
- * - *

logits and labels must have the same type and shape. - * - *

- * - * @param scope The TensorFlow scope - * @param labels the labels - * @param logits the logits of type float32 or float64 - * @param the type of labels and logits - * @return the component-wise logistic losses. - * @throws IllegalArgumentException if logits' and labels' do not have the same shape - */ - public Operand sigmoidCrossEntropyWithLogits(Operand labels, - Operand logits) { - return SigmoidCrossEntropyWithLogits.sigmoidCrossEntropyWithLogits(scope, labels, logits); - } - /** * Computes softmax activations. * For each batch {@code i} and class {@code j} we have @@ -1864,54 +1810,20 @@ public Softmax softmax(Operand logits) { } /** - * Computes softmax cross entropy between logits and labels. - * - *

Measures the probability error in discrete classification tasks in which the classes are - * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is - * labeled with one and only one label: an image can be a dog or a truck, but not both. - * - *

NOTE: - * - *

While the classes are mutually exclusive, their probabilities need not be. All that is - * required is that each row of labels is a valid probability distribution. If they - * are not, the computation of the gradient will be incorrect. - * - *

If using exclusive labels (wherein one and only one class is true at a time), - * see {@link org.tensorflow.op.NnOps#sparseSoftmaxCrossEntropyWithLogits} + * Computes softmax cross entropy cost and gradients to backpropagate. + * Inputs are the logits, not probabilities. * - *

Usage: - * - *

-   *    Operand<TFloat32> logits =
-   *        tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
-   *    Operand<TFloat32> labels =
-   *        tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
-   *    Operand<TFloat32> output =
-   *        tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
-   *    // output Shape = [2]
-   *    // dataType = FLOAT (1)
-   *    // values { 0.169846, 0.824745 }
-   *  
- * - *

Backpropagation will happen into both logits and labels. To - * disallow backpropagation into labels, pass label tensors through - * tf.stopGradient before feeding it to this function. - * - * @param scope current scope - * @param labels Each vector along the class dimension should hold a valid probability - * distribution e.g. for the case in which labels are of shape [batch_size, num_classes] - * , each row of labels[i] must be a valid probability distribution. - * @param logits Per-label activations, typically a linear output. These activation energies are - * interpreted as unnormalized log probabilities. - * @param axis The class dimension. -1 is the last dimension. - * @param the number type of the operands - * @return the softmax cross entropy loss. Its type is the same as logits and its - * shape is the same as labels except that it does not have the last dimension of - * labels. + * @param data type for {@code loss} output + * @param features batch_size x num_classes matrix + * @param labels batch_size x num_classes matrix + * The caller must ensure that each batch of labels represents a valid + * probability distribution. + * @param data type for {@code SoftmaxCrossEntropyWithLogits} output and operands + * @return a new instance of SoftmaxCrossEntropyWithLogits */ - public Operand softmaxCrossEntropyWithLogits( - Operand labels, Operand logits, int axis) { - return SoftmaxCrossEntropyWithLogits.softmaxCrossEntropyWithLogits(scope, labels, logits, axis); + public SoftmaxCrossEntropyWithLogits softmaxCrossEntropyWithLogits( + Operand features, Operand labels) { + return SoftmaxCrossEntropyWithLogits.create(scope, features, labels); } /** @@ -2098,51 +2010,23 @@ public SpaceToDepth spaceToDepth(Operand input, Long blo } /** - * Computes sparse softmax cross entropy between logits and labels. - * - *

Measures the probability error in discrete classification tasks in which the classes are - * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is - * labeled with one and only one label: an image can be a dog or a truck, but not both. - * - *

NOTE: - * - *

For this operation, the probability of a given label is considered exclusive. That is, soft - * classes are not allowed, and the labels vector must provide a single specific - * index for the true class for each row of logits (each minibatch entry). For soft - * softmax classification with a probability distribution for each entry, {@link - * org.tensorflow.op.NnOps#softmaxCrossEntropyWithLogits}. - * - *

WARNING: - * - *

This op expects unscaled logits, since it performs a softmax on logits - * internally for efficiency. Do not call this op with the output of softmax, - * as it will produce incorrect results. - * - *

A common use case is to have logits of shape [batchSize, numClasses] and have - * labels of shape [batchSize], but higher dimensions are supported, in which case - * the dim-th dimension is assumed to be of size numClasses. - * logits must have the dataType of TFloat16, TFloat32 - * , or TFloat64, and labels must have the dtype of TInt32 - * or TInt64. - * - * @param scope current scope - * @param labels Tensor of shape [d_0, d_1, ..., d_{r-1}] (where r - * is rank of labels and result) and the dataType is TInt32 - * or TInt64. Each entry in labels must be an index in [0, - * numClasses). Other values will raise an exception when this op is run on CPU, and - * return NaN for corresponding loss and gradient rows on GPU. - * @param logits Per-label activations (typically a linear output) of shape [d_0, d_1, ..., - * d_{r-1}, numClasses] and dataType of TFloat16, TFloat32, - * or TFloat64. These activation energies are interpreted as unnormalized log - * probabilities. - * @return A Tensor of the same shape as labels and of the same type as - * logits with the softmax cross entropy loss. - * @throws IllegalArgumentException If logits are scalars (need to have rank >= 1) or if the rank - * of the labels is not equal to the rank of the logits minus one. - */ - public Operand sparseSoftmaxCrossEntropyWithLogits( - Operand labels, Operand logits) { - return SparseSoftmaxCrossEntropyWithLogits.sparseSoftmaxCrossEntropyWithLogits(scope, labels, logits); + * Computes softmax cross entropy cost and gradients to backpropagate. + * Unlike {@code SoftmaxCrossEntropyWithLogits}, this operation does not accept + * a matrix of label probabilities, but rather a single label per row + * of features. This label is considered to have probability 1.0 for the + * given row. + *

Inputs are the logits, not probabilities. + * + * @param data type for {@code loss} output + * @param features batch_size x num_classes matrix + * @param labels batch_size vector with values in [0, num_classes). + * This is the label for the given minibatch entry. + * @param data type for {@code SparseSoftmaxCrossEntropyWithLogits} output and operands + * @return a new instance of SparseSoftmaxCrossEntropyWithLogits + */ + public SparseSoftmaxCrossEntropyWithLogits sparseSoftmaxCrossEntropyWithLogits( + Operand features, Operand labels) { + return SparseSoftmaxCrossEntropyWithLogits.create(scope, features, labels); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnRawOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnRawOps.java deleted file mode 100644 index c287459c460..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnRawOps.java +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2020 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== -// -// This class has been generated, DO NOT EDIT! -// -package org.tensorflow.op; - -import org.tensorflow.Operand; -import org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits; -import org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits; -import org.tensorflow.types.family.TNumber; - -/** - * An API for building {@code nn.raw} operations as {@link Op Op}s - * - * @see {@link Ops} - */ -public final class NnRawOps { - private final Scope scope; - - private final Ops ops; - - NnRawOps(Ops ops) { - this.scope = ops.scope(); - this.ops = ops; - } - - /** - * Computes softmax cross entropy cost and gradients to backpropagate. - * Inputs are the logits, not probabilities. - * - * @param data type for {@code loss} output - * @param features batch_size x num_classes matrix - * @param labels batch_size x num_classes matrix - * The caller must ensure that each batch of labels represents a valid - * probability distribution. - * @param data type for {@code SoftmaxCrossEntropyWithLogits} output and operands - * @return a new instance of SoftmaxCrossEntropyWithLogits - */ - public SoftmaxCrossEntropyWithLogits softmaxCrossEntropyWithLogits( - Operand features, Operand labels) { - return SoftmaxCrossEntropyWithLogits.create(scope, features, labels); - } - - /** - * Computes softmax cross entropy cost and gradients to backpropagate. - * Unlike {@code SoftmaxCrossEntropyWithLogits}, this operation does not accept - * a matrix of label probabilities, but rather a single label per row - * of features. This label is considered to have probability 1.0 for the - * given row. - *

Inputs are the logits, not probabilities. - * - * @param data type for {@code loss} output - * @param features batch_size x num_classes matrix - * @param labels batch_size vector with values in [0, num_classes). - * This is the label for the given minibatch entry. - * @param data type for {@code SparseSoftmaxCrossEntropyWithLogits} output and operands - * @return a new instance of SparseSoftmaxCrossEntropyWithLogits - */ - public SparseSoftmaxCrossEntropyWithLogits sparseSoftmaxCrossEntropyWithLogits( - Operand features, Operand labels) { - return SparseSoftmaxCrossEntropyWithLogits.create(scope, features, labels); - } - - /** - * Get the parent {@link Ops} object. - */ - public final Ops ops() { - return ops; - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 733e7ca7051..250ea35b9fa 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -354,20 +354,20 @@ public final class Ops { public final SparseOps sparse; - public final TpuOps tpu; - public final BitwiseOps bitwise; + public final TpuOps tpu; + public final MathOps math; public final AudioOps audio; public final SignalOps signal; - public final TrainOps train; - public final QuantizationOps quantization; + public final TrainOps train; + private final Scope scope; private Ops(Scope scope) { @@ -385,13 +385,13 @@ private Ops(Scope scope) { random = new RandomOps(this); strings = new StringsOps(this); sparse = new SparseOps(this); - tpu = new TpuOps(this); bitwise = new BitwiseOps(this); + tpu = new TpuOps(this); math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - train = new TrainOps(this); quantization = new QuantizationOps(this); + train = new TrainOps(this); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java similarity index 98% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SoftmaxCrossEntropyWithLogits.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java index 331933979c7..d6eed5cbe28 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.nn.raw; +package org.tensorflow.op.nn; import org.tensorflow.Operand; import org.tensorflow.Operation; @@ -34,7 +34,7 @@ * @param data type for {@code loss} output */ @Operator( - group = "nn.raw" + group = "nn" ) public final class SoftmaxCrossEntropyWithLogits extends RawOp { /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java similarity index 98% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SparseSoftmaxCrossEntropyWithLogits.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java index 8c48cd0db4d..26498cdce7a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SparseSoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.nn.raw; +package org.tensorflow.op.nn; import org.tensorflow.Operand; import org.tensorflow.Operation; @@ -38,7 +38,7 @@ * @param data type for {@code loss} output */ @Operator( - group = "nn.raw" + group = "nn" ) public final class SparseSoftmaxCrossEntropyWithLogits extends RawOp { /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/resources/ops.pb b/tensorflow-core/tensorflow-core-api/src/gen/resources/ops.pb index 5472f5f8839..fbcecceb5bd 100644 Binary files a/tensorflow-core/tensorflow-core-api/src/gen/resources/ops.pb and b/tensorflow-core/tensorflow-core-api/src/gen/resources/ops.pb differ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Activation.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Activation.java new file mode 100644 index 00000000000..0bde8e0889c --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Activation.java @@ -0,0 +1,120 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that applies an activation function to an output. + * + * @param the data type for the layer's weights and computation. + */ +public class Activation extends Layer { + private final org.tensorflow.framework.activations.Activation activation; + + /** + * Creates an Activation layer using {@link Class#getSimpleName()} for the name. + * + * @param tf the TensorFlow Ops. + * @param activation the activation to apply + * @param type the data type for the weights and computation + */ + public Activation( + Ops tf, org.tensorflow.framework.activations.Activation activation, Class type) { + this(tf, null, activation, type, null); + } + + /** + * Creates an Activation layer using {@link Class#getSimpleName()} for the name. + * + * @param tf the TensorFlow Ops. + * @param activation the activation to apply + * @param type the data type for the weights and computation + * @param options the layer's options, may be null + */ + public Activation( + Ops tf, + org.tensorflow.framework.activations.Activation activation, + Class type, + Options options) { + this(tf, null, activation, type, options); + } + + /** + * Creates an Activation layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer, if null will use {@link Class#getSimpleName()} for + * the name. + * @param activation the activation to apply + * @param type the data type for the weights and computation + */ + public Activation( + Ops tf, + String name, + org.tensorflow.framework.activations.Activation activation, + Class type) { + this(tf, name, activation, type, null); + } + /** + * Creates an Activation layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer, if null will use {@link Class#getSimpleName()} for + * the name. + * @param activation the activation to apply + * @param type the data type for the weights and computation + * @param options the layer's options, may be null + */ + public Activation( + Ops tf, + String name, + org.tensorflow.framework.activations.Activation activation, + Class type, + Options options) { + super(tf, name, true, type, options); + this.activation = activation; + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + Ops tf = getTF(); + List> results = new ArrayList<>(); + inputs.forEach( + input -> results.add(cast(tf, activation.call(cast(tf, input, getType())), resultType))); + return callPostProcess(results, training); + } + + /** {@inheritDoc} */ + @Override + public List computeOutputShape(List inputShapes) { + return inputShapes; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Add.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Add.java new file mode 100644 index 00000000000..02979c02942 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Add.java @@ -0,0 +1,92 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.Merge; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that adds a list of inputs element-wise. + * + *

It takes as input a list of tensors, all of the same shape, and returns a single tensor (also + * of the same shape). + * + * @param the data type for the layer's weights and computation. + */ +public class Add extends Merge { + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Add(Ops tf, Class type) { + this(tf, null, type, null); + } + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + * @param options the layer options + */ + public Add(Ops tf, Class type, Options options) { + this(tf, null, type, options); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Add(Ops tf, String name, Class type) { + this(tf, name, type, null); + } + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + * @param options the layer options + */ + public Add(Ops tf, String name, Class type, Options options) { + + super(tf, name, type, options); + } + + /** {@inheritDoc} */ + @Override + protected Operand mergeFunction(List> inputs) { + Ops tf = getTF(); + Operand output = cast(tf, tf.identity(inputs.get(0)), getType()); + for (int i = 1; i < inputs.size(); i++) { + output = tf.math.add(output, cast(tf, inputs.get(i), getType())); + } + return output; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/AlphaDropout.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/AlphaDropout.java new file mode 100644 index 00000000000..b8f50991f43 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/AlphaDropout.java @@ -0,0 +1,174 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.RandomUniform; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Applies Alpha Dropout to the input. + * + *

Alpha Dropout is a Dropout that keeps mean and variance of inputs to their + * original values, in order to ensure the self-normalizing property even after this dropout. Alpha + * Dropout fits well to Scaled Exponential Linear Units by randomly setting activations to the + * negative saturation value. + */ +public class AlphaDropout extends Layer { + private static final long DEFAULT_GRAPH_SEED = 87654321; + private final float rate; + private final Shape noiseShape; + private final long seed; + + /** + * Creates a AlphaDropout layer, using a unique name will be generated based on {@link + * Class#getSimpleName()} and no noiseShape. + * + * @param tf the TensorFlow Ops + * @param rate A number between 0 and 1. Drop probability (as with {@link Dropout}). The + * multiplicative noise will have standard deviation sqrt(rate / (1 - rate)). + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public AlphaDropout(Ops tf, float rate, long seed, Class type, Options options) { + + this(tf, null, rate, null, seed, type, options); + } + + /** + * Creates a AlphaDropout layer, using a unique name will be generated based on {@link + * Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param rate A number between 0 and 1. Drop probability (as with {@link Dropout}). The + * multiplicative noise will have standard deviation sqrt(rate / (1 - rate)). + * @param noiseShape Optional, 1D integer tensor representing the shape of the binary dropout mask + * that will be multiplied with the input. For instance, if your inputs have shape + * (batch_size, timesteps, features) and you want the dropout mask to be the same for all + * timesteps, you can use noise_shape=(batch_size, 1, features). May be null. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public AlphaDropout( + Ops tf, float rate, Shape noiseShape, long seed, Class type, Options options) { + this(tf, null, rate, noiseShape, seed, type, options); + } + + /** + * Creates a AlphaDropout layer + * + * @param tf the TensorFlow Ops + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param rate A number between 0 and 1. Drop probability (as with {@link Dropout}). The + * multiplicative noise will have standard deviation sqrt(rate / (1 - rate)). + * @param noiseShape Optional, 1D integer tensor representing the shape of the binary dropout mask + * that will be multiplied with the input. For instance, if your inputs have shape + * (batch_size, timesteps, features) and you want the dropout mask to be the same for all + * timesteps, you can use noise_shape=(batch_size, 1, features). May be null. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public AlphaDropout( + Ops tf, + String name, + float rate, + Shape noiseShape, + long seed, + Class type, + Options options) { + super(tf, name, true, type, options); + if (rate < 0 || rate >= 1) + throw new IllegalArgumentException("The rate must be between >= 0 and < 1, inclusive."); + this.rate = rate; + this.noiseShape = noiseShape; + this.seed = seed; + setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + Ops tf = getTF(); + + if (!training || rate < 0 || rate > 1) { + return convertList(inputs, resultType); + } + + // training = true + List> outputs = new ArrayList<>(); + Operand rateT = cast(tf, tf.constant(rate), getType()); + Operand alpha = cast(tf, tf.constant(1.6732632423543772848170429916717), getType()); + Operand scale = cast(tf, tf.constant(1.0507009873554804934193349852946), getType()); + // alpha_p = -alpha * scale + Operand alpha_p = tf.math.mul(tf.math.neg(alpha), scale); + Operand one = cast(tf, tf.constant(1), getType()); + Operand minusPoint5 = cast(tf, tf.constant(-0.5), getType()); + // a = ((1 - rate) * (1 + rate * alpha_p**2))**-0.5 + Operand a = + tf.math.pow( + tf.math.mul( + tf.math.sub(one, rateT), + tf.math.add(one, tf.math.mul(rateT, tf.math.mul(alpha_p, alpha_p)))), + minusPoint5); + // b = -a * alpha_p * rate + Operand b = tf.math.mul(tf.math.neg(a), tf.math.mul(alpha_p, rateT)); + + for (Operand input : inputs) { + Operand tInput = cast(tf, input, getType()); + Operand noise = + noiseShape == null ? tf.shape(input, TInt64.class) : tf.constant(noiseShape); + Operand randomTensor = + tf.random.randomUniform( + noise, getType(), RandomUniform.seed(DEFAULT_GRAPH_SEED), RandomUniform.seed2(seed)); + Operand keptIdx = cast(tf, tf.math.greaterEqual(randomTensor, rateT), getType()); + Operand x = + tf.math.add( + tf.math.mul(tInput, keptIdx), tf.math.mul(alpha_p, tf.math.sub(one, keptIdx))); + // result = a*x + b + //noinspection SuspiciousNameCombination + Operand result = tf.math.add(tf.math.mul(a, x), b); + outputs.add(result); + } + + return callPostProcess(convertTo(outputs, resultType), true); + } + + /** {@inheritDoc} */ + @Override + public List computeOutputShape(List inputShapes) { + return inputShapes; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Average.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Average.java new file mode 100644 index 00000000000..7e31e662258 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Average.java @@ -0,0 +1,92 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.Merge; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that averages a list of inputs element-wise. + * + *

It takes as input a list of tensors, all of the same shape, and returns a single tensor (also + * of the same shape). + * + * @param the data type for the layer's weights and computation. + */ +public class Average extends Merge { + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Average(Ops tf, Class type) { + this(tf, null, type, null); + } + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Average(Ops tf, Class type, Options options) { + this(tf, null, type, options); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Average(Ops tf, String name, Class type) { + this(tf, name, type, null); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Average(Ops tf, String name, Class type, Options options) { + super(tf, name, type, options); + } + + /** {@inheritDoc} */ + @Override + protected Operand mergeFunction(List> inputs) { + Ops tf = getTF(); + Operand output = cast(tf, tf.identity(inputs.get(0)), getType()); + for (int i = 1; i < inputs.size(); i++) { + output = tf.math.add(output, cast(tf, inputs.get(i), getType())); + } + return tf.math.div(output, cast(tf, tf.constant(inputs.size()), getType())); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Concatenate.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Concatenate.java new file mode 100644 index 00000000000..8af1f9a1c18 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Concatenate.java @@ -0,0 +1,387 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.Merge; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that concatenates a list of inputs element-wise. + * + *

It takes as input a list of tensors, all of the same shape except for the concatenation axis, + * and returns a single tensor that is the concatenation of all inputs. + * + * @param the data type for the layer's weights and computation. + */ +public class Concatenate extends Merge { + public static final int DEFAULT_AXIS = -1; + private int axis; + + /** + * Creates a Concatenate Layer using {@link Class#getSimpleName()} as the layer name , and using + * {@link #DEFAULT_AXIS} for the axis along which to concatenate. + * + * @param type the data type for the weights and computation + */ + public Concatenate(Class type) { + this(null, null, DEFAULT_AXIS, type, null); + } + + /** + * Creates a Concatenate Layer using {@link Class#getSimpleName()} as the layer name , and using + * {@link #DEFAULT_AXIS} for the axis along which to concatenate. + * + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Concatenate(Class type, Options options) { + this(null, null, DEFAULT_AXIS, type, options); + } + + /** + * Creates a Concatenate Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param axis Axis along which to concatenate. + * @param type the data type for the weights and computation + */ + public Concatenate(int axis, Class type) { + this(null, null, axis, type, null); + } + + /** + * Creates a Concatenate Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param axis Axis along which to concatenate. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Concatenate(int axis, Class type, Options options) { + this(null, null, axis, type, options); + } + + /** + * Creates a Concatenate Layer using {@link #DEFAULT_AXIS} for the axis along which to + * concatenate. + * + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Concatenate(String name, Class type) { + this(null, name, DEFAULT_AXIS, type, null); + } + + /** + * Creates a Concatenate Layer using {@link #DEFAULT_AXIS} for the axis along which to + * concatenate. + * + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Concatenate(String name, Class type, Options options) { + this(null, name, DEFAULT_AXIS, type, options); + } + + /** + * Creates a Concatenate Layer + * + * @param axis Axis along which to concatenate. + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Concatenate(String name, int axis, Class type) { + this(null, name, axis, type, null); + } + + /** + * Creates a Concatenate Layer + * + * @param axis Axis along which to concatenate. + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Concatenate(String name, int axis, Class type, Options options) { + this(null, name, axis, type, options); + } + + /** + * Creates a Concatenate Layer using {@link Class#getSimpleName()} as the layer name, and using + * {@link #DEFAULT_AXIS} for the axis along which to concatenate. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Concatenate(Ops tf, Class type) { + this(tf, null, DEFAULT_AXIS, type, null); + } + + /** + * Creates a Concatenate Layer using {@link Class#getSimpleName()} as the layer name, and using + * {@link #DEFAULT_AXIS} for the axis along which to concatenate. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Concatenate(Ops tf, Class type, Options options) { + this(tf, null, DEFAULT_AXIS, type, options); + } + + /** + * Creates a Concatenate Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param axis Axis along which to concatenate. + * @param type the data type for the weights and computation + */ + public Concatenate(Ops tf, int axis, Class type) { + this(tf, null, axis, type, null); + } + + /** + * Creates a Concatenate Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param axis Axis along which to concatenate. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Concatenate(Ops tf, int axis, Class type, Options options) { + this(tf, null, axis, type, options); + } + + /** + * Creates a Concatenate Layer using {@link #DEFAULT_AXIS} for the axis along which to + * concatenate. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Concatenate(Ops tf, String name, Class type) { + this(tf, name, DEFAULT_AXIS, type, null); + } + + /** + * Creates a Concatenate Layer using {@link #DEFAULT_AXIS} for the axis along which to + * concatenate. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Concatenate(Ops tf, String name, Class type, Options options) { + this(tf, name, DEFAULT_AXIS, type, options); + } + + /** + * Creates a Concatenate Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axis Axis along which to concatenate. + * @param type the data type for the weights and computation + */ + public Concatenate(Ops tf, String name, int axis, Class type) { + this(tf, name, axis, type, null); + } + /** + * Creates a Concatenate Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axis Axis along which to concatenate. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Concatenate(Ops tf, String name, int axis, Class type, Options options) { + super(tf, name, type, options); + this.axis = axis; + this.setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @Override + public List> computeMask( + List> inputs, List> masks) { + if (masks == null || masks.isEmpty()) { + return null; + } + if (inputs.size() != masks.size()) { + throw new IllegalArgumentException("The lists inputs and masks should have the same length."); + } + boolean allNull = true; + for (Operand m : masks) { + if (m != null) { + allNull = false; + break; + } + } + if (allNull) { + return null; + } + + final Ops tf = getTF(); + + List> rMasks = + masks.stream().map(m -> cast(getTF(), m, TBool.class)).collect(Collectors.toList()); + + List> newMasks = new ArrayList<>(); + for (int i = 0; i < inputs.size(); i++) { + Operand input = inputs.get(i); + Operand mask = rMasks.get(i); + if (mask == null) { + newMasks.add(cast(tf, tf.onesLike(input), TBool.class)); + } else if (mask.rank() < input.rank()) { + newMasks.add(tf.expandDims(mask, tf.constant(-1))); + } else { + newMasks.add(mask); + } + } + Operand concat = tf.concat(newMasks, tf.constant(axis)); + return Collections.singletonList(tf.reduceAll(concat, tf.constant(-1))); + } + + /** {@inheritDoc} */ + @Override + public void build(List inputShapes) { + + // Used purely for shape validation. + if (inputShapes.size() < 2) { + throw new IllegalArgumentException("A Concatenate layer must have at least 2 inputs."); + } + boolean allShapesUnknown = true; + for (Shape shape : inputShapes) { + if (!shape.isUnknown()) { + allShapesUnknown = false; + break; + } + } + if (allShapesUnknown) { + this.setBuilt(true); + return; + } + Integer rank = null; + long[][] shapesArray = new long[inputShapes.size()][]; + for (int i = 0; i < inputShapes.size(); i++) { + + Shape shape = inputShapes.get(i); + long[] dims = new long[shape.numDimensions() - 1]; + for (int j = 0, k = 0; j < dims.length; k++) { + if (k == axis) continue; + dims[j++] = shape.size(i); + } + + if (rank == null || rank == Shape.UNKNOWN_SIZE) { + rank = shape.numDimensions(); + } else if (rank != shape.numDimensions()) { + throw new IllegalArgumentException( + String.format( + "A Concatenate layer requires inputs with matching shapes %s", + shapesToString(inputShapes))); + } + shapesArray[i] = dims; + } + + if (axis < 0) { + axis = Math.floorMod(axis, rank); + } + long[] firstShape = shapesArray[0]; + for (int i = 1; i < shapesArray.length; i++) { + for (int j = 0; j < shapesArray[i].length; j++) { + if (shapesArray[i][j] != firstShape[j] + && shapesArray[i][j] != Shape.UNKNOWN_SIZE + && firstShape[j] != Shape.UNKNOWN_SIZE) { + throw new IllegalArgumentException( + String.format( + "A Concatenate layer requires inputs with matching shapes %s", + shapesToString(inputShapes))); + } + } + } + + this.setBuilt(true); + } + + /** + * Coverts a list of shapes to a String + * + * @param shapes the list of shapes. + * @return list of shapes as a String + */ + private String shapesToString(List shapes) { + StringBuilder sb = new StringBuilder("[ "); + boolean first = true; + for (Shape shape : shapes) { + if (!first) { + sb.append(", "); + } else { + first = false; + } + sb.append(shape); + } + sb.append(" ]"); + return sb.toString(); + } + + /** {@inheritDoc} */ + @Override + protected Operand mergeFunction(List> inputs) { + Ops tf = getTF(); + if (inputs.size() < 2) { + throw new IllegalArgumentException("A Concatenate layer must have at least 2 inputs."); + } + Class inputType = inputs.get(0).type(); + List> tList = + inputs.stream().map(item -> cast(tf, item, getType())).collect(Collectors.toList()); + return cast(tf, tf.concat(tList, tf.constant(axis)), inputType); + } + + public List computeOutputShape(List inputShapes) { + build(inputShapes); + Shape outputShape = inputShapes.get(0); + long[] dims = outputShape.asArray(); + if (dims == null) { + dims = new long[] {Shape.UNKNOWN_SIZE}; + } + + for (int i = 1; i < inputShapes.size(); i++) { + Shape shape = inputShapes.get(0); + if (outputShape.size(axis) == Shape.UNKNOWN_SIZE || shape.size(axis) == Shape.UNKNOWN_SIZE) { + dims[axis] = Shape.UNKNOWN_SIZE; + break; + } + dims[axis] += shape.size(axis); + } + + Shape result = Shape.of(dims); + return Collections.singletonList(result); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dense.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dense.java new file mode 100644 index 00000000000..77b0219dc45 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dense.java @@ -0,0 +1,418 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.activations.Activation; +import org.tensorflow.framework.initializers.Glorot; +import org.tensorflow.framework.initializers.Initializer; +import org.tensorflow.framework.initializers.VarianceScaling; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.layers.impl.InputSpec; +import org.tensorflow.framework.layers.impl.VariableDef; +import org.tensorflow.framework.op.FrameworkOps; +import org.tensorflow.framework.regularizers.Regularizer; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.Collections; +import java.util.List; +import java.util.function.UnaryOperator; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * A regular densely-connected NN layer. + * + *

Dense implements the operation: + * output = activation(dot(input, kernel) + bias) where activation is the + * element-wise activation function passed as the activation argument, kernel + * is a weights matrix created by the layer, and bias is a bias vector created + * by the layer (only applicable if useBias is true). + * + *

Note: If the input to the layer has a rank greater than 2, then Dense + * computes the dot product between the inputs and the kernel along the + * last axis of the inputs and axis 1 of the kernel (using + * tf.tensordot). For example, if input has dimensions (batch_size, d0, + * d1), then we create a kernel with shape (d1, units), and the + * kernel operates along axis 2 of the input, on every sub-tensor of shape + * (1, 1, d1) (there are batch_size * d0 such sub-tensors). The output in + * this case will have shape (batch_size, d0, units). + * + * @param the data type for the layer's weights and computation. + */ +public class Dense extends Layer { + + private final Integer units; + private final Activation activation; + private final boolean useBias; + private final long seed; + + private final UnaryOperator> kernelConstraint; + private final UnaryOperator> biasConstraint; + private final Regularizer biasRegularizer; + private final Regularizer kernelRegularizer; + + private Initializer kernelInitializer; + private Initializer biasInitializer; + private Variable kernel; + private Variable bias; + + /** + * Creates a Dense layer. + * + * @param tf the TensorFlow Ops. + * @param units Positive integer, dimensionality of the output space. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + public Dense(Ops tf, Integer units, long seed, Class type) { + this(tf, null, units, null, true, null, null, null, null, null, null, null, seed, type, null); + } + + /** + * Creates a Dense layer. + * + * @param tf the TensorFlow Ops. + * @param units Positive integer, dimensionality of the output space. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public Dense(Ops tf, Integer units, long seed, Class type, Options options) { + this( + tf, null, units, null, true, null, null, null, null, null, null, null, seed, type, options); + } + + /** + * Creates a Dense layer. + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param units Positive integer, dimensionality of the output space. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + public Dense(Ops tf, String name, Integer units, long seed, Class type) { + this(tf, name, units, null, true, null, null, null, null, null, null, null, seed, type, null); + } + + /** + * Creates a Dense layer. + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param units Positive integer, dimensionality of the output space. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public Dense(Ops tf, String name, Integer units, long seed, Class type, Options options) { + this( + tf, name, units, null, true, null, null, null, null, null, null, null, seed, type, options); + } + + /** + * Creates a Dense layer. + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param units Positive integer, dimensionality of the output space. + * @param activation Activation function to use. If you don't specify anything, no activation is + * applied (ie. "linear" activation: a(x) = x). + * @param useBias whether the layer uses a bias vector. + * @param kernelInitializer Initializer for the kernel weights matrix. + * @param biasInitializer Initializer for the bias vector. + * @param kernelRegularizer Regularizer applied to the kernel weights matrix. + * @param biasRegularizer Regularizer function applied to the bias vector. + * @param activityRegularizer Regularizer function applied to the output of the layer (its + * "activation"). + * @param kernelConstraint a constraint on the kernel variable + * @param biasConstraint a constraint on the bias variable + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + public Dense( + Ops tf, + String name, + Integer units, + Activation activation, + boolean useBias, + Initializer kernelInitializer, + Initializer biasInitializer, + Regularizer kernelRegularizer, + Regularizer biasRegularizer, + Regularizer activityRegularizer, + UnaryOperator> kernelConstraint, + UnaryOperator> biasConstraint, + long seed, + Class type) { + this( + tf, + name, + units, + activation, + useBias, + kernelInitializer, + biasInitializer, + kernelRegularizer, + biasRegularizer, + activityRegularizer, + kernelConstraint, + biasConstraint, + seed, + type, + null); + } + /** + * Creates a Dense layer. + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param units Positive integer, dimensionality of the output space. + * @param activation Activation function to use. If you don't specify anything, no activation is + * applied (ie. "linear" activation: a(x) = x). + * @param useBias whether the layer uses a bias vector. + * @param kernelInitializer Initializer for the kernel weights matrix. + * @param biasInitializer Initializer for the bias vector. + * @param kernelRegularizer Regularizer applied to the kernel weights matrix. + * @param biasRegularizer Regularizer function applied to the bias vector. + * @param activityRegularizer Regularizer function applied to the output of the layer (its + * "activation"). + * @param kernelConstraint a constraint on the kernel variable + * @param biasConstraint a constraint on the bias variable + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + @SuppressWarnings("unchecked") + public Dense( + Ops tf, + String name, + Integer units, + Activation activation, + boolean useBias, + Initializer kernelInitializer, + Initializer biasInitializer, + Regularizer kernelRegularizer, + Regularizer biasRegularizer, + Regularizer activityRegularizer, + UnaryOperator> kernelConstraint, + UnaryOperator> biasConstraint, + long seed, + Class type, + Options options) { + super(tf, name, true, type, options); + this.units = units; + this.activation = activation; + this.useBias = useBias; + + this.kernelInitializer = + kernelInitializer != null + ? kernelInitializer + : (Initializer) new Glorot<>(tf, VarianceScaling.Distribution.UNIFORM, seed); + this.biasInitializer = biasInitializer != null ? biasInitializer : new Zeros<>(tf); + this.kernelConstraint = kernelConstraint; + this.biasConstraint = biasConstraint; + this.biasRegularizer = biasRegularizer; + this.kernelRegularizer = kernelRegularizer; + setActivityRegularizer(activityRegularizer); + this.seed = seed; + addInputSpec(new InputSpec(InputSpec.Options.create().minRank(2))); + setSupportsMasking(true); + } + + /** + * Implements the operation: {@code output = activation(dot(input, kernel) + bias)} + * + * @param inputs the input Operands, an N-D tensor with shape: {@code (batch_size, ..., + * input_dim)}. The most common situation would be a 2D input with shape @code (batch_size, + * input_dim)}. + * @param masks a list of masks, one for each input, to apply to the result, may be null + * @param training whether the call is in inference mode or training mode + * @param resultType the data tupe for the result + * @param the data tupe for the result + * @return the output with shape {@code (batch_size, ..., units)}. For instance, for a 2D input + * with shape {@code (batch_size, input_dim)}, the output would have shape {@code (batch_size, + * units)}. + */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + if (inputs == null || inputs.size() != 1) + throw new IllegalArgumentException("Dense only supports 1 input."); + Operand singleInput = inputs.get(0); + Operand input = cast(getTF(), singleInput, getType()); + if (!isBuilt()) build(input.shape()); + Shape inputShape = input.shape(); + int rank = inputShape.numDimensions(); + Operand tOutput; + if (rank == 2 || rank == Shape.UNKNOWN_SIZE) { + tOutput = getTF().linalg.matMul(input, getKernel()); + } else { + FrameworkOps fops = FrameworkOps.create(getTF()); + tOutput = fops.math.tensordot(input, getKernel(), new int[] {rank - 1, 0}); + // Reshape the output back to the original number of dimensions of the input. + Shape newShape = inputShape.take(rank - 1).append(getUnits()); + tOutput = getTF().reshape(tOutput, getTF().constant(newShape)); + } + if (isUseBias()) { + tOutput = getTF().nn.biasAdd(tOutput, getBias()); + } + if (activation != null) { + tOutput = activation.call(tOutput); + } + + return callPostProcess(Collections.singletonList(cast(getTF(), tOutput, resultType)), training); + } + + /** {@inheritDoc} */ + @Override + public void build(List inputShapes) { + super.build(inputShapes); + if (inputShapes == null || inputShapes.size() != 1) { + throw new IllegalArgumentException("Dense only supports 1 input."); + } + if (!TFloating.class.isAssignableFrom(getType())) + throw new IllegalArgumentException( + String.format( + "Unable to build Dense layer with non-floating point type: %s", + getType().toString())); + + if (kernelInitializer == null) { + // Cast is required because Glorot is TFloating. + kernelInitializer = new Glorot<>(getTF(), VarianceScaling.Distribution.UNIFORM, getSeed()); + } + if (biasInitializer == null) { + biasInitializer = new Zeros<>(getTF()); + } + + Shape inputShape = inputShapes.get(0); + if (inputShape.size(-1) == Shape.UNKNOWN_SIZE) { + throw new IllegalArgumentException( + "The last dimension of the inputs to `Dense` should be defined. Found `UNKNOWN`."); + } + long lastDim = inputShape.size(-1); + addInputSpec(new InputSpec(InputSpec.Options.create().minRank(2).axesMap(-1, lastDim))); + + kernel = + addWeight( + "kernel", + Shape.of(lastDim, this.getUnits()), + kernelInitializer, + kernelConstraint, + kernelRegularizer, + true, + getSeed()); + if (isUseBias()) + bias = + addWeight( + "bias", + Shape.of(this.getUnits()), + biasInitializer, + biasConstraint, + biasRegularizer, + true, + getSeed()); + } + + public Operand applyConstraint(Variable variable) { + VariableDef variableDef = getVariableDef(variable); + if(variableDef != null && variableDef.getConstraint() != null) { + return variableDef.getConstraint().apply(variable); + }else { + return variable; + } + } + + /** {@inheritDoc} */ + @Override + public List computeOutputShape(List inputShapes) { + if (inputShapes == null || inputShapes.size() != 1) + throw new IllegalArgumentException("Dense layer: there must be one input shape"); + if (!isBuilt()) build(inputShapes); + Shape singleShape = inputShapes.get(0); + if (singleShape.size(-1) == Shape.UNKNOWN_SIZE) + throw new IllegalArgumentException( + String.format( + "Dense layer: The innermost dimension of input_shape must be defined, but saw: %s", + singleShape)); + Shape headShape = singleShape.take(singleShape.numDimensions() - 1).append(getUnits()); + + return Collections.singletonList(headShape); + } + + /** + * Gets the dense units + * + * @return the dense units + */ + public Integer getUnits() { + return units; + } + + /** + * Gets the use bias flag + * + * @return the use bias flag + */ + public boolean isUseBias() { + return useBias; + } + + /** + * Gets the seed + * + * @return the seed + */ + public long getSeed() { + return seed; + } + + /** + * Gets the kernel variable + * + * @return the kernel variable + */ + public Variable getKernel() { + return kernel; + } + + /** + * Gets the bias variable + * + * @return the bias variable + */ + public Variable getBias() { + return bias; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dot.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dot.java new file mode 100644 index 00000000000..bb0353ec3e6 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dot.java @@ -0,0 +1,556 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.Merge; +import org.tensorflow.framework.op.FrameworkOps; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Squeeze; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that computes a dot product between samples in two tensors. + * + *

E.g. if applied to a list of two tensors a and b of shape + * (batch_size, n), the output will be a tensor of shape (batch_size, 1) where + * each entry i will be the dot product between `a[i]` and `b[i]`. + * + * @param the data type for the layer's weights and computation. + */ +public class Dot extends Merge { + private final int[] axes; + private final boolean normalize; + + /** + * Creates a Layer that computes a dot product between samples in two tensors, using {@link + * Class#getSimpleName()} as the layer name, and no L2 Normalization. + * + * @param tf the TensorFlow Ops + * @param axes axes along which to take the dot product. + * @param type the data type for the weights and computation + */ + public Dot(Ops tf, int axes, Class type) { + this(tf, null, new int[] {axes}, false, type, null); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors, using {@link + * Class#getSimpleName()} as the layer name, and no L2 Normalization. + * + * @param tf the TensorFlow Ops + * @param axes axes along which to take the dot product. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Dot(Ops tf, int axes, Class type, Options options) { + this(tf, null, new int[] {axes}, false, type, options); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors, using {@link + * Class#getSimpleName()} as the layer name and no L2 Normalization. + * + * @param tf the TensorFlow Ops + * @param axes axes along which to take the dot product. Should be one or two integers + * corresponding to the desired axis from the first input and the desired axis from the second + * input, respectively. Note that the size of the two selected axes must match. + * @param type the data type for the weights and computation + */ + public Dot(Ops tf, int[] axes, Class type) { + this(tf, null, axes, false, type, null); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors, using {@link + * Class#getSimpleName()} as the layer name and no L2 Normalization. + * + * @param tf the TensorFlow Ops + * @param axes axes along which to take the dot product. Should be one or two integers + * corresponding to the desired axis from the first input and the desired axis from the second + * input, respectively. Note that the size of the two selected axes must match. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Dot(Ops tf, int[] axes, Class type, Options options) { + this(tf, null, axes, false, type, options); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors with no L2 + * Normalization. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axes axes along which to take the dot product. + * @param type the data type for the weights and computation + */ + public Dot(Ops tf, String name, int axes, Class type) { + this(tf, name, new int[] {axes}, false, type, null); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors with no L2 + * Normalization. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axes axes along which to take the dot product. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Dot(Ops tf, String name, int axes, Class type, Options options) { + this(tf, name, new int[] {axes}, false, type, options); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors with no L2 + * Normalization. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axes axes along which to take the dot product. Should be one or two integers + * corresponding to the desired axis from the first input and the desired axis from the second + * input, respectively. Note that the size of the two selected axes must match. + * @param type the data type for the weights and computation + */ + public Dot(Ops tf, String name, int[] axes, Class type) { + this(tf, name, axes, false, type, null); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors with no L2 + * Normalization. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axes axes along which to take the dot product. Should be one or two integers + * corresponding to the desired axis from the first input and the desired axis from the second + * input, respectively. Note that the size of the two selected axes must match. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Dot(Ops tf, String name, int[] axes, Class type, Options options) { + this(tf, name, axes, false, type, options); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axes axes along which to take the dot product. + * @param normalize Whether to L2-normalize samples along the dot product axis before taking the + * dot product. If set to True, then the output of the dot product is the cosine proximity + * between the two samples. + * @param type the data type for the weights and computation + */ + public Dot(Ops tf, String name, int axes, boolean normalize, Class type) { + this(tf, name, new int[] {axes}, normalize, type, null); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axes axes along which to take the dot product. + * @param normalize Whether to L2-normalize samples along the dot product axis before taking the + * dot product. If set to True, then the output of the dot product is the cosine proximity + * between the two samples. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Dot(Ops tf, String name, int axes, boolean normalize, Class type, Options options) { + this(tf, name, new int[] {axes}, normalize, type, options); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axes axes along which to take the dot product. Should be one or two integers + * corresponding to the desired axis from the first input and the desired axis from the second + * input, respectively. Note that the size of the two selected axes must match. + * @param normalize Whether to L2-normalize samples along the dot product axis before taking the + * dot product. If set to True, then the output of the dot product is the cosine proximity + * between the two samples. + * @param type the data type for the weights and computation + */ + public Dot(Ops tf, String name, int[] axes, boolean normalize, Class type) { + this(tf, name, axes, normalize, type, null); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axes axes along which to take the dot product. Should be one or two integers + * corresponding to the desired axis from the first input and the desired axis from the second + * input, respectively. Note that the size of the two selected axes must match. + * @param normalize Whether to L2-normalize samples along the dot product axis before taking the + * dot product. If set to True, then the output of the dot product is the cosine proximity + * between the two samples. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Dot(Ops tf, String name, int[] axes, boolean normalize, Class type, Options options) { + super(tf, name, type, options); + if (axes.length < 1 || axes.length > 2) { + throw new IllegalArgumentException( + "Invalid format for axes - must only contain one or two elements."); + } + this.axes = axes; + this.normalize = normalize; + this.setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @Override + protected void build(List inputShapes) { + // Used purely for shape validation. + if (inputShapes.size() != 2) { + throw new IllegalArgumentException("A Dot layer should be called on exactly 2 inputs"); + } + Shape shape1 = inputShapes.get(0); + Shape shape2 = inputShapes.get(1); + if (shape1.isUnknown() || shape2.isUnknown()) { + return; + } + int[] newAxes; + if (axes.length == 1) { + newAxes = new int[2]; + // covert negative axes + if (axes[0] < 0) { + newAxes[0] = Math.floorMod(axes[0], shape1.numDimensions()); + newAxes[1] = Math.floorMod(axes[0], shape2.numDimensions()); + } else { + newAxes[0] = axes[0]; + newAxes[1] = axes[0]; + } + } else { + newAxes = axes; + } + if (shape1.size(axes[0]) != shape2.size(axes[1])) { + throw new IllegalArgumentException( + String.format( + "Dimension incompatibility %s != %s. Layer shapes: %s, %s. Chosen axes: %s", + shape1.size(axes[0]), + shape2.size(axes[1]), + shape1, + shape2, + Arrays.toString(newAxes))); + } + } + + /** {@inheritDoc} */ + @Override + protected Operand mergeFunction(List> inputs) { + Ops tf = getTF(); + if (inputs.size() != 2) { + throw new IllegalArgumentException("A Dot layer should be called on exactly 2 inputs"); + } + Operand input1 = inputs.get(0); + Operand input2 = inputs.get(1); + int[] newAxes = new int[2]; + if (axes.length == 1) { + if (axes[0] < 0) { + newAxes[0] = Math.floorMod(axes[0], input1.shape().numDimensions()); + newAxes[1] = Math.floorMod(axes[0], input2.shape().numDimensions()); + } else { + newAxes[0] = axes[0]; + newAxes[1] = axes[0]; + } + } else { + for (int i = 0; i < axes.length; i++) { + if (axes[i] < 0) { + newAxes[i] = Math.floorMod(axes[0], inputs.get(i).shape().numDimensions()); + } else { + newAxes[i] = axes[i]; + } + } + } + if (normalize) { + FrameworkOps fops = FrameworkOps.create(tf); + input1 = fops.math.l2Normalize(input1, new int[] {axes[0]}); + input2 = fops.math.l2Normalize(input2, new int[] {axes[1]}); + } + return batchDot(input1, input2, newAxes); + } + + /** {@inheritDoc} */ + @Override + public List> computeMask( + List> inputs, List> masks) { + return null; + } + + /** {@inheritDoc} */ + @Override + public List computeOutputShape(List inputShapes) { + if (inputShapes.size() != 2) { + throw new IllegalArgumentException("A Dot layer should be called on a list of 2 inputs."); + } + Shape shape1 = inputShapes.get(0); + Shape shape2 = inputShapes.get(1); + int[] lAxes; + if (axes.length == 1) { + lAxes = new int[2]; + lAxes[0] = Math.floorMod(axes[0], shape1.numDimensions()); + lAxes[1] = Math.floorMod(axes[0], shape2.numDimensions()); + } else { + lAxes = axes; + for (int i = 0; i < lAxes.length; i++) { + lAxes[i] = Math.floorMod(axes[i], shape1.numDimensions()); + } + } + + // pop(axes[0]) + shape1 = shape1.take(lAxes[0]); + long remainder = shape1.numDimensions() - (lAxes[0] + 1); + if (remainder > 0) { + shape1 = shape1.append(shape1.takeLast((int) remainder)); + } + + // pop(axes[1]) + shape2 = shape2.take(lAxes[1]); + remainder = shape2.numDimensions() - (lAxes[1] + 1); + if (remainder > 0) { + shape2 = shape2.append(shape2.takeLast((int) remainder)); + } + if (shape2.numDimensions() > 0) { + // pop(0) + shape2 = shape2.takeLast(shape2.numDimensions() - 1); + } + Shape outputShape = shape1.append(shape2); + + if (outputShape.numDimensions() == 1) { + outputShape = outputShape.append(1); + } + return Collections.singletonList(outputShape); + } + + /** + * Computes the batch-wise dot product. + * + *

batchDot is used to compute dot product of x and y + * when x and y are data in batch, i.e. in a shape of + * (batch_size, :). batchDot results in aan Operand with less dimensions than + * the input. If the number of dimensions is reduced to 1, we use expandDims + * to make sure that the number of dimensions is at least 2. + * + * @param x Operand with numdimensions >= 2. + * @param y Operand with numdimensions >= 2. + * @param axes the axes to peform the Dot Product. + * @return A operand with shape equal to the concatenation of x's shape (less the + * dimension that was summed over) and y's shape (less the batch dimension and + * the dimension that was summed over). If the final rank is 1, the result is reshaped to + * (batch_size, 1). + */ + private Operand batchDot( + Operand x, Operand y, int[] axes) { + Ops tf = getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + // make local copy for changes later + int[] dotAxes = axes; + Operand tX = cast(tf, x, getType()); + Operand tY = cast(tf, y, getType()); + + Shape xShape = tX.shape(); + Shape yShape = tY.shape(); + + int xRank = xShape.numDimensions(); + int yRank = yShape.numDimensions(); + + if (xRank < 2 || yRank < 2) { + throw new IllegalArgumentException( + String.format( + "Cannot do batch_dot on inputs with rank < 2. Received inputs with shapes %s and %s.", + xShape, yShape)); + } + + int xBatchSize = (int) xShape.size(0); + int yBatchSize = (int) yShape.size(0); + if (xBatchSize != Shape.UNKNOWN_SIZE && yBatchSize != Shape.UNKNOWN_SIZE) { + if (xBatchSize != yBatchSize) { + throw new IllegalArgumentException( + String.format( + "Cannot do batchDot on inputs with different batch sizes. Received inputs with shapes %s and %s.", + xShape, yShape)); + } + } + + if (dotAxes == null) { + dotAxes = new int[2]; + dotAxes[0] = xRank - 1; + if (yRank == 2) { + dotAxes[1] = yRank - 1; + } else { + dotAxes[1] = yRank - 2; + } + } else if (dotAxes.length == 1) { + dotAxes = new int[] {dotAxes[0], dotAxes[0]}; + } + + if (dotAxes[0] < 0) { + dotAxes[0] = Math.floorMod(dotAxes[0], xRank); + } + if (dotAxes[1] < 0) { + dotAxes[1] = Math.floorMod(dotAxes[1], yRank); + } + if (dotAxes[0] == 0 || dotAxes[1] == 0) { + throw new IllegalArgumentException( + "Cannot perform batch_dot over axis 0. If your inputs are not batched, add a dummy batch dimension to your inputs using tf.expandDims(x, 0)"); + } + + int a0 = dotAxes[0]; + int a1 = dotAxes[1]; + int d1 = (int) xShape.size(a0); + int d2 = (int) yShape.size(a1); + + if (d1 != Shape.UNKNOWN_SIZE && d2 != Shape.UNKNOWN_SIZE && d1 != d2) { + throw new IllegalArgumentException( + String.format( + "Cannot do batch_dot on inputs with shapes %s and %s with axes %s. x.shape[%d] != %d, y.shape[%d] != %d", + xShape, yShape, Arrays.toString(dotAxes), a0, d1, a1, d2)); + } + + // backup rank. Need them rank. + int origXRank = xRank; + int origYRank = yRank; + if (xRank == 2) { + tX = tf.expandDims(tX, tf.constant(1)); + xRank++; + a0++; + } + if (yRank == 2) { + tY = tf.expandDims(tY, tf.constant(2)); + yRank += 1; + } + + // move x's dimension to be reduced to last axis. + if (a0 != xRank - 1) { + int[] pattern = new int[xRank]; + // move a0 to last + for (int i = 0; i < a0; i++) { + pattern[i] = i; + } + for (int i = a0; i < xRank - 1; i++) { + pattern[i] = i + 1; + } + pattern[xRank - 1] = a0; + tX = tf.linalg.transpose(tX, tf.constant(pattern)); + } + // move y's dimension to be reduced to axis 1. + if (a1 != 1) { + int[] pattern = new int[yRank]; + pattern[0] = 0; + // skip slot 1 + for (int i = 1; i < a1; i++) { + pattern[i + 1] = i; + } + for (int i = a1; i < pattern.length - 1; i++) { + pattern[i + 1] = i + 1; + } + pattern[1] = a1; + //noinspection SuspiciousNameCombination + tY = tf.linalg.transpose(tY, tf.constant(pattern)); + } + + // normalize both inputs to rank 3. + boolean xSquashed = false; + Operand xMidShape = null; + if (xRank > 3) { + org.tensorflow.op.core.Shape tmpShape = tf.shape(tX, TInt64.class); + xMidShape = tf.shape.takeLast(tmpShape, tf.constant((long) (xRank - 1)), TInt64.class); + + Operand squashedShape = + tf.stack( + Arrays.asList( + tf.shape.size(tmpShape, tf.constant(0L), TInt64.class), + tf.constant(Shape.UNKNOWN_SIZE), + tf.shape.size(tmpShape, tf.constant((long) (xRank - 1)), TInt64.class))); + tX = tf.reshape(tX, squashedShape); + xSquashed = true; + } + + boolean ySquashed = false; + Operand yTrailDims = null; + if (yRank > 3) { + yTrailDims = + tf.shape.takeLast( + tf.shape(tY, TInt64.class), tf.constant((long) (yRank - 2)), TInt64.class); + + Operand squashedShape = + tf.stack( + Arrays.asList( + tf.shape.size(y, tf.constant(0L), TInt64.class), + tf.shape.size(y, tf.constant(1L), TInt64.class), + tf.constant(-1L))); + tY = tf.reshape(tY, squashedShape); + ySquashed = true; + } + + Operand result = fops.linalg.matmul(tX, tY); + boolean doReshape = false; + Operand outputShape = tf.shape(result, TInt64.class); + + if (xSquashed) { + outputShape = + tf.concat( + Arrays.asList( + tf.shape.size(outputShape, tf.constant(0L), TInt64.class), + xMidShape, + tf.shape.size(outputShape, tf.constant(-1L), TInt64.class)), + tf.constant(0)); + doReshape = true; + } + + if (ySquashed) { + + outputShape = + tf.concat( + Arrays.asList( + tf.slice(outputShape, tf.constant(0), tf.constant(outputShape.rank() - 1)), + yTrailDims), + tf.constant(0)); + doReshape = true; + } + + if (doReshape) { + result = tf.reshape(result, outputShape); + } + + if (origXRank == 2) { + result = tf.squeeze(result, Squeeze.axis(Collections.singletonList(1L))); + } else if (origYRank == 2) { + result = tf.squeeze(result, Squeeze.axis(Collections.singletonList(-1L))); + } + return result; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dropout.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dropout.java new file mode 100644 index 00000000000..6b9e98feda8 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dropout.java @@ -0,0 +1,250 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.RandomUniform; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Applies Dropout to the input. + * + *

The Dropout layer randomly sets input units to 0 with a frequency of rate at each step during + * training time, which helps prevent overfitting. Inputs not set to 0 are scaled up by 1/(1 - rate) + * such that the sum over all inputs is unchanged. + * + *

Note that the Dropout layer only applies when training is set to true such that no values are + * dropped during inference. When using model.fit, training will be appropriately set to true + * automatically, and in other contexts, you can set the kwarg explicitly to True when calling the + * layer. + * + *

(This is in contrast to setting trainable=false for a Dropout layer. trainable does not affect + * the layer's behavior, as Dropout does not have any variables/weights that can be frozen during + * training.) + * + * @param the data type for the layer's weights and computation. + * @see Hinton G, et al. 2012, Improving neural networks + * by preventing co-adaptation of feature detectors + */ +public class Dropout extends Layer { + + private final float rate; + private final Shape noiseShape; + private final long seed; + + /** + * Creates a Dropout layer, using a unique name will be generated based on {@link + * Class#getSimpleName()} and no noiseShape. + * + * @param tf the TensorFlow Ops. + * @param rate A number between 0 and 1. Fraction of the input units to drop. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + public Dropout(Ops tf, float rate, long seed, Class type) { + + this(tf, null, rate, null, seed, type, null); + } + + /** + * Creates a Dropout layer, using a unique name will be generated based on {@link + * Class#getSimpleName()} and no noiseShape. + * + * @param tf the TensorFlow Ops. + * @param rate A number between 0 and 1. Fraction of the input units to drop. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public Dropout(Ops tf, float rate, long seed, Class type, Options options) { + + this(tf, null, rate, null, seed, type, options); + } + + /** + * Creates a Dropout layer, using a unique name will be generated based on {@link + * Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param rate A number between 0 and 1. Fraction of the input units to drop. + * @param noiseShape Optional, 1D integer tensor representing the shape of the binary dropout mask + * that will be multiplied with the input. For instance, if your inputs have shape + * (batch_size, timesteps, features) and you want the dropout mask to be the same for all + * timesteps, you can use noise_shape=(batch_size, 1, features). May be null. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + public Dropout(Ops tf, float rate, Shape noiseShape, long seed, Class type) { + this(tf, null, rate, noiseShape, seed, type, null); + } + + /** + * Creates a Dropout layer, using a unique name will be generated based on {@link + * Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param rate A number between 0 and 1. Fraction of the input units to drop. + * @param noiseShape Optional, 1D integer tensor representing the shape of the binary dropout mask + * that will be multiplied with the input. For instance, if your inputs have shape + * (batch_size, timesteps, features) and you want the dropout mask to be the same for all + * timesteps, you can use noise_shape=(batch_size, 1, features). May be null. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public Dropout(Ops tf, float rate, Shape noiseShape, long seed, Class type, Options options) { + this(tf, null, rate, noiseShape, seed, type, options); + } + + /** + * Creates a Dropout layer + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param rate A number between 0 and 1. Fraction of the input units to drop. + * @param noiseShape Optional, 1D integer tensor representing the shape of the binary dropout mask + * that will be multiplied with the input. For instance, if your inputs have shape + * (batch_size, timesteps, features) and you want the dropout mask to be the same for all + * timesteps, you can use noise_shape=(batch_size, 1, features). May be null. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + public Dropout(Ops tf, String name, float rate, Shape noiseShape, long seed, Class type) { + this(tf, name, rate, noiseShape, seed, type, null); + } + + /** + * Creates a Dropout layer + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param rate A number between 0 and 1. Fraction of the input units to drop. + * @param noiseShape Optional, 1D integer tensor representing the shape of the binary dropout mask + * that will be multiplied with the input. For instance, if your inputs have shape + * (batch_size, timesteps, features) and you want the dropout mask to be the same for all + * timesteps, you can use noise_shape=(batch_size, 1, features). May be null. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public Dropout( + Ops tf, + String name, + float rate, + Shape noiseShape, + long seed, + Class type, + Options options) { + super(tf, name, true, type, options); + if (rate < 0 || rate >= 1) + throw new IllegalArgumentException("The rate must be between >= 0 and < 1, inclusive."); + this.rate = rate; + this.noiseShape = noiseShape; + this.seed = seed; + setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @SuppressWarnings("unchecked") + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + + Ops tf = getTF(); + List> outputs = new ArrayList<>(); + + for (Operand input : inputs) { + + Operand output; + if (!TFloating.class.isAssignableFrom(input.type())) { + output = cast(tf, input, TFloat64.class); + } else { + output = (Operand) input; + } + + if (training) { + Operand rateV = cast(tf, tf.constant(rate), getType()); + + Operand noise = + noiseShape == null ? tf.shape(input, TInt64.class) : tf.constant(noiseShape); + + Operand tOutput = cast(getTF(), output, getType()); + tOutput = dropout(tOutput, rateV, noise, seed); + + outputs.add(cast(getTF(), tOutput, resultType)); + } else { + outputs.add(cast(getTF(), output, resultType)); + } + } + return callPostProcess(outputs, training); + } + + /** {@inheritDoc} */ + @Override + public List computeOutputShape(List inputShapes) { + return inputShapes; + } + + /** + * Computes dropout: randomly sets elements to zero to prevent overfitting. + * + * @param input the input + * @param rate the drop out rate, the probability that each element is dropped. For example, + * setting rate=0.1 would drop 10% of input elements. + * @param noiseShape the noise shape representing the shape for randomly generated keep/drop + * flags. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @return an Operand with the same shape as the input with a percentage of entries dropped out. + */ + /* TODO, this is defined as tf.dropout() in python and is an nn op. Do we want it here? */ + private Operand dropout( + Operand input, Operand rate, Operand noiseShape, long seed) { + Ops tf = getTF(); + + Operand one = cast(tf, tf.constant(1.), input.type()); + Operand keepProb = tf.math.sub(one, rate); + Operand scale = tf.math.div(one, keepProb); + Operand ret = tf.math.mul(input, scale); + + Operand randomTensor = + tf.random.randomUniform(noiseShape, input.type(), RandomUniform.seed(seed)); + Operand keepMask = tf.math.greaterEqual(randomTensor, rate); + ret = tf.math.mul(ret, cast(tf, keepMask, ret.type())); + return tf.reshape(ret, tf.shape(input)); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ELU.java new file mode 100644 index 00000000000..90dfc6e2d4e --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ELU.java @@ -0,0 +1,110 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Exponential Linear Unit layer. + * + *

It follows:: + * + *

{@code
+ * f(x) =  alpha * (exp(x) - 1.) for x < 0
+ * f(x) = x for x >= 0
+ * }
+ * + * @param the data type for the layer's weights and computation. + */ +public class ELU extends Layer { + public static float DEFAULT_ALPHA = 1.0f; + + private final float alpha; + + /** + * Creates a ELU Layer with a unique name generated based on * {@link Class#getSimpleName()} and + * {@link #DEFAULT_ALPHA} for the alpha value. + * + * @param tf the TensorFlow Ops. + * @param type the data type for the layer's weights and computation. + */ + public ELU(Ops tf, Class type) { + this(tf, null, DEFAULT_ALPHA, type, null); + } + + /** + * Creates a ELU Layer with {@link #DEFAULT_ALPHA} for the alpha value. + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param type the data type for the layer's weights and computation. + */ + public ELU(Ops tf, String name, Class type) { + this(tf, name, DEFAULT_ALPHA, type, null); + } + + /** + * Creates a ELU Layer with a unique name generated based on * {@link Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops. + * @param alpha Negative slope coefficient. Must be >= 0. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options + */ + public ELU(Ops tf, float alpha, Class type, Options options) { + this(tf, null, alpha, type, options); + } + /** + * Creates a ELU Layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param alpha Negative slope coefficient. Must be >= 0. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public ELU(Ops tf, String name, float alpha, Class type, Options options) { + super(tf, name, true, type, options); + this.alpha = alpha; + setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + + org.tensorflow.framework.activations.ELU elu = + new org.tensorflow.framework.activations.ELU<>(getTF(), alpha); + List> tInputs = convertList(inputs, getType()); + List> results = new ArrayList<>(); + tInputs.forEach(tInput -> results.add(cast(getTF(), elu.call(tInput), resultType))); + return callPostProcess(results, training); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Flatten.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Flatten.java new file mode 100644 index 00000000000..2a07d6f623f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Flatten.java @@ -0,0 +1,192 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.TensorFormat; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.Collections; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Flattens the input. Does not affect the batch size. + * + *

Note: If inputs are shaped {@code (batch,)} without a feature axis, then flattening + * adds an extra channel dimension and output shape is {@code (batch, 1)}. + * + * @param the data type for the layer's weights and computation. + */ +public class Flatten extends Layer { + private final TensorFormat dataFormat; + + /** + * Creates a Flatten Layer with a unique name generated based on * {@link Class#getSimpleName()} + * and {@link TensorFormat#NHWC} for the data format + * + * @param tf the TensorFlow Ops. + * @param type the data type for the layer's weights and computation. + */ + public Flatten(Ops tf, Class type) { + this(tf, null, TensorFormat.NHWC, type, null); + } + + /** + * Creates a Flatten Layer with a unique name generated based on {@link Class#getSimpleName()} and + * {@link TensorFormat#NHWC} for the data format + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param type the data type for the layer's weights and computation. + */ + public Flatten(Ops tf, String name, Class type) { + this(tf, name, TensorFormat.NHWC, type, null); + } + + /** + * Creates a Flatten Layer with a unique name generated based on * {@link Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops. + * @param dataFormat The ordering of the dimensions in the inputs. {@link TensorFormat#NHWC} + * corresponds to inputs with shape {@code (batch, ..., channels) } while {@link + * TensorFormat#NCHW} corresponds to inputs with shape {@code (batch, channels, ...)}. + * @param type the data type for the layer's weights and computation. + */ + public Flatten(Ops tf, TensorFormat dataFormat, Class type) { + this(tf, null, dataFormat, type, null); + } + + /** + * Creates a Flatten Layer with a unique name generated based on * {@link Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops. + * @param dataFormat The ordering of the dimensions in the inputs. {@link TensorFormat#NHWC} + * corresponds to inputs with shape {@code (batch, ..., channels) } while {@link + * TensorFormat#NCHW} corresponds to inputs with shape {@code (batch, channels, ...)}. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options + */ + public Flatten(Ops tf, TensorFormat dataFormat, Class type, Options options) { + this(tf, null, dataFormat, type, options); + } + + /** + * Creates a Flatten Layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param dataFormat The ordering of the dimensions in the inputs. {@link TensorFormat#NHWC} + * corresponds to inputs with shape {@code (batch, ..., channels) } while {@link + * TensorFormat#NCHW} corresponds to inputs with shape {@code (batch, channels, ...)}. + * @param type the data type for the layer's weights and computation. + */ + public Flatten(Ops tf, String name, TensorFormat dataFormat, Class type) { + this(tf, name, dataFormat, type, null); + } + /** + * Creates a Flatten Layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param dataFormat The ordering of the dimensions in the inputs. {@link TensorFormat#NHWC} + * corresponds to inputs with shape {@code (batch, ..., channels) } while {@link + * TensorFormat#NCHW} corresponds to inputs with shape {@code (batch, channels, ...)}. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public Flatten(Ops tf, String name, TensorFormat dataFormat, Class type, Options options) { + super(tf, name, true, type, options); + this.dataFormat = dataFormat; + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + // this layer only accepts one input + if (inputs == null || inputs.size() != 1) + throw new IllegalArgumentException("Flatten layer: only accepts 1 input"); + + Operand input = inputs.get(0); + if (!isBuilt()) build(input.shape()); + Shape shape = input.shape(); + int rank = shape.numDimensions(); + if (this.dataFormat == TensorFormat.NCHW) { + if (rank != Shape.UNKNOWN_SIZE && rank > 1) { + long[] permutation = new long[rank + 1]; + permutation[0] = 0; + for (int i = 2; i < rank; i++) permutation[i - 1] = i; + permutation[rank] = 1; + input = getTF().linalg.transpose(input, getTF().constant(permutation)); + } + } + + if (rank == 1) { + input = getTF().expandDims(input, getTF().constant(1)); + } else { + Operand flattenedShape; + long[] dims = shape.asArray(); + if (dims != null) { + long batchDim = dims[0]; + long[] nonBatchDims = new long[dims.length - 1]; + System.arraycopy(dims, 1, nonBatchDims, 0, nonBatchDims.length); + Shape nonBatchShape = Shape.of(nonBatchDims); + if (!nonBatchShape.hasUnknownDimension()) { + int lastDim = 1; + for (long dim : nonBatchDims) lastDim *= dim; + flattenedShape = getTF().constant(Shape.of(-1L, lastDim)); + } else if (batchDim != Shape.UNKNOWN_SIZE) { + flattenedShape = getTF().constant(Shape.of(batchDim, -1L)); + } else { + Operand batchDimension = + getTF().shape.size(input, getTF().constant(0L), TInt64.class); + flattenedShape = getTF().shape.append(batchDimension, getTF().constant(0L)); + } + input = getTF().reshape(input, flattenedShape); + } + } + return callPostProcess(Collections.singletonList(cast(getTF(), input, resultType)), training); + } + + /** {@inheritDoc} */ + @Override + public List computeOutputShape(List inputShapes) { + if (inputShapes == null || inputShapes.size() != 1) + throw new IllegalArgumentException("Dense layer: there must be one input shape"); + if (!isBuilt()) build(inputShapes); + Shape inputShape = inputShapes.get(0); + long lastDim = 1L; + for (int i = 1; i < inputShape.numDimensions(); i++) { + lastDim *= inputShape.size(i); + } + // creates a new shape of (batchSize, rest) + Shape newShape = Shape.of(inputShape.size(0)); + newShape = newShape.append(lastDim); + return Collections.singletonList(newShape); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianDropout.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianDropout.java new file mode 100644 index 00000000000..9c1a15eaf74 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianDropout.java @@ -0,0 +1,170 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.ParameterizedTruncatedNormal; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Apply multiplicative 1-centered Gaussian noise. + * + *

As it is a regularization layer, it is only active at training time. + * + * @param the data type for the layer's weights and computation. + */ +public class GaussianDropout extends Layer { + + private final float rate; + private final long seed; + + /** + * Creates a GaussianDropout layer, using a unique name will be generated based on {@link + * Class#getSimpleName()} and no noiseShape. + * + * @param tf the TensorFlow Ops. + * @param rate A number between 0 and 1. Drop probability. The multiplicative noise will have + * standard deviation: sqrt(rate / (1 - rate)). + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + public GaussianDropout(Ops tf, float rate, long seed, Class type) { + + this(tf, null, rate, seed, type, null); + } + + /** + * Creates a GaussianDropout layer, using a unique name will be generated based on {@link + * Class#getSimpleName()} and no noiseShape. + * + * @param tf the TensorFlow Ops. + * @param rate A number between 0 and 1. Drop probability. The multiplicative noise will have + * standard deviation: sqrt(rate / (1 - rate)). + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public GaussianDropout(Ops tf, float rate, long seed, Class type, Options options) { + + this(tf, null, rate, seed, type, options); + } + + /** + * Creates a GaussianDropout layer + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param rate A number between 0 and 1. Fraction of the input units to drop. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + public GaussianDropout(Ops tf, String name, float rate, long seed, Class type) { + this(tf, name, rate, seed, type, null); + } + + /** + * Creates a GaussianDropout layer + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param rate A number between 0 and 1. Fraction of the input units to drop. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public GaussianDropout( + Ops tf, String name, float rate, long seed, Class type, Options options) { + super(tf, name, true, type, options); + if (rate < 0 || rate >= 1) + throw new IllegalArgumentException("The rate must be between >= 0 and < 1, inclusive."); + this.rate = rate; + this.seed = seed; + setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + + Ops tf = getTF(); + List> outputs = new ArrayList<>(); + + for (Operand input : inputs) { + + Operand output = cast(tf, input, getType()); + + // if in training mode do dropout, otherwise don't + //noinspection IfStatementWithIdenticalBranches + if (training && rate >= 0 && rate <= 1) { + Operand rateV = cast(tf, tf.constant(rate), getType()); + output = dropout(output, rateV, seed); + outputs.add(output); + } else { + outputs.add(output); + } + } + return callPostProcess(convertTo(outputs, resultType), training); + } + + /** {@inheritDoc} */ + @Override + public List computeOutputShape(List inputShapes) { + return inputShapes; + } + + /** + * Computes dropout: randomly sets elements to zero to prevent overfitting. + * + * @param input the input + * @param rate the drop out rate, the probability that each element is dropped. For example, + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @return an Operand with the same shape as the input with a percentage of entries dropped out. + */ + /* TODO, this is defined as tf.dropout() in python and is an nn op. Do we want it here? */ + private Operand dropout(Operand input, Operand rate, long seed) { + Ops tf = getTF(); + + Operand one = cast(tf, tf.constant(1), input.type()); + Operand zero = cast(tf, tf.constant(0), input.type()); + Operand keepProb = tf.math.sub(one, rate); + Operand stdDev = tf.math.sqrt(tf.math.div(rate, keepProb)); + + Operand randomNormal = + tf.random.parameterizedTruncatedNormal( + tf.shape(input), one, stdDev, zero, one, ParameterizedTruncatedNormal.seed(seed)); + + return tf.math.mul(input, randomNormal); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianNoise.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianNoise.java new file mode 100644 index 00000000000..18718a68f15 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianNoise.java @@ -0,0 +1,167 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.ParameterizedTruncatedNormal; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Apply additive zero-centered Gaussian noise. + * + *

This is useful to mitigate overfitting (you could see it as a form of random data + * augmentation). Gaussian Noise (GS) is a natural choice as corruption process for real valued + * inputs. + * + * @param the data type for the layer's weights and computation. + */ +public class GaussianNoise extends Layer { + + private final float stddev; + private final long seed; + + /** + * Creates a GaussianNoise layer, using a unique name will be generated based on {@link + * Class#getSimpleName()} and no noiseShape. + * + * @param tf the TensorFlow Ops. + * @param stddev A number between 0 and 1. Drop probability. The multiplicative noise will have + * standard deviation: sqrt(rate / (1 - rate)). + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + public GaussianNoise(Ops tf, float stddev, long seed, Class type) { + + this(tf, null, stddev, seed, type, null); + } + + /** + * Creates a GaussianNoise layer, using a unique name will be generated based on {@link + * Class#getSimpleName()} and no noiseShape. + * + * @param tf the TensorFlow Ops. + * @param stddev A number between 0 and 1. Drop probability. The multiplicative noise will have + * standard deviation: sqrt(rate / (1 - rate)). + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public GaussianNoise(Ops tf, float stddev, long seed, Class type, Options options) { + + this(tf, null, stddev, seed, type, options); + } + + /** + * Creates a GaussianNoise layer + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param stddev A number between 0 and 1. Fraction of the input units to drop. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + public GaussianNoise(Ops tf, String name, float stddev, long seed, Class type) { + this(tf, name, stddev, seed, type, null); + } + /** + * Creates a GaussianNoise layer + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param stddev A number between 0 and 1. Fraction of the input units to drop. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public GaussianNoise( + Ops tf, String name, float stddev, long seed, Class type, Options options) { + super(tf, name, true, type, options); + this.stddev = stddev; + this.seed = seed; + setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + + Ops tf = getTF(); + List> outputs = new ArrayList<>(); + + for (Operand input : inputs) { + + Operand output = cast(tf, input, getType()); + + if (training) { + + Operand stddevV = cast(tf, tf.constant(stddev), getType()); + + output = dropout(output, stddevV, seed); + outputs.add(output); + } else { + outputs.add(output); + } + } + return callPostProcess(convertTo(outputs, resultType), training); + } + + /** {@inheritDoc} */ + @Override + public List computeOutputShape(List inputShapes) { + return inputShapes; + } + + /** + * Computes dropout: randomly sets elements to zero to prevent overfitting. + * + * @param input the input + * @param stdDev the drop out rate, the probability that each element is dropped. For example, + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @return an Operand with the same shape as the input with a percentage of entries dropped out. + */ + /* TODO, this is defined as tf.dropout() in python and is an nn op. Do we want it here? */ + private Operand dropout(Operand input, Operand stdDev, long seed) { + Ops tf = getTF(); + + Operand one = cast(tf, tf.constant(1), input.type()); + Operand zero = cast(tf, tf.constant(0), input.type()); + + Operand randomNormal = + tf.random.parameterizedTruncatedNormal( + tf.shape(input), zero, stdDev, zero, one, ParameterizedTruncatedNormal.seed(seed)); + + return tf.math.add(input, randomNormal); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Input.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Input.java new file mode 100644 index 00000000000..2a2c414b86b --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Input.java @@ -0,0 +1,348 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.Collections; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that handles model input. + * + * @param the data type for the layer's calculations. + */ +public class Input extends Layer { + + private final Class inputType; + private final boolean placeholder; + private final Operand output; + + /** + * Creates an input layer using {@link Class#getSimpleName()} for the name. + * + * @param tf the TensorFlow Ops. + * @param input The input + * @param type the data type for the layer's weights and computation. + */ + public Input(Ops tf, Operand input, Class type) { + + this(tf, null, input, null, type, null); + } + + /** + * Creates an input layer using {@link Class#getSimpleName()} for the name. + * + * @param tf the TensorFlow Ops. + * @param input The input + * @param type the data type for the layer's weights and computation. + * @param options the Layer options + */ + public Input(Ops tf, Operand input, Class type, Options options) { + + this(tf, null, input, null, type, options); + } + + /** + * Creates an input layer. + * + * @param tf the TensorFlow Op + * @param name the unique name for this layer, if null, will generate a name based on {@link + * Class#getSimpleName()} + * @param input The input + * @param type the data type for the layer's weights and computation. + */ + public Input(Ops tf, String name, Operand input, Class type) { + + this(tf, name, input, null, type, null); + } + + /** + * Creates an input layer. + * + * @param tf the TensorFlow Op + * @param name the unique name for this layer, if null, will generate a name based on {@link + * Class#getSimpleName()} + * @param input The input + * @param type the data type for the layer's weights and computation. + * @param options the layer's options + */ + public Input( + Ops tf, String name, Operand input, Class type, Options options) { + + this(tf, name, input, null, type, options); + } + + /** + * Creates an input layer using {@link Class#getSimpleName()} for the name. + * + * @param tf the TensorFlow Ops, before the first call to the {@link #call} method method is + * called. + * @param inputType the data type for the input and output, if null, input.type() is used + * @param type the data type for the layer's weights and computation. + */ + public Input(Ops tf, Class inputType, Class type) { + this(tf, null, null, inputType, type, null); + } + + /** + * Creates an input layer using {@link Class#getSimpleName()} for the name. + * + * @param tf the TensorFlow Ops, before the first call to the {@link #call} method method is + * called. + * @param inputType the data type for the input and output, if null, input.type() is used + * @param type the data type for the layer's weights and computation. + * @param options the layer's options, may be null + */ + public Input(Ops tf, Class inputType, Class type, Options options) { + this(tf, null, null, inputType, type, options); + } + + /** + * Creates an input layer. + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer, if null, will generate a name based on {@link + * Class#getSimpleName()} + * @param inputType the data type for the input and output, if null, input.type() is used + * @param type the data type for the layer's weights and computation. + */ + public Input(Ops tf, String name, Class inputType, Class type) { + this(tf, name, null, inputType, type, null); + } + + /** + * Creates an input layer. + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer, if null, will generate a name based on {@link + * Class#getSimpleName()} + * @param inputType the data type for the input and output, if null, input.type() is used + * @param type the data type for the layer's weights and computation. + * @param options the layer's options, may be null + */ + public Input( + Ops tf, String name, Class inputType, Class type, Options options) { + this(tf, name, null, inputType, type, options); + } + + /** + * Creates an input layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer, if null, will generate a name based on {@link + * Class#getSimpleName()} + * @param input The input + * @param inputType the data type for the input and output, if null, input.type() is used + * @param type the data type for the layer's weights and computation. + * @throws IllegalArgumentException if inputShape and either batchSize or batchInputShape are not + * null, and if both inputShape and input are null. + */ + public Input( + Ops tf, + String name, + Operand input, + Class inputType, + Class type) { + this(tf, name, input, inputType, type, null); + } + /** + * Creates an input layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer, if null, will generate a name based on {@link + * Class#getSimpleName()} + * @param input The input + * @param inputType the data type for the input and output, if null, input.type() is used + * @param type the data type for the layer's weights and computation. + * @param options the layer's options, may be null + * @throws IllegalArgumentException if inputShape and either batchSize or batchInputShape are not + * null, and if both inputShape and input are null. + */ + public Input( + Ops tf, + String name, + Operand input, + Class inputType, + Class type, + Options options) { + super(tf, name, true, type, options); + Options inputOptions = getInstanceOptions(); + + if (inputType == null && input == null) { + throw new IllegalArgumentException("both input and inputType cannot be null"); + } + + if (input != null && inputType != null && !input.type().equals(inputType)) { + throw new IllegalArgumentException( + String.format("input.type() differs from inputType: %s vs. %s", input.type(), inputType)); + } + + if (inputOptions != null) { + if (inputOptions.inputShape != null + && (inputOptions.batchSize != null || inputOptions.batchInputShape != null)) { + throw new IllegalArgumentException( + "Only provide the inputShape or the batchSize or batchInputShape parameters at the size."); + } + } + + Shape lShape; + + if (inputOptions != null && inputOptions.batchInputShape != null) { + lShape = + inputOptions.batchInputShape.takeLast(inputOptions.batchInputShape.numDimensions() - 1); + setBatchInputShape(inputOptions.batchInputShape); + if (getBatchSize() == null) { + setBatchSize(inputOptions.batchInputShape.size(0)); + } + } else { + if (input == null) { + lShape = + (inputOptions == null || inputOptions.inputShape == null) + ? Shape.of(Shape.UNKNOWN_SIZE) + : inputOptions.inputShape; + } else { + lShape = + (inputOptions == null || inputOptions.inputShape == null) + ? input.shape() + : inputOptions.inputShape; + } + + setBatchSize( + (inputOptions == null || inputOptions.batchSize == null) + ? Shape.UNKNOWN_SIZE + : inputOptions.batchSize); + + setBatchInputShape(Shape.of(getBatchSize()).append(lShape)); + } + setInputShape(lShape); + + this.inputType = inputType == null ? input.type() : inputType; + super.build(lShape); + if (input != null) { + output = input; + placeholder = false; + } else { + output = getTF().placeholder(this.inputType, Placeholder.shape(getBatchInputShape())); + placeholder = true; + } + } + + /** + * Gets the input Operand. This is a convenience method to create the input for a Model. + * + * @param tf the TensorFlow Ops. + * @param type the data type for the layer's weights and computation. + * @param the data type for the layer's calculations. + * @return the output + */ + public static Operand input(Ops tf, Class type) { + return input(tf, type, null); + } + + /** + * Gets the input Operand. This is a convenience method to create the input for a Model. + * + * @param tf the TensorFlow Ops. + * @param type the data type for the layer's weights and computation. + * @param options the Layer options + * @param the data type for the layer's calculations. + * @return the output + */ + public static Operand input(Ops tf, Class type, Options options) { + Input layer = new Input<>(tf, type, type, options); + return layer.getOutput(type); + } + + /** + * Gets the input Operand. This is a convenience method to create the input for a Model. + * + * @param tf the TensorFlow Ops. + * @param input The input + * @param type the data type for the layer's weights and computation. + * @param options the Layer options + * @param the data type for the layer's calculations. + * @return the output + */ + public static Operand input( + Ops tf, Operand input, Class type, Options options) { + Input layer = new Input<>(tf, input, type, options); + return layer.getOutput(); + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + return callPostProcess(Collections.singletonList(getOutput(resultType)), training); + } + + /** + * Gets the output Operand. + * + *

Note: a calling class should call this method directly, rather than calling one of the + * {@link #call} methods + * + * @return the output Operand. + */ + public Operand getOutput() { + return output; + } + + /** + * Gets the output Operand. + * + *

Note: a calling class should call this method directly, rather than calling one of the + * {@link #call} methods + * + * @param resultType the output data type + * @param the data type for the result + * @return the output Operand. + */ + public Operand getOutput(Class resultType) { + + return cast(getTF(), output, resultType); + } + + /** + * Identifies whether the output is a placeholder or not. + * + * @return true, if the output represents a placeholder + */ + public boolean isPlaceholder() { + return placeholder; + } + + /** + * The data type expected by the input. + * + * @return The data type expected by the input. + */ + public Class getInputType() { + return inputType; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Lambda.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Lambda.java new file mode 100644 index 00000000000..cb2f1e9048c --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Lambda.java @@ -0,0 +1,195 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.BiFunction; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Wraps arbitrary Java Lambda as a Layer. + * + *

The Lambda layer exists so that arbitrary TensorFlow functions can be used when + * constructing Sequential models. Lambda layers are best suited for + * simple operations or quick experimentation. + * + *

the Java lambda function is in the form x = function(tf, input). The first + * argument is the TensorFlow Ops, the second argument is the input Operand. For example: + * + *

{@code
+ * Lambda lambda = new Lambda(tf, (ops, input) -> ops.math.mul(ops.constant(2), input), TFloat32.class);
+ * }
+ * + * @param the data type for the layer's weights and computation. + */ +public class Lambda extends Layer { + private BiFunction, Operand> function; + + /** + * Creates a Lambda layer, generating a unique name based on {@link Class#getSimpleName()} + * + * @param tf the TensorFlow Ops + * @param type the data type for the layer's weights and computation. + */ + public Lambda(Ops tf, Class type) { + this(tf, null, null, type, null); + } + + /** + * Creates a Lambda layer, generating a unique name based on {@link Class#getSimpleName()} + * + * @param tf the TensorFlow Ops + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public Lambda(Ops tf, Class type, Options options) { + this(tf, null, null, type, options); + } + + /** + * Creates a Lambda layer + * + * @param tf the TensorFlow Ops + * @param name the unique name for this layer, if null, generates a unique name based on {@link + * Class#getSimpleName()}. + * @param type the data type for the layer's weights and computation. + */ + public Lambda(Ops tf, String name, Class type) { + this(tf, name, null, type, null); + } + + /** + * Creates a Lambda layer + * + * @param tf the TensorFlow Ops + * @param name the unique name for this layer, if null, generates a unique name based on {@link + * Class#getSimpleName()}. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public Lambda(Ops tf, String name, Class type, Options options) { + this(tf, name, null, type, options); + } + + /** + * Creates a Lambda layer, generating a unique name based on {@link Class#getSimpleName()} + * + * @param tf the TensorFlow Ops + * @param function The Java lambda function in the form x = function(tf, input). The + * first argument is the TensorFlow Ops, the second argument is the input Operand. If function + * is null, then the input is returned un changed. + * @param type the data type for the layer's weights and computation. + */ + public Lambda(Ops tf, BiFunction, Operand> function, Class type) { + this(tf, null, function, type, null); + } + + /** + * Creates a Lambda layer, generating a unique name based on {@link Class#getSimpleName()} + * + * @param tf the TensorFlow Ops + * @param function The Java lambda function in the form x = function(tf, input). The + * first argument is the TensorFlow Ops, the second argument is the input Operand. If function + * is null, then the input is returned un changed. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public Lambda( + Ops tf, BiFunction, Operand> function, Class type, Options options) { + this(tf, null, function, type, options); + } + + /** + * Creates a Lambda layer + * + * @param tf the TensorFlow Ops + * @param name the unique name for this layer, if null, generates a unique name based on {@link + * Class#getSimpleName()}. + * @param function the Java lambda function in the form x = function(tf, input). The + * first argument is the TensorFlow Ops, the second argument is the input Operand. If function + * is null, then the input is returned un changed. + * @param type the data type for the layer's weights and computation. + */ + public Lambda( + Ops tf, String name, BiFunction, Operand> function, Class type) { + this(tf, name, function, type, null); + } + + /** + * Creates a Lambda layer + * + * @param tf the TensorFlow Ops + * @param name the unique name for this layer, if null, generates a unique name based on {@link + * Class#getSimpleName()}. + * @param function the Java lambda function in the form x = function(tf, input). The + * first argument is the TensorFlow Ops, the second argument is the input Operand. If function + * is null, then the input is returned un changed. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options + */ + public Lambda( + Ops tf, + String name, + BiFunction, Operand> function, + Class type, + Options options) { + super(tf, name, true, type, options); + this.function = function; + } + + /** + * Sets the lambda function + * + * @param function the Java lambda function in the form + * x = function(tf, input). The first argument is the TensorFlow Ops, the second + * argument is the input Operand. + */ + public void setLambda(BiFunction, Operand> function) { + this.function = function; + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + if (inputs.isEmpty()) { + return Collections.emptyList(); + } + Ops tf = getTF(); + List> outputs = new ArrayList<>(); + for (Operand input : inputs) { + if (function != null) { + Operand tInput = cast(tf, input, getType()); + Operand result = function.apply(tf, tInput); + outputs.add(result); + } else { + outputs.add(cast(tf, input, getType())); + } + } + return convertTo(outputs, resultType); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Layer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Layer.java new file mode 100644 index 00000000000..60ee9c6a9e0 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Layer.java @@ -0,0 +1,1009 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Initializer; +import org.tensorflow.framework.layers.impl.InputSpec; +import org.tensorflow.framework.layers.impl.VariableDef; +import org.tensorflow.framework.losses.Loss; +import org.tensorflow.framework.metrics.Metric; +import org.tensorflow.framework.regularizers.Regularizer; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TString; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.UnaryOperator; +import java.util.stream.Collectors; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * The base abstract class for Layers. + * + *

A layer is a callable object that takes as input one or more tensors and that outputs one or + * more tensors. It involves computation, defined in the call() method, and a state (weight + * variables), defined either in the constructor or in the first call to the {@link #call} method + * method. + * + *

Users will just instantiate a layer and then treat it as a callable. + * + * @param the data type for the layer's weights and computation. + */ +public abstract class Layer { + + private static final Map nameMap = new HashMap<>(); + private final String name; + private final Class type; + private final List> weights = new ArrayList<>(); + private final List> trainableWeights = new ArrayList<>(); + private final List> nonTrainableWeights = new ArrayList<>(); + private final List losses = new ArrayList<>(); + // some loss operations don't have an associated Loss class, so this property holds + // the Operands to calculate the loss, used in the model. + private final List> lossOperations = new ArrayList<>(); + private final List> metrics = new ArrayList<>(); + private final Map, VariableDef> variableMap = new HashMap<>(); + // Note that, unlike other classes, tf may not be set in the constructor, but may be set later. + // the idea behind this is that the model can be built with the layers before the model + // sets the tf instance probably during the model.compile phase. + private final Ops tf; + private boolean trainable; + private Regularizer activityRegularizer; + private boolean built; + private boolean stateful; + private boolean supportsMasking; + // These are the inputShapes as presented to build + private List inputShapes; + private List inputSpecs; + // These are the shapes/dimensions presented by Options. + private Shape batchInputShape; + private Long batchSize; + private Shape inputShape; + private Options instanceOptions; + + /** + * Creates the base Layer class + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer, if null will generate a name using the method + * {@link #genName()} + * @param trainable whether the layer's variables should be trainable or not. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public Layer(Ops tf, String name, boolean trainable, Class type, Options options) { + this.name = name == null ? genName() : name; + this.setTrainable(trainable); + this.type = type; + this.batchSize = Shape.UNKNOWN_SIZE; + loadOptions(options); + this.tf = tf.withSubScope(getName()); + } + + private void loadOptions(Options options) { + if (options != null) { + instanceOptions = options; + if (instanceOptions.batchInputShape != null) { + this.batchInputShape = instanceOptions.batchInputShape; + } + if (instanceOptions.batchSize != null) { + this.batchSize = instanceOptions.batchSize; + } + if (instanceOptions.inputShape != null) { + this.inputShape = instanceOptions.inputShape; + } + if (instanceOptions.activityRegularizer != null) { + this.activityRegularizer = instanceOptions.activityRegularizer; + } + if (instanceOptions.metrics != null) { + this.metrics.addAll(instanceOptions.metrics); + } + if (instanceOptions.losses != null) { + this.losses.addAll(instanceOptions.losses); + } + } + } + + /** + * Generates an unique name by appending an integer value to the {@link Class#getSimpleName} in + * the form {@link Class#getSimpleName}_<identifier>, e.g Dense_1 + * The first call to generate an unique name will only return {@link Class#getSimpleName} with out + * the suffix, e.g Dense. + * + * @return the generated name for the class. + */ + private String genName() { + String base = getClass().getSimpleName(); + Integer id = nameMap.get(base); + if (id == null) { + nameMap.put(base, 0); + return base; + } else { + id++; + nameMap.put(base, id); + return String.format("%s_%d", base, id); + } + } + + /** + * Invokes the layer's algorithm using a single input, returning a single output. Training mode is + * true. + * + *

This is a convenience call on top of {@link #call}}. + * + * @param input the input Operand + * @return the output Operand, or null if no output is generated from the layer's logic. + */ + public Operand call(Operand input) { + + return call(input, null, true, getType()); + } + + /** + * Invokes the layer's algorithm using a single input, returning a single output. Training mode is + * true. + * + *

This is a convenience call on top of {@link #call}}. + * + * @param input the input Operand + * @param type the data type for the result + * @param the data type for the result + * @return the output Operand, or null if no output is generated from the layer's logic. + */ + public Operand call(Operand input, Class type) { + + return call(input, null, true, type); + } + + /** + * Invokes the layer's algorithm using a single input, returning a single output. + * + *

This is a convenience call on top of {@link #call}. + * + * @param input the input Operand + * @param training whether the call is in inference mode or training mode + * @param type the data type for the result + * @param the data type for the result + * @return the output Operand, or null if no output is generated from the layer's logic. + */ + public Operand call( + Operand input, boolean training, Class type) { + return call(input, null, training, type); + } + + /** + * Invokes the layer's algorithm using a single input, returning a single output. + * + *

This is a convenience call on top of {@link #call}. + * + * @param input the input Operand + * @param mask the mask to apply to the result, may be null + * @param training whether the call is in inference mode or training mode + * @param type the data type for the result + * @param the data type for the result + * @return the output Operand, or null if no output is generated from the layer's logic. + */ + public Operand call( + Operand input, Operand mask, boolean training, Class type) { + List> result = + call(Collections.singletonList(input), Collections.singletonList(mask), training, type); + return result != null ? result.get(0) : null; + } + + /** + * Invokes the layer's algorithm Training mode is true. + * + * @param inputs the input Operands + * @param type the data type for the result + * @param the data type for the result + * @return the output Operands + */ + public List> call( + List> inputs, Class type) { + return call(inputs, null, false, type); + } + + /** + * Invokes the layer's logic using a list of inputs, returning a list of outputs. + * + * @param inputs the input Operands + * @param masks a list of masks, one for each input, to apply to the result, may be null + * @param training whether the call is in inference mode or training mode + * @param type the data type for the result + * @param the data type for the result + * @return the output Operands. + */ + public abstract List> call( + List> inputs, + List> masks, + boolean training, + Class type); + + /** + * Post processes a layer's call result + * + * @param inputs the input Operands + * @param training true if in training mode + * @param the data type of the inputs and result + * @return the output Operands. + */ + protected List> callPostProcess( + List> inputs, @SuppressWarnings("unused") boolean training) { + if (activityRegularizer != null && !inputs.isEmpty()) { + boolean aTNumber = TNumber.class.isAssignableFrom(inputs.get(0).type()); + if (aTNumber) { + inputs.forEach( + input -> { + if (input.type() != TString.class) { + Operand tInput = cast(tf, input, getType()); + addLossOperation(activityRegularizer.call(tInput)); + } + }); + } + } + + return inputs; + } + + /** + * Converts a list of inputs to a new list of the internal data type defined for this layer. + * + * @param inputs the inputs. + * @return the new list converted to the new type. + */ + protected List> convertList(List> inputs) { + return convertList(inputs, getType()); + } + /** + * Converts a list of inputs to a new list of the internal data type defined for this layer. + * + * @param inputs the inputs. + * @param resultType the data type of the result + * @param the data type of the result + * @return the new list converted to the new type. + */ + protected List> convertList( + List> inputs, Class resultType) { + List> result = new ArrayList<>(); + inputs.forEach(input -> result.add(cast(getTF(), input, resultType))); + return result; + } + + /** + * Converts a list of inputs with this class type, to a new list of the new type + * + * @param inputs the inputs. + * @param newType the new type. + * @param the data type for the new type. + * @return the new list converted to the new type. + */ + protected List> convertTo( + List> inputs, Class newType) { + List> result = new ArrayList<>(); + inputs.forEach(input -> result.add(cast(getTF(), input, newType))); + return result; + } + + /** + * Creates the variables of the layer (optional, for subclass implementers). This is a method that + * implementers of subclasses of Layer or Model can override if they + * need a state-creation step in-between layer instantiation and layer call. This is typically + * used to create the weights of Layer subclasses. + * + *

This method is a convenience method that calls {@link #build(List)}. + * + * @param inputShape the shapes of the inputs, one per input + */ + protected void build(Shape... inputShape) { + build(Arrays.asList(inputShape)); + } + + /** + * Creates the variables of the layer (optional, for subclass implementers). This is a method that + * implementers of subclasses of Layer or Model can override if they + * need a state-creation step in-between layer instantiation and layer call. This is typically + * used to create the weights of Layer subclasses. + * + * @param inputShapes the shapes of the inputs, one per input + * @throws IllegalStateException if the TensorFlow Ops is null. + */ + protected void build(List inputShapes) { + if (tf == null) throw new IllegalStateException("The TensorFlow Ops has not been set yet"); + built = true; + this.inputShapes = inputShapes; + } + + /** + * Computes the output shape of the layer. + * + *

This implementation calls {@link #build(List)} if not already called, and returns the input + * shapes as the output shapes. Sub-classes may want to alter this default behavior + * + *

If the layer has not been built, this method will call {@link #build(List)} on the layer. + * This assumes that the layer will later be used with inputs that match the input shape provided + * here. + * + * @param inputShapes the input shapes, one per input + * @return the output shapes, one per output + */ + public List computeOutputShape(List inputShapes) { + if (!built) build(inputShapes); + return inputShapes; + } + + /** + * Gets the unique name for this layer + * + * @return the unique name for this layer + */ + public String getName() { + return name; + } + + /** + * Gets the trainable setting + * + * @return true, if this layer is trainable + */ + public boolean isTrainable() { + return trainable; + } + + /** + * Sets the trainable indicator + * + * @param trainable the trainable indicator + */ + public void setTrainable(boolean trainable) { + this.trainable = trainable; + } + + /** + * Gets the data type for the layer's weights and computation. + * + * @return the data type for the layer's weights and computation. + */ + public Class getType() { + return type; + } + + /** + * Gets the layer's weights + * + * @return the layer's weights + */ + public List> getWeights() { + return weights; + } + + public void setWeights(List> weights) { + this.weights.clear(); + this.weights.addAll(weights); + } + + /** + * Gets the layer's trainable weights + * + * @return the layer's trainable weights + */ + public List> getTrainableWeights() { + return trainableWeights; + } + + /** + * Gets the layer's non-trainable weights + * + * @return the layer's non-trainable weights + */ + public List> getNonTrainableWeights() { + return nonTrainableWeights; + } + + /** + * Adds a weight to the layer + * + * @param name the variable's name + * @param shape the variable's shape + * @param initializer the variable initializer + * @param constraint a constraint to be applied to the weight + * @param regularizer Regularizer instance + * @param trainable whether the variable should be part of the layer's "trainableWeights" + * @param seed a seed value for random number generation + * @throws IllegalStateException if the property {@link #tf} has not been set yet. + * @return the variable created for the weight + */ + public Variable addWeight( + String name, + Shape shape, + Initializer initializer, + UnaryOperator> constraint, + Regularizer regularizer, + boolean trainable, + long seed) { + if (tf == null) { + throw new IllegalStateException("Parameter \"tf\" has not been set"); + } + + VariableDef variableDef = + new VariableDef<>( + tf, name, shape, initializer, constraint, regularizer, trainable, seed, getType()); + + Variable variable = variableDef.getVariable(); + + variableMap.put(variable, variableDef); + weights.add(variable); + if (trainable) trainableWeights.add(variable); + else nonTrainableWeights.add(variable); + return variable; + } + + /** + * Gets the VariableDef for the specified variable + * + * @param variable the variable + * @return the VariableDef + */ + public VariableDef getVariableDef(Variable variable) { + return variableMap.get(variable); + } + + /** + * Adds a weight to the layer + * + * @param name the weight name + * @param variable the variable to add + * @param initializer the variable initializer + * @param constraint the constraint on the variable + * @param regularizer the regularizer for the variable + * @param trainable whether the variable should be part of the layer's "trainableWeights" + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and type. + * @throws IllegalStateException if the property {@link #tf} has not been set yet. + * @return the variable created for the weight + */ + public Variable addWeight( + String name, + Variable variable, + Initializer initializer, + UnaryOperator> constraint, + Regularizer regularizer, + boolean trainable, + long seed) { + if (tf == null) { + throw new IllegalStateException("Parameter \"tf\" has not been set"); + } + if (variable == null) { + throw new IllegalStateException("Parameter \"variable\" has not been set"); + } + VariableDef variableDef = + new VariableDef<>( + tf, name, variable, initializer, constraint, regularizer, trainable, seed); + variableMap.put(variable, variableDef); + weights.add(variable); + if (trainable) trainableWeights.add(variable); + else nonTrainableWeights.add(variable); + return variable; + } + + /** + * Gets the Operands that initializes all the weights + * + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and type. + * @return the Operands that initializes all the weights + */ + public List> initializeWeights(long seed) { + List> result = new ArrayList<>(); + weights.forEach(w -> result.add(initializeWeight(w, seed))); + return result; + } + + /** + * Creates an Operand that initializes a weight + * + * @param weight the weight to initialize + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and type. + * @return the Operand that initializes the weight + * @throws IllegalArgumentException if the weight does not have a registered initializer + */ + public Operand initializeWeight(Variable weight, long seed) { + VariableDef varDef = variableMap.get(weight); + if (varDef == null) { // this should not happen if addWeight was used to create/add the weight + addWeight(null, weight, null, null, null, true, seed); + varDef = variableMap.get(weight); + } + return varDef.init(); + } + + /** + * Computes an output mask tensor. + * + * @param inputs the input Operands + * @param masks the mask Operands. + * @return null or a list of Operands, one for each output from the layer, + */ + @SuppressWarnings("UnusedParameters") + public List> computeMask( + List> inputs, List> masks) { + // the default implementation merely returns the masks. + if (isSupportsMasking()) { + if (masks == null) return null; + return masks.stream().map(m -> cast(getTF(), m, TBool.class)).collect(Collectors.toList()); + } + if (masks == null || masks.isEmpty()) { + throw new IllegalArgumentException( + String.format("%s does not support masking, but was passed a mask", getName())); + } + + return null; + } + + /** + * Gets the Losses assigned to this layer + * + * @return the Losses assigned to this layer + */ + public List getLosses() { + return losses; + } + + /** + * Gets the Loss Operations assigned to this layer + * + * @return the Loss Operations assigned to this layer + */ + public List> getLossOperations() { + return lossOperations; + } + + /** + * Adds a loss to this layer + * + * @param loss the loss to add + */ + public void addLoss(Loss loss) { + losses.add(loss); + } + + /** + * Adds a loss operation to this layer + * + * @param lossOperation the loss operation + */ + public void addLossOperation(Operand lossOperation) { + this.lossOperations.add(lossOperation); + } + + /** + * Adds losses to this layer + * + * @param losses the losses to add + */ + public void addLosses(List losses) { + this.losses.addAll(losses); + } + + /** + * Adds loss operations to this layer + * + * @param lossOperations the loss operations to add + */ + public void addLossOperations(List> lossOperations) { + this.lossOperations.addAll(lossOperations); + } + + /** + * Gets the Losses assigned to this layer + * + * @return the Losses assigned to this layer + */ + public List> getMetrics() { + return metrics; + } + + /** + * Adds a metric to this layer + * + * @param metric the metric to add + */ + public void addMetric(Metric metric) { + metrics.add(metric); + } + + /** + * Adds metrics to this layer + * + * @param metrics the metric to add + */ + public void addMetrics(List> metrics) { + this.metrics.addAll(metrics); + } + + /** + * Determines whether or not the build method has been called. + * + * @return true, if the build method has been called. + */ + @SuppressWarnings("BooleanMethodIsAlwaysInverted") + public boolean isBuilt() { + return built; + } + + /** + * Sets the build indicator + * + * @param built the build indicator + */ + public void setBuilt(boolean built) { + this.built = built; + } + + /** + * Gets the input shapes, one per input + * + * @return the input shapes, one per input + */ + public List getInputShapes() { + return inputShapes; + } + + /** + * Sets the input shapes, one per input + * + * @param inputShapes the input shapes + */ + public void setInputShapes(List inputShapes) { + this.inputShapes = inputShapes; + } + + /** + * Adds an inputSpec + * + * @param inputSpec the inputSpec + */ + public void addInputSpec(InputSpec inputSpec) { + if (inputSpecs == null) { + inputSpecs = new ArrayList<>(); + } + inputSpecs.add(inputSpec); + } + + /** + * Gets the inputSpecs, one per input + * + * @return the inputSpecs, one per input + */ + public List getInputSpecs() { + return inputSpecs; + } + + /** + * Sets the inputSpecs, one per input + * + * @param inputSpecs the inputSpecs + */ + public void setInputSpecs(List inputSpecs) { + this.inputSpecs = inputSpecs; + } + + /** + * Gets the {@link #tf} property + * + * @return the {@link #tf} property + */ + public Ops getTF() { + return tf; + } + + /** + * Gets the stateful property. + * + *

A stateful layer is a layer whose updates are run during inference too, for instance + * stateful RNNs. + * + * @return true, if this layer is stateful + */ + public boolean isStateful() { + return stateful; + } + + /** + * Sets the stateful property. + * + *

A stateful layer is a layer whose updates are run during inference too, for instance + * stateful RNNs. + * + * @param stateful true, if this layer is stateful. + */ + public void setStateful(boolean stateful) { + this.stateful = stateful; + } + + /** + * Gets the batch input shape + * + * @return the batch input shape + */ + public Shape getBatchInputShape() { + return batchInputShape; + } + + /** + * Sets the batch input shape + * + * @param batchInputShape the batch input shape + */ + public void setBatchInputShape(Shape batchInputShape) { + this.batchInputShape = batchInputShape; + } + + /** + * Gets the batch size + * + * @return the batch size + */ + public Long getBatchSize() { + return batchSize; + } + + /** + * Sets the batch size + * + * @param batchSize the batch size + */ + public void setBatchSize(Long batchSize) { + this.batchSize = batchSize; + } + + /** + * Gets the input shape for this layer + * + * @return the input shape for this layer + */ + public Shape getInputShape() { + return inputShape; + } + + /** + * Sets the input shape for this layer + * + * @param inputShape the input shape for this layer + */ + public void setInputShape(Shape inputShape) { + this.inputShape = inputShape; + } + + /** + * Gets the options instance for this layer. + * + * @return the options instance for this layer. + */ + public Options getInstanceOptions() { + return instanceOptions; + } + + /** + * Gets the activity Regularizer + * + * @return the activity Regularizer + */ + // TODO change to Regularizer class + public Object getActivityRegularizer() { + return activityRegularizer; + } + + /** + * Sets the activity Regularizer + * + * @param activityRegularizer the activity Regularizer + */ + public void setActivityRegularizer(Regularizer activityRegularizer) { + this.activityRegularizer = activityRegularizer; + } + + /** + * Gets the indicator that this layer supports masking. + * + * @return the indicator that this layer supports masking. + */ + public boolean isSupportsMasking() { + return supportsMasking; + } + + /** + * Sets the indicator that this layer supports masking. + * + * @param supportsMasking the indicator that this layer supports masking. + */ + public void setSupportsMasking(boolean supportsMasking) { + this.supportsMasking = supportsMasking; + } + + /** + * Assigns a value to the variable + * + * @param variable the variable to assign to + * @param value the value to assign + * @return the operand that assigns the value to this variable + * @throws IllegalArgumentException if the variable is not known. + */ + public Operand assign(Variable variable, Operand value) { + VariableDef varDef = variableMap.get(variable); + if (varDef == null) { + throw new IllegalStateException(String.format("Variable %s was not found.", variable)); + } + return varDef.assign(value); + } + + /** + * Adds a value to the variable + * + * @param variable the variable to add to + * @param value the value to add + * @return the operand that adds the value to this variable + * @throws IllegalArgumentException if the variable is not known. + */ + public Operand assignAdd(Variable variable, Operand value) { + VariableDef varDef = variableMap.get(variable); + if (varDef == null) { + throw new IllegalStateException(String.format("Variable %s was not found.", variable)); + } + return varDef.assignAdd(value); + } + + /** + * Subtracts a value from the variable + * + * @param variable the variable to subtract from + * @param value the value to subtract + * @return the operand that subtracts the value from this variable + * @throws IllegalArgumentException if the variable is not known. + */ + public Operand assignSub(Variable variable, Operand value) { + VariableDef varDef = variableMap.get(variable); + if (varDef == null) { + throw new IllegalStateException(String.format("Variable %s was not found.", variable)); + } + return varDef.assignSub(value); + } + + /** Optional attributes for {@link Layer} */ + public static class Options { + protected Shape inputShape; + protected Shape batchInputShape; + protected Long batchSize; + protected List> metrics; + protected List losses; + protected Regularizer activityRegularizer; + + public static Options create() { + return new Options(); + } + + /** + * Sets the inputShape + * + * @param inputShape the input shape for the layer + * @return this options instance + */ + public Layer.Options inputShape(Shape inputShape) { + this.inputShape = inputShape; + return this; + } + + /** + * Sets the batchSize + * + * @param batchSize the batch input shape for the layer + * @return this Options instance + */ + public Layer.Options batchSize(Long batchSize) { + this.batchSize = batchSize; + return this; + } + + /** + * Sets the shared name + * + * @param batchInputShape the batch input shape for the layer + * @return this Options instance + */ + public Layer.Options batchInputShape(Shape batchInputShape) { + this.batchInputShape = batchInputShape; + return this; + } + + /** + * Sets the activityRegularizer + * + * @param activityRegularizer the activity Regularizer + * @return this Options instance + */ + public Layer.Options activityRegularizer(Regularizer activityRegularizer) { + this.activityRegularizer = activityRegularizer; + return this; + } + + /** + * Adds a metric + * + * @param metric the metric + * @return this Options instance + */ + public Layer.Options metric(Metric metric) { + if (this.metrics == null) { + this.metrics = new ArrayList<>(); + } + metrics.add(metric); + return this; + } + + /** + * Adds metrics + * + * @param metrics the metrics to add + * @return this Options instance + */ + public Layer.Options metrics(List> metrics) { + if (this.metrics == null) { + this.metrics = new ArrayList<>(metrics); + } else { + this.metrics.addAll(metrics); + } + return this; + } + + /** + * Adds a loss + * + * @param loss the Loss + * @return this Options instance + */ + public Layer.Options loss(Loss loss) { + if (losses == null) { + losses = new ArrayList<>(); + } + losses.add(loss); + return this; + } + + /** + * Adds losses + * + * @param losses the losses to add + * @return this Options instance + */ + public Layer.Options losses(List losses) { + if (this.losses == null) { + this.losses = new ArrayList<>(losses); + } else { + this.losses.addAll(losses); + } + return this; + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/LeakyReLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/LeakyReLU.java new file mode 100644 index 00000000000..807b2b4430f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/LeakyReLU.java @@ -0,0 +1,110 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.activations.ReLU; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +/** + * Leaky version of a Rectified Linear Unit. + * + *

It allows a small gradient when the unit is not active: + * + *

{@code
+ * f(x) = alpha * x if x < 0
+ * f(x) = x if x >= 0
+ * }
+ * + * @param the data type for the layer's weights and computation. + */ +public class LeakyReLU extends Layer { + public static float DEFAULT_ALPHA = 0.3f; + + private final float alpha; + + /** + * Creates a LeakyReLU Layer with a unique name generated based on * {@link Class#getSimpleName()} + * and {@link #DEFAULT_ALPHA} for the alpha value. + * + * @param tf the TensorFlow Ops. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public LeakyReLU(Ops tf, Class type, Options options) { + this(tf, null, DEFAULT_ALPHA, type, options); + } + + /** + * Creates a LeakyReLU Layer with {@link #DEFAULT_ALPHA} for the alpha value. + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param type the data type for the layer's weights and computation. + */ + public LeakyReLU(Ops tf, String name, Class type) { + this(tf, name, DEFAULT_ALPHA, type, null); + } + + /** + * Creates a LeakyReLU Layer with a unique name generated based on * {@link + * Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops. + * @param alpha Negative slope coefficient. Must be >= 0. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public LeakyReLU(Ops tf, float alpha, Class type, Options options) { + this(tf, null, alpha, type, options); + } + /** + * Creates a LeakyReLU Layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param alpha Negative slope coefficient. Must be >= 0. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public LeakyReLU(Ops tf, String name, float alpha, Class type, Options options) { + super(tf, name, true, type, options); + this.alpha = alpha; + setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + + ReLU reLU = new ReLU<>(getTF(), alpha, ReLU.MAX_VALUE_DEFAULT, ReLU.THRESHOLD_DEFAULT); + List> tInputs = convertList(inputs); + List> results = new ArrayList<>(); + tInputs.forEach(input -> results.add(reLU.call(input))); + return callPostProcess(convertTo(results, resultType), training); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Maximum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Maximum.java new file mode 100644 index 00000000000..28f1a80821f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Maximum.java @@ -0,0 +1,92 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.Merge; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that computes the maximum (element-wise) a list of inputs. + * + *

It takes as input a list of tensors, all of the same shape, and returns a single tensor (also + * of the same shape). + * + * @param the data type for the layer's weights and computation. + */ +public class Maximum extends Merge { + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Maximum(Ops tf, Class type) { + this(tf, null, type, null); + } + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public Maximum(Ops tf, Class type, Options options) { + this(tf, null, type, options); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Maximum(Ops tf, String name, Class type) { + this(tf, name, type, null); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public Maximum(Ops tf, String name, Class type, Options options) { + super(tf, name, type, options); + } + + /** {@inheritDoc} */ + @Override + protected Operand mergeFunction(List> inputs) { + Ops tf = getTF(); + Operand output = cast(tf, tf.identity(inputs.get(0)), getType()); + for (int i = 1; i < inputs.size(); i++) { + output = tf.math.maximum(output, cast(tf, inputs.get(i), getType())); + } + return output; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Minimum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Minimum.java new file mode 100644 index 00000000000..bc46e7e82f3 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Minimum.java @@ -0,0 +1,92 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.Merge; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that computes the minimum (element-wise) a list of inputs. + * + *

It takes as input a list of tensors, all of the same shape, and returns a single tensor (also + * of the same shape). + * + * @param the data type for the layer's weights and computation. + */ +public class Minimum extends Merge { + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Minimum(Ops tf, Class type) { + this(tf, null, type, null); + } + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public Minimum(Ops tf, Class type, Options options) { + this(tf, null, type, options); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Minimum(Ops tf, String name, Class type) { + this(tf, name, type, null); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public Minimum(Ops tf, String name, Class type, Options options) { + super(tf, name, type, options); + } + + /** {@inheritDoc} */ + @Override + protected Operand mergeFunction(List> inputs) { + Ops tf = getTF(); + Operand output = cast(tf, tf.identity(inputs.get(0)), getType()); + for (int i = 1; i < inputs.size(); i++) { + output = tf.math.minimum(output, cast(tf, inputs.get(i), getType())); + } + return output; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Multiply.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Multiply.java new file mode 100644 index 00000000000..d343463f7af --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Multiply.java @@ -0,0 +1,92 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.Merge; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that multiplies (element-wise) a list of inputs. + * + *

It takes as input a list of tensors, all of the same shape, and returns a single tensor (also + * of the same shape). + * + * @param the data type for the layer's weights and computation. + */ +public class Multiply extends Merge { + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Multiply(Ops tf, Class type) { + this(tf, null, type); + } + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public Multiply(Ops tf, Class type, Options options) { + this(tf, null, type, options); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Multiply(Ops tf, String name, Class type) { + this(tf, name, type, null); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public Multiply(Ops tf, String name, Class type, Options options) { + super(tf, name, type, options); + } + + /** {@inheritDoc} */ + @Override + protected Operand mergeFunction(List> inputs) { + Ops tf = getTF(); + Operand output = cast(tf, tf.identity(inputs.get(0)), getType()); + for (int i = 1; i < inputs.size(); i++) { + output = tf.math.mul(output, cast(tf, inputs.get(i), getType())); + } + return output; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ReLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ReLU.java new file mode 100644 index 00000000000..be502aba57b --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ReLU.java @@ -0,0 +1,223 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +/** + * Rectified Linear Unit activation layer + * + *

With default values, it returns element-wise {@code max(x, 0)} + * + *

Otherwise, it follows: + * + *

{@code
+ * f(x) = max_value if x >= max_value
+ *  f(x) = x if threshold <= x < max_value
+ *  f(x) = negative_slope * (x - threshold) otherwise
+ * }
+ * + * @param the data type for the layer's weights and computation. + */ +public class ReLU extends Layer { + public static float DEFAULT_MAX_VALUE = Float.NaN; + public static float DEFAULT_NEGATIVE_SLOPE = 0; + public static float DEFAULT_THRESHOLD = 0; + + private final float maxValue; + private final float negativeSlope; + private final float threshold; + + /** + * Creates a ReLU Layer with a unique name generated based on * {@link Class#getSimpleName()} and + * using {@link #DEFAULT_NEGATIVE_SLOPE} for the negative slope, {@link #DEFAULT_MAX_VALUE} as the + * maximum value and {@link #DEFAULT_THRESHOLD} as the threshold. + * + * @param tf the TensorFlow Ops + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public ReLU(Ops tf, Class type, Options options) { + this(tf, null, DEFAULT_NEGATIVE_SLOPE, DEFAULT_MAX_VALUE, DEFAULT_THRESHOLD, type, options); + } + + /** + * Creates a ReLU Layer using {@link #DEFAULT_NEGATIVE_SLOPE} for the negative slope, {@link + * #DEFAULT_MAX_VALUE} as the maximum value and {@link #DEFAULT_THRESHOLD} as the threshold. + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param type the data type for the layer's weights and computation. + */ + public ReLU(Ops tf, String name, Class type) { + this(tf, name, DEFAULT_NEGATIVE_SLOPE, DEFAULT_MAX_VALUE, DEFAULT_THRESHOLD, type, null); + } + + /** + * Creates a ReLU Layer using {@link #DEFAULT_NEGATIVE_SLOPE} for the negative slope, {@link + * #DEFAULT_MAX_VALUE} as the maximum value and {@link #DEFAULT_THRESHOLD} as the threshold. + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public ReLU(Ops tf, String name, Class type, Options options) { + this(tf, name, DEFAULT_NEGATIVE_SLOPE, DEFAULT_MAX_VALUE, DEFAULT_THRESHOLD, type, options); + } + + /** + * Creates a ReLU Layer with a unique name generated based on * {@link Class#getSimpleName()}, + * using {@link #DEFAULT_MAX_VALUE} as the maximum value and {@link #DEFAULT_THRESHOLD} as the + * threshold. + * + * @param tf the TensorFlow Ops. + * @param negativeSlope Negative slope coefficient. Must be >= 0. + * @param type the data type for the layer's weights and computation. + * @throws IllegalArgumentException if negativeSlope is < 0 + */ + public ReLU(Ops tf, float negativeSlope, Class type) { + this(tf, null, negativeSlope, DEFAULT_MAX_VALUE, DEFAULT_THRESHOLD, type, null); + } + + /** + * Creates a ReLU Layer with a unique name generated based on * {@link Class#getSimpleName()}, + * using {@link #DEFAULT_MAX_VALUE} as the maximum value and {@link #DEFAULT_THRESHOLD} as the + * threshold. + * + * @param tf the TensorFlow Ops. + * @param negativeSlope Negative slope coefficient. Must be >= 0. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + * @throws IllegalArgumentException if negativeSlope is < 0 + */ + public ReLU(Ops tf, float negativeSlope, Class type, Options options) { + this(tf, null, negativeSlope, DEFAULT_MAX_VALUE, DEFAULT_THRESHOLD, type, options); + } + + /** + * Creates a ReLU Layer using a unique name will be generated based on * {@link + * Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param negativeSlope Negative slope coefficient. Must be >= 0. + * @param maxValue Maximum activation value. Must be >= 0. + * @param threshold Threshold value for thresholded activation. + * @param type the data type for the layer's weights and computation. + * @throws IllegalArgumentException if maxValue or negativeSlope is < 0 + */ + public ReLU(Ops tf, float negativeSlope, float maxValue, float threshold, Class type) { + this(tf, null, negativeSlope, maxValue, threshold, type, null); + } + + /** + * Creates a ReLU Layer using a unique name will be generated based on * {@link + * Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param negativeSlope Negative slope coefficient. Must be >= 0. + * @param maxValue Maximum activation value. Must be >= 0. + * @param threshold Threshold value for thresholded activation. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + * @throws IllegalArgumentException if maxValue or negativeSlope is < 0 + */ + public ReLU( + Ops tf, + float negativeSlope, + float maxValue, + float threshold, + Class type, + Options options) { + this(tf, null, negativeSlope, maxValue, threshold, type, options); + } + + /** + * Creates a ReLU Layer + * + * @param tf the TensorFlow Ops + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param negativeSlope Negative slope coefficient. Must be >= 0. + * @param maxValue Maximum activation value. Must be >= 0. + * @param threshold Threshold value for thresholded activation. + * @param type the data type for the layer's weights and computation. + * @throws IllegalArgumentException if maxValue or negativeSlope is < 0 + */ + public ReLU( + Ops tf, String name, float negativeSlope, float maxValue, float threshold, Class type) { + this(tf, name, negativeSlope, maxValue, threshold, type, null); + } + /** + * Creates a ReLU Layer + * + * @param tf the TensorFlow Ops + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param negativeSlope Negative slope coefficient. Must be >= 0. + * @param maxValue Maximum activation value. Must be >= 0. + * @param threshold Threshold value for thresholded activation. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + * @throws IllegalArgumentException if maxValue or negativeSlope is < 0 + */ + public ReLU( + Ops tf, + String name, + float negativeSlope, + float maxValue, + float threshold, + Class type, + Options options) { + super(tf, name, true, type, options); + if (!Float.isNaN(maxValue) && maxValue < 0) { + throw new IllegalArgumentException("maxValue must be >= 0, got " + maxValue); + } + if (negativeSlope < 0) { + throw new IllegalArgumentException("negativeSlope must be >= 0, got " + negativeSlope); + } + + this.maxValue = maxValue; + this.negativeSlope = negativeSlope; + this.threshold = threshold; + setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + + org.tensorflow.framework.activations.ReLU reLU = + new org.tensorflow.framework.activations.ReLU<>( + getTF(), negativeSlope, maxValue, threshold); + List> tInputs = convertList(inputs); + List> results = new ArrayList<>(); + tInputs.forEach(input -> results.add(reLU.call(input))); + return callPostProcess(convertTo(results, resultType), training); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/RepeatVector.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/RepeatVector.java new file mode 100644 index 00000000000..3808fad8bff --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/RepeatVector.java @@ -0,0 +1,123 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Repeats the input {@code repeatCount} times. + * + * @param the data type for the layer's weights and computation. + */ +public class RepeatVector extends Layer { + + private final int repeatCount; + + /** + * Creates a RepeatCount using a unique name will be generated based on * {@link + * Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param repeatCount the repetition factor. + * @param type the data type for the layer's weights and computation. + */ + public RepeatVector(Ops tf, int repeatCount, Class type) { + this(tf, null, repeatCount, type, null); + } + + /** + * Creates a RepeatVector using a unique name will be generated based on * {@link + * Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param repeatCount the repetition factor. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public RepeatVector(Ops tf, int repeatCount, Class type, Options options) { + this(tf, null, repeatCount, type, options); + } + + /** + * Creates a RepeatCount + * + * @param tf the TensorFlow Ops + * @param name he unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param repeatCount the repetition factor. + * @param type the data type for the layer's weights and computation. + */ + public RepeatVector(Ops tf, String name, int repeatCount, Class type) { + this(tf, name, repeatCount, type, null); + } + + /** + * Creates a RepeatCount + * + * @param tf the TensorFlow Ops + * @param name he unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param repeatCount the repetition factor. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public RepeatVector(Ops tf, String name, int repeatCount, Class type, Options options) { + super(tf, name, true, type, options); + this.repeatCount = repeatCount; + } + + /** + * @param inputs the input Operands, 2D tensor of shape (num_samples, features) + * @param masks a list of masks, one for each input, to apply to the result, may be null + * @param training whether the call is in inference mode or training mode + * @param resultType the result type + * @param the data type of the result + * @return a 3D tensor of shape (num_samples, repeatCount, features) + */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + if (inputs.isEmpty()) { + return Collections.emptyList(); + } + Ops tf = getTF(); + List> outputs = new ArrayList<>(); + for (Operand input : inputs) { + if (input.shape().numDimensions() != 2) { + throw new IllegalArgumentException("RepeatVector inputs must be rank 2."); + } + Operand output = input; + Operand one = tf.constant(1); + output = tf.expandDims(output, tf.constant(1)); + Operand pattern = tf.stack(Arrays.asList(one, tf.constant(repeatCount), one)); + output = tf.tile(output, pattern); + outputs.add(output); + } + return convertList(outputs, resultType); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Reshape.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Reshape.java new file mode 100644 index 00000000000..0641c501612 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Reshape.java @@ -0,0 +1,110 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that reshapes inputs into the given shape. + * + * @param the data type for the layer's weights and computation. + */ +public class Reshape extends Layer { + + private final Shape targetShape; + + /** + * Creates a Reshape layer using a unique name will be generated based on {@link + * Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param targetShape the target shape. Does not include the batch size dimension. + * @param type the data type for the layer's weights and computation. + */ + public Reshape(Ops tf, Shape targetShape, Class type) { + this(tf, null, targetShape, type, null); + } + + /** + * Creates a Reshape layer using a unique name will be generated based on {@link + * Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param targetShape the target shape. Does not include the batch size dimension. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public Reshape(Ops tf, Shape targetShape, Class type, Options options) { + this(tf, null, targetShape, type, options); + } + + /** + * Creates a Reshape layer. + * + * @param tf the TensorFlow Ops + * @param name the name of this layer + * @param targetShape the target shape. Does not include the batch size dimension. + * @param type the data type for the layer's weights and computation. + */ + public Reshape(Ops tf, String name, Shape targetShape, Class type) { + this(tf, name, targetShape, type, null); + } + + /** + * Creates a Reshape layer. + * + * @param tf the TensorFlow Ops + * @param name the name of this layer + * @param targetShape the target shape. Does not include the batch size dimension. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public Reshape(Ops tf, String name, Shape targetShape, Class type, Options options) { + super(tf, name, true, type, options); + this.targetShape = targetShape; + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + if (inputs.isEmpty()) { + return Collections.emptyList(); + } + Ops tf = getTF(); + Operand input = inputs.get(0); + long batchSize = input.shape().size(0); + Shape newShape = targetShape.prepend(batchSize); + List> result = new ArrayList<>(); + Operand newShapeOp = tf.constant(newShape); + inputs.forEach(inp -> result.add(tf.reshape(cast(tf, inp, getType()), newShapeOp))); + return callPostProcess(convertList(result, resultType), training); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Softmax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Softmax.java new file mode 100644 index 00000000000..a3f32fa29cf --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Softmax.java @@ -0,0 +1,116 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.op.FrameworkOps; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat16; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** Softmax activation function. */ +public class Softmax extends Layer { + + private final int[] axes; + + /** + * Creates a SoftMax layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param axes axes along which the softmax normalization is applied. + * @param type the data type for the layer's weights and computation. + */ + public Softmax(Ops tf, String name, int[] axes, Class type) { + this(tf, name, axes, type, null); + } + + /** + * Creates a SoftMax layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param axes axes along which the softmax normalization is applied. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public Softmax(Ops tf, String name, int[] axes, Class type, Options options) { + super(tf, name, true, type, options); + this.axes = axes; + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + Ops tf = getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + // TODO mask + + List> results = new ArrayList<>(); + + for (int i = 0; i < inputs.size(); i++) { + Operand input = cast(tf, inputs.get(i), getType()); + Operand result; + if (masks != null && !masks.isEmpty()) { + // Since attention_mask is 1.0 for positions we want to attend and 0.0 for + // masked positions, this operation will create a tensor which is 0.0 for + // positions we want to attend and -1e.9 for masked positions. + Operand mask = masks.get(i); + Operand one = cast(tf, tf.constant(1), getType()); + + Operand adder = + tf.math.mul(tf.math.sub(one, cast(tf, mask, getType())), largeCompatibleNegative()); + // Since we are adding it to the raw scores before the softmax, this is + // effectively the same as removing these entirely. + + input = tf.math.add(input, adder); + } + if (axes.length > 1) { + result = tf.math.exp(tf.math.sub(input, fops.math.reduceLogSumExp(input, axes, true))); + } else { + result = fops.nn.softmax(input, axes[0]); + } + results.add(result); + } + return callPostProcess(convertTo(results, resultType), training); + } + + /** + * Gets a large number based on the data type + * + * @return a large number based on the data type + */ + private Operand largeCompatibleNegative() { + Ops tf = getTF(); + if (getType() == TFloat16.class) { + return cast(tf, tf.constant(-0xffdc), getType()); + } else { + return cast(tf, tf.constant(-1e9), getType()); + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Subtract.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Subtract.java new file mode 100644 index 00000000000..a114c64ae86 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Subtract.java @@ -0,0 +1,105 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.Merge; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer hat subtracts two inputs. + * + *

It takes as input a list of tensors, all of the same shape, and returns a single tensor (also + * of the same shape). + * + *

It takes as input a list of tensors of size 2, both of the same shape, and returns a single + * tensor, (inputs[0] - inputs[1]), also of the same shape. + * + * @param the data type for the layer's weights and computation. + */ +public class Subtract extends Merge { + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Subtract(Ops tf, Class type) { + this(tf, null, type, null); + } + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public Subtract(Ops tf, Class type, Options options) { + this(tf, null, type, options); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Subtract(Ops tf, String name, Class type) { + this(tf, name, type, null); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public Subtract(Ops tf, String name, Class type, Options options) { + super(tf, name, type, options); + } + + /** {@inheritDoc} */ + @Override + protected void build(List inputShapes) { + if (inputShapes.size() != 2) { + throw new IllegalArgumentException("A Subtract layer should be called on exactly 2 inputs"); + } + super.build(inputShapes); + } + + /** {@inheritDoc} */ + @Override + protected Operand mergeFunction(List> inputs) { + if (inputs.size() != 2) { + throw new IllegalArgumentException("A Subtract layer should be called on exactly 2 inputs"); + } + Ops tf = getTF(); + Operand output = cast(tf, tf.identity(inputs.get(0)), getType()); + return tf.math.sub(output, cast(tf, inputs.get(1), getType())); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ThresholdedReLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ThresholdedReLU.java new file mode 100644 index 00000000000..29cf026c0a6 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ThresholdedReLU.java @@ -0,0 +1,117 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Thresholded Rectified Linear Unit. + * + *

It follows:: + * + *

{@code
+ * f(x) = x for x > theta
+ * f(x) = 0 otherwise`
+ * }
+ * + * @param the data type for the layer's weights and computation. + */ +public class ThresholdedReLU extends Layer { + public static float DEFAULT_THETA = 1.03f; + + private final float theta; + + /** + * Creates a ThresholdedReLU Layer with a unique name generated based on * {@link + * Class#getSimpleName()} and {@link #DEFAULT_THETA} for the theta value. + * + * @param tf the TensorFlow Ops + * @param type the data type for the layer's weights and computation. + */ + public ThresholdedReLU(Ops tf, Class type) { + + this(tf, null, DEFAULT_THETA, type, null); + } + + /** + * Creates a ThresholdedReLU Layer with {@link #DEFAULT_THETA} for the theta value. + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param type the data type for the layer's weights and computation. + */ + public ThresholdedReLU(Ops tf, String name, Class type) { + this(tf, name, DEFAULT_THETA, type, null); + } + + /** + * Creates a ThresholdedReLU Layer with a unique name generated based on * {@link + * Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param theta Negative slope coefficient. Must be >= 0. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public ThresholdedReLU(Ops tf, float theta, Class type, Options options) { + this(tf, null, theta, type, options); + } + /** + * Creates a ThresholdedReLU Layer + * + * @param tf the TensorFlow Ops + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param theta Threshold location of activation.. Must be >= 0. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + * @throws IllegalArgumentException if theta is *lt; 0. + */ + public ThresholdedReLU(Ops tf, String name, float theta, Class type, Options options) { + super(tf, name, true, type, options); + if (theta < 0) { + throw new IllegalArgumentException("theta must be >= 0, got " + theta); + } + this.theta = theta; + setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + Ops tf = getTF(); + Operand tTheta = cast(tf, tf.constant(theta), getType()); + List> tInputs = convertList(inputs); + List> results = new ArrayList<>(); + tInputs.forEach( + input -> + results.add(tf.math.mul(input, cast(tf, tf.math.greater(input, tTheta), getType())))); + return callPostProcess(convertTo(results, resultType), training); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/InputSpec.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/InputSpec.java new file mode 100644 index 00000000000..899a528e534 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/InputSpec.java @@ -0,0 +1,488 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers.impl; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** + * Specifies the rank, data type and shape of every input to a layer. + * + *

These objects enable the layer to run input compatibility checks for input structure, input + * rank, input shape, and input data type. + * + *

A {@link Shape#UNKNOWN_SIZE} entry in a shape is compatible with any dimension, a {@link + * Shape#unknown()} shape is compatible with any shape. + */ +public class InputSpec { + private Class dataType; + private Shape shape; + private Integer rank; + private Integer maxRank; + private Integer minRank; + private Map axes; + private boolean allowLastAxisSqueeze; + + public InputSpec() {} + + public InputSpec(Options options) { + dataType = options.dataType; + rank = options.rank; + maxRank = options.maxRank; + minRank = options.minRank; + axes = options.axes; + allowLastAxisSqueeze = options.allowLastAxisSqueeze; + + if (options.shape != null && options.shape.numDimensions() != Shape.UNKNOWN_SIZE) { + shape = options.shape; + rank = shape.numDimensions(); + } + + if (axes != null && (rank != null || maxRank != null)) { + maxRank = rank != null ? rank : maxRank; + + Integer maxAxis = axes.keySet().stream().max(Long::compare).get(); + if (maxAxis >= maxRank) { + throw new IllegalArgumentException( + String.format( + "Axis %d is greater than the maximum allowed value: %d, %s", + maxAxis, maxRank, shape)); + } + } + } + + /** + * Returns a Shape object that matches the shape specifications. + * + *

If the InputSpec's {@link #shape} or expected {@link #rank} is defined, this method will + * return a fully or partially-known shape. Otherwise, the returned Shape is {@link + * Shape#unknown()}. + * + * @return the generated shape + */ + public Shape toShape() { + if (rank == null && shape == null) { + return Shape.unknown(); + } else if (shape != null) { + return shape; + } else { + long[] dims = new long[rank]; + Arrays.fill(dims, Shape.UNKNOWN_SIZE); + if (axes != null) { + for (Integer key : axes.keySet()) { + int dimIdx = Math.floorMod(key, rank); + dims[dimIdx] = axes.get(key); + } + } + return Shape.of(dims); + } + } + + /** + * Checks compatibility between the layer and provided inputs. + * + * @param input the input to check. + * @param layerName layer name for error message formatting. + * @param the data type for the input. + * @throws IllegalArgumentException if the provided input's shape is not compatible wiht this + * InputSpec. + */ + public void assertInputCompatibility(Operand input, String layerName) { + Shape staticShape = input.shape(); + + if (staticShape.numDimensions() != Shape.UNKNOWN_SIZE) { + if (rank != null && !isAllowLastAxisSqueeze()) { + if (staticShape.numDimensions() != rank) { + throw new IllegalArgumentException( + String.format( + "Input of %s is incompatible with the layer: expected rank=%d, , found rank=%d. . Full shape received: %s", + layerName, rank, staticShape.numDimensions(), staticShape)); + } + } + if (maxRank != null) { + if (staticShape.numDimensions() > maxRank) { + throw new IllegalArgumentException( + String.format( + "Input of %s is incompatible with the layer: expected max rank =%d, , found rank = %d.", + layerName, maxRank, staticShape.numDimensions())); + } + } + if (minRank != null) { + if (staticShape.numDimensions() < minRank) { + throw new IllegalArgumentException( + String.format( + "Input of %s is incompatible with the layer: expected min rank =%d, found rank = %d.", + layerName, minRank, staticShape.numDimensions())); + } + } + + if (dataType != null && !dataType.equals(input.type())) { + throw new IllegalArgumentException( + String.format( + "Input of %s is incompatible with the layer: expected data type = %s, found data type = %s.", + layerName, dataType.getSimpleName(), input.type().getSimpleName())); + } + + // check each axis + if (axes != null) { + axes.forEach( + (x, v) -> { + if (shape.size(x) != Shape.UNKNOWN_SIZE && shape.size(x) != v) { + throw new IllegalArgumentException( + String.format( + "Input of %s is incompatible with the layer: expected axis = %d of input shape to have value %d, but received input with shape %s", + layerName, x, v, staticShape)); + } + }); + } + + // Check shape. + if (shape != null) { + Shape specShape = shape; + Shape inputShape = staticShape; + if (isAllowLastAxisSqueeze()) { + if (inputShape.size(inputShape.numDimensions() - 1) == 1) { + inputShape = inputShape.take(inputShape.numDimensions() - 1); + } + if (specShape.size(specShape.numDimensions() - 1) == 1) { + specShape = specShape.take(specShape.numDimensions() - 1); + } + } + for (int i = 0; i < specShape.numDimensions(); i++) { + if (specShape.size(i) != Shape.UNKNOWN_SIZE + && inputShape.size(i) != Shape.UNKNOWN_SIZE + && specShape.size(i) != inputShape.size(i)) { + throw new IllegalArgumentException( + String.format( + "Input of %s is incompatible with the layer:: expected shape= %s, found shape = %s", + layerName, shape, staticShape)); + } + } + } + } + } + + /** + * Gets the expected Data Type of the input. + * + * @return the expected Data Type of the input. + */ + public Class getDataType() { + return dataType; + } + + /** + * Sets the expected Data Type of the input. + * + * @param dataType the expected Data Type of the input. + */ + public void setDataType(Class dataType) { + this.dataType = dataType; + } + + /** + * Gets the Dictionary mapping integer axes to a specific dimension value. + * + * @return the Dictionary mapping integer axes to a specific dimension value. + */ + public Map getAxesMap() { + return axes; + } + + /** + * Sets the Dictionary mapping integer axes to a specific dimension value. + * + * @param axes the Dictionary mapping integer axes to a specific dimension value. + */ + public void setAxesMap(Map axes) { + this.axes = axes; + } + + /** + * Gets the expected shape of the input (may include {@link Shape#UNKNOWN_SIZE} for unchecked + * axes). Includes the batch size. + * + * @return the expected shape of the input including batch size. + */ + public Shape getShape() { + return shape; + } + + /** + * Sets the expected shape of the input (may include {@link Shape#UNKNOWN_SIZE} for unchecked + * axes). Includes the batch size. + * + * @param shape the expected shape of the input including batch size. + */ + public void setShape(Shape shape) { + this.shape = shape; + } + + /** + * Gets the expected rank of the input + * + * @return the expected rank of the input + */ + public Integer getRank() { + return rank; + } + + /** + * Sets the expected rank of the input + * + * @param rank the expected rank of the input + */ + public void setRank(Integer rank) { + this.rank = rank; + } + + /** + * Gets the maximum rank of the input. + * + * @return the maximum rank of the input. + */ + public Integer getMaxRank() { + return maxRank; + } + + /** + * Sets the maximum rank of the input. + * + * @param maxRank he maximum rank of the input. + */ + public void setMaxRank(Integer maxRank) { + this.maxRank = maxRank; + } + + /** + * Gets the minimum rank of the input. + * + * @return the minimum rank of the input. + */ + public Integer getMinRank() { + return minRank; + } + + /** + * Sets the minimum rank of the input. + * + * @param minRank he maximum rank of the input. + */ + public void setMinRank(Integer minRank) { + this.minRank = minRank; + } + + /** + * Gets the allow last axis squeeze indicator for the input. If true, then allow inputs of rank + * N+1 as long as the last axis of the input is 1, as well as inputs of rank N-1 as long as the + * last axis of the spec is 1. + * + * @return the allow last axis squeeze indicator + */ + public boolean isAllowLastAxisSqueeze() { + return allowLastAxisSqueeze; + } + + /** + * Sets the allow last axis squeeze indicator for the input. If true, then allow inputs of rank + * N+1 as long as the last axis of the input is 1, as well as inputs of rank N-1 as long as the + * last axis of the spec is 1. + * + * @param allowLastAxisSqueeze the allow last axis squeeze indicator + */ + public void setAllowLastAxisSqueeze(boolean allowLastAxisSqueeze) { + this.allowLastAxisSqueeze = allowLastAxisSqueeze; + } + + /** Optional attributes for {@link InputSpec} */ + public static class Options { + + private Class dataType; + private Shape shape; + private Integer rank; + private Integer maxRank; + private Integer minRank; + private Map axes; + private boolean allowLastAxisSqueeze; + + /** + * Creates an InputSpecs.Options instance + * + * @return the InputSpecs.Options instance + */ + public static Options create() { + return new Options(); + } + + /** + * Sets the expected Data Type of the input. + * + * @param dataType the expected Data Type of the input. + * @return this Options instance. + */ + public Options dataType(Class dataType) { + this.dataType = dataType; + return this; + } + + /** + * Sets the expected shape of the input (may include {@link Shape#UNKNOWN_SIZE} for unchecked + * axes). Includes the batch size. + * + * @param shape the expected shape of the input + * @return this Options instance. + */ + public Options shape(Shape shape) { + this.shape = shape; + return this; + } + + /** + * Sets the expected rank of the input + * + * @param rank the expected rank of the input + * @return this Options instance. + */ + public Options rank(Integer rank) { + this.rank = rank; + return this; + } + + /** + * Sets the maximum rank of the input. + * + * @param maxRank the maximum rank of the input. + * @return this Options instance. + */ + public Options maxRank(Integer maxRank) { + this.maxRank = maxRank; + return this; + } + + /** + * Sets the minimum rank of the input. + * + * @param minRank the minimum rank of the input. + * @return this Options instance. + */ + public Options minRank(Integer minRank) { + this.minRank = minRank; + return this; + } + /** + * Sets the Dictionary mapping integer axes to a specific dimension value. + * + * @param axes the Dictionary mapping integer axes to a specific dimension value. + * @return this Options instance. + */ + public Options axesMap(Map axes) { + this.axes = axes; + return this; + } + /** + * Sets the Dictionary mapping integer axes to a specific dimension value. + * + * @param key the integer axis + * @param dim the dimension value for the specified axis + * @return this Options instance. + */ + public Options axesMap(Integer key, Long dim) { + if (this.axes == null) { + this.axes = new HashMap<>(); + } + this.axes.put(key, dim); + return this; + } + + /** + * Sets the allow last axis squeeze indicator for the input. If true, then allow inputs of rank + * N+1 as long as the last axis of the input is 1, as well as inputs of rank N-1 as long as the + * last axis of the spec is 1. + * + * @param allowLastAxisSqueeze indicator that the allow last axis squeeze indicator for the + * input + * @return this Options instance. + */ + public Options allowLastAxisSqueeze(boolean allowLastAxisSqueeze) { + this.allowLastAxisSqueeze = allowLastAxisSqueeze; + return this; + } + /** + * Gets the expected Data Type of the input. + * + * @return the expected Data Type of the input. + */ + public Class getDataType() { + return dataType; + } + /** + * Gets the expected shape of the input (may include {@link Shape#UNKNOWN_SIZE} for unchecked + * axes). Includes the batch size. + * + * @return the expected shape of the input including batch size. + */ + public Shape getShape() { + return shape; + } + /** + * Gets the expected rank of the input + * + * @return the expected rank of the input + */ + public Integer getRank() { + return rank; + } + /** + * Gets the maximum rank of the input. + * + * @return the maximum rank of the input. + */ + public Integer getMaxRank() { + return maxRank; + } + /** + * Gets the minimum rank of the input. + * + * @return the minimum rank of the input. + */ + public Integer getMinRank() { + return minRank; + } + + /** + * Gets the Dictionary mapping integer axes to a specific dimension value. + * + * @return the Dictionary mapping integer axes to a specific dimension value. + */ + public Map getAxesMap() { + return axes; + } + /** + * Gets the allow last axis squeeze indicator for the input. If true, then allow inputs of rank + * N+1 as long as the last axis of the input is 1, as well as inputs of rank N-1 as long as the + * last axis of the spec is 1. + * + * @return the allow last axis squeeze indicator + */ + public boolean isAllowLastAxisSqueeze() { + return allowLastAxisSqueeze; + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/Merge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/Merge.java new file mode 100644 index 00000000000..74d948e0262 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/Merge.java @@ -0,0 +1,382 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.Layer; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Generic abstract merge layer for element-wise merge functions. + * + * @param the data type for the layer's weights and computation. + */ +public abstract class Merge extends Layer { + + private boolean reshapeRequired; + + /** + * Creates a Merge base class using {@link Class#getSimpleName()} for the name. + * + * @param tf the TensorFlow Ops, may not be null before the first call to the {@link #call} method + * method is called. + * @param type the data type for the weights and computation + */ + protected Merge(Ops tf, Class type) { + + this(tf, null, true, type, null); + } + + /** + * Creates a Merge base class using {@link Class#getSimpleName()} for the name. + * + * @param tf the TensorFlow Ops, may not be null before the first call to the {@link #call} method + * method is called. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + protected Merge(Ops tf, Class type, Options options) { + + this(tf, null, true, type, options); + } + + /** + * Creates a Merge base class. + * + * @param tf the TensorFlow Ops, may not be null before the first call to the {@link #call} method + * method is called. + * @param name the unique name for this layer, if null will use {@link Class#getSimpleName()} for + * the name. + * @param type the data type for the weights and computation + */ + protected Merge(Ops tf, String name, Class type) { + + this(tf, name, true, type, null); + } + + /** + * Creates a Merge base class. + * + * @param tf the TensorFlow Ops, may not be null before the first call to the {@link #call} method + * method is called. + * @param name the unique name for this layer, if null will use {@link Class#getSimpleName()} for + * the name. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + protected Merge(Ops tf, String name, Class type, Options options) { + + this(tf, name, true, type, options); + } + + /** + * Creates the base Layer class + * + * @param tf the TensorFlow Ops, may not be null. + * @param name the unique name for this layer, if null will use {@link Class#getSimpleName()} for + * the name. + * @param trainable whether the layer's variables should be trainable or not. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options + */ + protected Merge(Ops tf, String name, boolean trainable, Class type, Options options) { + super(tf, name, trainable, type, options); + this.setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @Override + public List> computeMask( + List> inputs, List> masks) { + if (masks == null || masks.isEmpty()) { + return null; + } + if (inputs.size() != masks.size()) { + throw new IllegalArgumentException("The lists inputs and masks should have the same length."); + } + + boolean allNull = true; + for (Operand m : masks) { + if (m != null) { + allNull = false; + break; + } + } + if (allNull) { + return null; + } + + final Ops tf = getTF(); + List> rMasks = + masks.stream() + .map(m -> cast(getTF(), m, TBool.class)) + .map(m -> tf.expandDims(m, tf.constant(0))) + .collect(Collectors.toList()); + + Operand concat = tf.concat(rMasks, tf.constant(0)); + Operand bool = cast(tf, concat, TBool.class); + return Collections.singletonList(tf.reduceAll(bool, tf.constant(0))); + } + + /** + * Computes the merged result + * + * @param inputs the inputs + * @return the merged result + */ + protected abstract Operand mergeFunction( + List> inputs); + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + Ops tf = getTF(); + + if (reshapeRequired) { + List> reshapedInputs = new ArrayList<>(); + List inputDimensions = new ArrayList<>(); + inputs.forEach(s -> inputDimensions.add(s.shape().numDimensions())); + if (!inputDimensions.contains((int) Shape.UNKNOWN_SIZE)) { + // If ranks of all inputs are available, + // we simply expand each of them at axis=1 + // until all of them have the same rank. + int maxDimension = Collections.max(inputDimensions); + for (Operand input : inputs) { + int numDims = input.shape().numDimensions(); + for (int i = numDims; i < maxDimension; i++) { + input = tf.expandDims(input, tf.constant(1)); + } + Operand tInput = cast(getTF(), input, getType()); + reshapedInputs.add(tInput); + } + Operand result = cast(tf, mergeFunction(reshapedInputs), resultType); + return Collections.singletonList(result); + + } else { + // Transpose all inputs so that batch size is the last dimension. + // (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , batch_size) + boolean transposed = false; + for (Operand input : inputs) { + Operand tInput = cast(getTF(), input, getType()); + int nDims = tInput.shape().numDimensions(); + if (nDims == Shape.UNKNOWN_SIZE) { + org.tensorflow.op.core.Shape tShape = tf.shape(tInput); + Operand batchSize = tf.shape.size(tShape, tf.constant(0)); + Operand remainderShape = + tf.shape.takeLast(tShape, tf.math.sub(tf.rank(tInput), tf.constant(1))); + Operand newShape = + tf.shape.append(remainderShape, tf.expandDims(batchSize, tf.constant(-1))); + + Operand transposedInput = + tf.reshape( + tInput, + tf.shape.append(batchSize, tf.reduceProd(remainderShape, tf.constant(0)))); + + transposedInput = tf.linalg.transpose(transposedInput, tf.constant(new int[] {1, 0})); + transposedInput = tf.reshape(transposedInput, newShape); + reshapedInputs.add(transposedInput); + transposed = true; + + } else if (nDims > 1) { + int[] perms = new int[nDims]; + for (int i = 1; i < nDims - 1; i++) { + perms[i - 1] = i; + } + perms[nDims - 1] = 0; + reshapedInputs.add(tf.linalg.transpose(tInput, tf.constant(perms))); + } else { + reshapedInputs.add(tInput); + } + } + Operand result = cast(tf, mergeFunction(reshapedInputs), resultType); + + if (transposed) { + int nDim = result.shape().numDimensions(); + if (nDim == Shape.UNKNOWN_SIZE) { + org.tensorflow.op.core.Shape rShape = tf.shape(result); + Operand batchSize = tf.shape.takeLast(rShape, tf.constant(1)); + Operand baseShape = + tf.shape.take(rShape, tf.math.sub(tf.rank(result), tf.constant(1))); + Operand newShape = tf.shape.append(batchSize, baseShape); + result = + tf.reshape( + result, + tf.concat( + Arrays.asList(tf.constant(new int[] {-1}), batchSize), tf.constant(0))); + result = tf.linalg.transpose(result, tf.constant(new int[] {1, 0})); + result = tf.reshape(result, newShape); + } else if (nDim > 1) { + int[] perms = new int[nDim]; + perms[0] = nDim - 1; + for (int i = 0; i < nDim - 1; i++) { + perms[i + 1] = i; + } + result = tf.linalg.transpose(result, tf.constant(perms)); + } + } + return callPostProcess(Collections.singletonList(result), training); + } + } else { + List> tInputs = new ArrayList<>(); + inputs.forEach(i -> tInputs.add(cast(getTF(), i, getType()))); + Operand merged = cast(tf, mergeFunction(tInputs), resultType); + + return callPostProcess(Collections.singletonList(merged), training); + } + } + + /** {@inheritDoc} */ + @Override + protected void build(List inputShapes) { + if (inputShapes == null || inputShapes.size() <= 1) { + throw new IllegalArgumentException( + String.format( + "A merge layer should be called on a list of at least 2 inputs. Got %d inputs", + inputShapes == null ? 0 : inputShapes.size())); + } + Set batchSizes = new HashSet<>(); + inputShapes.forEach(s -> batchSizes.add(s.size(0))); + if (batchSizes.size() > 1) { + throw new IllegalArgumentException( + String.format( + "Can not merge tensors with different batch sizes. Got tensors with shapes %s: ", + Arrays.toString(inputShapes.toArray()))); + } + + Shape inputShape = inputShapes.get(0); + Shape outputShape = inputShape.takeLast(inputShape.numDimensions() - 1); + Shape shape; + for (int i = 1; i < inputShape.size(); i++) { + shape = inputShapes.get(i); + outputShape = computeElementWiseOpOutputShape(outputShape, shape); + } + + Set ranks = new HashSet<>(); + inputShapes.forEach(s -> ranks.add(s.numDimensions())); + boolean hasUnknown = false; + for (Shape s : inputShapes) { + if (s.isUnknown()) { + hasUnknown = true; + break; + } + } + reshapeRequired = hasUnknown || ranks.size() > 1; + super.build(inputShapes); + } + + /** {@inheritDoc} */ + @Override + public List computeOutputShape(List inputShapes) { + Shape outputShape; + if (inputShapes.isEmpty() || inputShapes.get(0) == null) { + outputShape = Shape.of(); + } else { + Shape shape1 = inputShapes.get(0); + if (shape1.numDimensions() > 0) { + outputShape = shape1.takeLast(shape1.numDimensions() - 1); + } else { + outputShape = Shape.of(); + } + } + Shape shape; + for (int i = 1; i < inputShapes.size(); i++) { + Shape shapei = inputShapes.get(i); + if (shapei == null) { + shape = Shape.of(); + } else { + if (shapei.numDimensions() > 0) { + shape = shapei.takeLast(shapei.numDimensions() - 1); + } else { + shape = Shape.of(); + } + } + outputShape = computeElementWiseOpOutputShape(outputShape, shape); + } + + Set batchSizes = new HashSet<>(); + for (Shape s : inputShapes) { + if (s != null) { + batchSizes.add(s.size(0)); + } + } + if (batchSizes.size() == 1) { + outputShape = outputShape.prepend(batchSizes.toArray(new Long[1])[0]); + } else { + outputShape = outputShape.prepend(Shape.UNKNOWN_SIZE); + } + + return Collections.singletonList(outputShape); + } + + /** + * Computes the shape of the resultant of an element-wise operation. + * + * @param shape1 Shape of the first tensor + * @param shape2 Shape of the second tensor + * @return expected output shape when an element-wise operation is carried out on 2 tensors with + * shapes shape1 and shape2 + */ + protected Shape computeElementWiseOpOutputShape(Shape shape1, Shape shape2) { + if (shape2 == null) { + return shape1; + } + if (shape1.isUnknown() || shape2.isUnknown()) { + return Shape.unknown(); + } + if (shape1.numDimensions() < shape2.numDimensions()) { + return computeElementWiseOpOutputShape(shape2, shape1); + } + Shape outputShape = shape1.take(shape1.numDimensions() - shape2.numDimensions()); + + for (int i = shape1.numDimensions() - shape2.numDimensions(), j = 0; + j < shape2.numDimensions(); + j++, i++) { + if (shape1.size(i) == Shape.UNKNOWN_SIZE || shape2.size(i) == Shape.UNKNOWN_SIZE) { + outputShape = outputShape.append(Shape.UNKNOWN_SIZE); + } else if (shape1.size(i) == 1) { + outputShape = outputShape.append(shape2.size(j)); + } else if (shape2.size(j) == 1) { + outputShape = outputShape.append(shape1.size(i)); + } else if (shape1.size(i) != shape2.size(j)) { + throw new IllegalArgumentException( + String.format( + "Operands could not be broadcast together with shapes %s %s", shape1, shape2)); + } else { + outputShape = outputShape.append(shape1.size(i)); + } + } + return outputShape; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/TensorFormat.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/TensorFormat.java new file mode 100644 index 00000000000..e3bd6a3ea49 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/TensorFormat.java @@ -0,0 +1,23 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers.impl; + +/* TODO remove after this enum is added to the api. + * PR: Created TensorFormat enum #191 + */ +public enum TensorFormat { + NCHW, + NHWC +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/VariableDef.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/VariableDef.java new file mode 100644 index 00000000000..370d1907a74 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/VariableDef.java @@ -0,0 +1,260 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Glorot; +import org.tensorflow.framework.initializers.Initializer; +import org.tensorflow.framework.initializers.VarianceScaling; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.regularizers.Regularizer; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.AssignAdd; +import org.tensorflow.op.core.AssignSub; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; + +import java.util.function.UnaryOperator; + +public class VariableDef { + private final Ops tf; + private final String name; + private final Shape shape; + private final Initializer initializer; + private final UnaryOperator> constraint; + private final Regularizer regularizer; + private final boolean trainable; + private final Variable variable; + private final Operand initOperand; + private final Class type; + + public VariableDef( + Ops tf, + String name, + Shape shape, + Initializer initializer, + UnaryOperator> constraint, + Regularizer regularizer, + boolean trainable, + long seed, + Class type) { + this.tf = tf.withName(name); + this.type = type; + this.name = name; + this.constraint = constraint; + this.regularizer = regularizer; + this.trainable = trainable; + + this.shape = shape == null ? Shape.scalar() : shape; + this.initializer = initializer == null ? getDefaultInitializer(seed) : initializer; + initOperand = this.initializer.call(tf.constant(this.shape), type); + variable = tf.withSubScope(name).variable(initOperand); + } + + public VariableDef( + Ops tf, + String name, + Variable variable, + Initializer initializer, + UnaryOperator> constraint, + Regularizer regularizer, + boolean trainable, + long seed) { + this.tf = tf.withName(name); + this.name = name == null ? variable.toString() : name; + this.constraint = constraint; + this.regularizer = regularizer; + this.trainable = trainable; + this.variable = variable; + shape = variable.shape(); + type = variable.type(); + this.initializer = initializer == null ? getDefaultInitializer(seed) : initializer; + initOperand = this.initializer.call(tf.constant(this.shape), type); + } + + /** + * Initializes the variable + * + * @return the operand that initializes this variable + */ + public Operand init() { + return assign(initOperand); + } + + /** + * Assigns a value to the variable, with locking set to false + * + * @param value the value to assign + * @return the operand that assigns the value to this variable + */ + public Operand assign(Operand value) { + return assign(value, false); + } + /** + * Assigns a value to the variable + * + * @param value the value to assign + * @param useLocking If true, use locking during the assignment. + * @return the operand that assigns the value to this variable + */ + public Operand assign(Operand value, boolean useLocking) { + return tf.assign(variable, value, Assign.useLocking(useLocking)); + } + + /** + * Adds a value to the variable, without locking. + * + * @param value the value to add + * @return the operand that adds the value to this variable + */ + public Operand assignAdd(Operand value) { + return assignAdd(value, false); + } + + /** + * Adds a value to the variable + * + * @param value the value to add + * @param useLocking If true, use locking during the assignment. + * @return the operand that adds the value to this variable + */ + public Operand assignAdd(Operand value, boolean useLocking) { + return tf.assignAdd(variable, value, AssignAdd.useLocking(useLocking)); + } + + /** + * Subtracts a value from the variable, without locking. + * + * @param value the value to subtract + * @return the operand that subtracts the value from this variable + */ + public Operand assignSub(Operand value) { + return assignSub(value, false); + } + + /** + * Subtracts a value from the variable + * + * @param value the value to subtract + * @param useLocking If true, use locking during the assignment. + * @return the operand that subtracts the value from this variable + */ + public Operand assignSub(Operand value, boolean useLocking) { + return tf.assignSub(variable, value, AssignSub.useLocking(useLocking)); + } + + /** + * Gets the default initializer based on type + * + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and type. + * @return the default initializer + */ + @SuppressWarnings("unchecked") + private Initializer getDefaultInitializer(long seed) { + Initializer initializer; + + if (TFloating.class.isAssignableFrom(type)) { + // this creates a "Casting 'new Glorot<>(...)' to 'Initializer' is redundant" warning. + // Ignored here as Glorot takes a TFloating which is a subclass of + // and is checked in the if statement above. If you remove this cast, you'll get an error. + + //noinspection RedundantCast + initializer = (Initializer) new Glorot<>(tf, VarianceScaling.Distribution.UNIFORM, seed); + } else { + initializer = new Zeros<>(tf); + } + return initializer; + } + + /** + * Gets the variable name + * + * @return the variable name + */ + public String getName() { + return name; + } + /** + * Gets the variable shape + * + * @return the variable shape + */ + public Shape getShape() { + return shape; + } + /** + * Gets the variable initializer + * + * @return the variable initializer + */ + public Initializer getInitializer() { + return initializer; + } + /** + * Gets the variable constraint + * + * @return the variable constraint + */ + public UnaryOperator> getConstraint() { + return constraint; + } + + /** + * Gets the variable constraint + * + * @return the variable constraint + */ + public Regularizer getRegularizer() { + return regularizer; + } + /** + * Gets the variable trainable indicator + * + * @return the variable trainable indicator + */ + public boolean isTrainable() { + return trainable; + } + /** + * Gets the variable + * + * @return the variable + */ + public Variable getVariable() { + return variable; + } + + /** + * Gets the variable initialization operand. + * + * @return the variable initialization operand. + */ + public Operand getInitOperand() { + return initOperand; + } + + /** + * Gets the variable data type + * + * @return the variable data tupe + */ + public Class getType() { + return type; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 9aa94cf7fcf..6700f2569f0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -19,6 +19,7 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; +import org.tensorflow.framework.op.FrameworkOps; import org.tensorflow.op.core.ReduceAll; import org.tensorflow.op.core.ReduceMax; import org.tensorflow.op.core.ReduceSum; @@ -26,6 +27,7 @@ import org.tensorflow.op.math.Softplus; import org.tensorflow.types.TBool; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TFloating; import org.tensorflow.types.family.TNumber; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -181,7 +183,8 @@ public static Operand binaryCrossentropy( */ private static Operand binaryCrossentropyHelper( Ops tf, Operand target, Operand output, boolean fromLogits) { - if (fromLogits) return tf.nn.sigmoidCrossEntropyWithLogits(target, output); + FrameworkOps fop = FrameworkOps.create(tf); + if (fromLogits) { return fop.nn.sigmoidCrossEntropyWithLogits(target, output);} /* TODO - skip this logic for now. It requires walking back the inputs which is not yet possible if (!(output instanceof Variable) && (!tf.scope().env().isEager())) { @@ -191,7 +194,7 @@ private static Operand binaryCrossentropyHelper( // TODO if (output.op().numInputess() != 1) // TODO throw new IllegalArgumentException("output can only have 1 output"); // TODO output = output.op().inout(0); - // TODO return tf.nn.sigmoidCrossEntropyWithLogits(target, output); + // TODO return fop.nn.sigmoidCrossEntropyWithLogits(target, output); // TODO} } */ @@ -235,6 +238,7 @@ public static Operand categoricalCrossentropy( boolean fromLogits, float labelSmoothing, int axis) { + FrameworkOps fop = FrameworkOps.create(tf); Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -245,7 +249,7 @@ public static Operand categoricalCrossentropy( tLabels = smoothCategoricalLabels(tf, tLabels, labelSmoothing); } if (fromLogits) { - return tf.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, axis); + return fop.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, axis); } /* TODO if (!(predictions instanceof Variable) && (!tf.scope().env().isEager())) { @@ -255,7 +259,7 @@ public static Operand categoricalCrossentropy( if (predictions.op().numOutputs() != 1) throw new IllegalArgumentException("output can only have 1 output"); predictions = predictions.op().output(0); - return tf.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, -1); + return fop.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, -1); } } */ @@ -334,13 +338,14 @@ public static Operand categoricalHinge( */ public static Operand cosineSimilarity( Ops tf, Operand labels, Operand predictions, int[] axis) { + FrameworkOps fops = FrameworkOps.create(tf); Operand tLabels = cast(tf, labels, predictions.type()); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - tLabels = l2Normalize(tf, tLabels, axis); - predictions = l2Normalize(tf, predictions, axis); + tLabels = fops.math.l2Normalize(tLabels, axis); + predictions = fops.math.l2Normalize(predictions, axis); Operand mathMul = tf.math.mul(tLabels, predictions); return tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE)); } @@ -516,6 +521,7 @@ public static Operand sparseCategoricalCrossentropy( boolean fromLogits, int axis) { Class predictionType = predictions.type(); + FrameworkOps fop = FrameworkOps.create(tf); Operand epsilonConst = cast(tf, tf.constant(EPSILON), predictionType); Operand one = cast(tf, tf.constant(1), predictionType); Operand oneMinusEpsilonConst = tf.math.sub(one, epsilonConst); @@ -569,8 +575,7 @@ public static Operand sparseCategoricalCrossentropy( new long[] {-1L, predictionsShape.size(predictionsShape.numDimensions() - 1)})); } - @SuppressWarnings("unchecked") - Operand loss = tf.nn.sparseSoftmaxCrossEntropyWithLogits(iLabels, predictions); + Operand loss = fop.nn.sparseSoftmaxCrossEntropyWithLogits(iLabels, predictions); if (updateShape && predictionsRank >= 3) { Shape newShape = predictionsShape.take(predictionsShape.numDimensions() - 1); loss = tf.reshape(loss, tf.constant(newShape)); @@ -648,23 +653,7 @@ private static Operand smoothCategoricalLabels( return tf.math.add(tf.math.mul(labels, oneMinusSmoothing), tf.math.div(smoothing, numClasses)); } - // TODO this was tf.math.l2_normalize in TF Python - /** - * Normalizes along dimension axis using an L2 norm. - * - * @param tf The TensorFlow Ops - * @param x the input - * @param axis Dimension along which to normalize. - * @param the data type for the input and the result - * @return the normalized values based on L2 norm - */ - public static Operand l2Normalize(Ops tf, Operand x, int[] axis) { - Operand squareSum = - tf.reduceSum(tf.math.square(x), tf.constant(axis), ReduceSum.keepDims(Boolean.TRUE)); - Operand invNorm = - tf.math.rsqrt(tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.type()))); - return tf.math.mul(x, invNorm); - } + /** * Converts binary labels into -1/1. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java index 22baab3d6cb..70cd826f625 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.initializers.Zeros; -import org.tensorflow.framework.metrics.impl.MetricsHelper; +import org.tensorflow.framework.op.FrameworkOps; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -124,8 +124,8 @@ public Assign getInitializer() { * @param sampleWeights Optional weighting of each example. Defaults to 1, if null. Rank is either * 0, or the same rank as labels, and must be broadcastable to labels. * @return the Operands that updates totalConfusionMatrix variable - * @throws IllegalArgumentException if the weights rank is not 0, and weights rank @{code !=} labels rank, - * and if the predictions size is not equal to the labels size + * @throws IllegalArgumentException if the weights rank is not 0, and weights rank @{code !=} + * labels rank, and if the predictions size is not equal to the labels size */ @Override public List updateStateList( @@ -167,10 +167,11 @@ public List updateStateList( tSampleWeights = getTF().shape.flatten(tSampleWeights); } + FrameworkOps fops = FrameworkOps.create(getTF()); // Accumulate the prediction to current confusion matrix. Operand currentCM = - MetricsHelper.confusionMatrix( - getTF(), tLabels, tPredictions, getTF().constant(numClasses), tSampleWeights, type); + fops.math.confusionMatrix( + tLabels, tPredictions, tSampleWeights, getTF().constant(numClasses)); return Collections.singletonList(getTF().assignAdd(totalConfusionMatrix, currentCM)); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 40336233d21..a4e19d58bcb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -18,7 +18,7 @@ import org.tensorflow.framework.losses.impl.LossTuple; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.framework.metrics.exceptions.NotBroadcastableException; -import org.tensorflow.framework.utils.SparseTensor; +import org.tensorflow.framework.op.FrameworkOps; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -26,7 +26,6 @@ import org.tensorflow.op.core.OneHot; import org.tensorflow.op.core.Rank; import org.tensorflow.op.core.Squeeze; -import org.tensorflow.op.core.Stack; import org.tensorflow.op.core.Variable; import org.tensorflow.op.math.Mean; import org.tensorflow.op.nn.TopK; @@ -94,7 +93,7 @@ public static Op assertBroadcastable( && !valuesShapeStatic.hasUnknownDimension()) { if (weightsRankStatic == 0) { return tf.withSubScope("staticScalarCheckSuccess") - .withControlDependencies(java.util.Collections.EMPTY_LIST) + .withControlDependencies(Collections.emptyList()) .noOp(); } if (weightsRankStatic != valuesRankStatic) { @@ -190,12 +189,13 @@ private static Operand canBroadcastNonscalarShapes( private static Operand canBroadcastDims( Ops tf, Operand weightsShape, Operand valuesShape) { tf = tf.withSubScope("canBroadcastDims"); + FrameworkOps fops = FrameworkOps.create(tf); Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); Operand validDims = tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); Operand weightsShape2D = tf.expandDims(weightsShape, tf.constant(-1)); - Operand diffResult = SetsOps.difference(tf, weightsShape2D, validDims); + Operand diffResult = fops.sets.difference(weightsShape2D, validDims); Operand numInvalidDims = tf.size(diffResult); return tf.math.equal(tf.constant(0), numInvalidDims); } @@ -766,125 +766,7 @@ LossTuple raggedAssertCompatibleAndGetFlatValues( } /** - * Computes the confusion matrix from predictions and labels. - * - *

The matrix columns represent the prediction labels and the rows represent the real labels. - * The confusion matrix is always a 2-D array of shape {@code [n, n]}, where {@code n} is the - * number of valid labels for a given classification task. Both prediction and labels must be 1-D - * arrays of the same shape in order for this function to work. - * - *

If {@code numClasses} is null, then {@code numClasses} will be set to one plus the maximum - * value in either predictions or labels. Class labels are expected to start at 0. For example, if - * {@code numClasses}` is 3, then the possible labels would be {@code [0, 1, 2]}. - * - *

If {@code weights} is not null, then each prediction contributes its corresponding weight to - * the total value of the confusion matrix cell. - * - *

For example: - * - *

{@code
-   *     confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
-   *          [[0 0 0 0 0]
-   *           [0 0 1 0 0]
-   *           [0 0 1 0 0]
-   *           [0 0 0 0 0]
-   *           [0 0 0 0 1]]
-   * }
- * - * Note that the possible labels are assumed to be {@code [0, 1, 2, 3,4]}, resulting in a 5x5 - * confusion matrix. - * - * @param tf the TensorFlow Ops - * @param labels 1-D {@code Operand} of real labels for the classification task. - * @param predictions 1-D {@code Operand} of predictions for a given classification. - * @param numClasses The possible number of labels the classification task can have. If this value - * is not provided, it will be calculated using both predictions and labels array. - * @param weights optional weights to be applied to the confusion matrix - * @param type Data type of the confusion matrix. - * @param the type of Operands - * @return A {@code Operand} of type {@code type} with shape {@code [n, n]} - * representing the confusion matrix, where {@code n} is the number of possible labels in - * the classification task. - * @throws IllegalArgumentException If both {@code predictions} and {@code labels} do - * not have compatible shapes, or if {@code weights} is not{@code null} and its - * shape is not compatible with {@code predictions}. - */ - // TODO should this be moved to FramnworkOps under math. - public static Operand confusionMatrix( - Ops tf, - Operand labels, - Operand predictions, - Operand numClasses, - Operand weights, - Class type) { - if (!predictions.shape().isCompatibleWith(labels.shape())) - throw new IllegalArgumentException( - String.format( - "Prediction shape %s is not compatible with labels shape %s", - predictions.shape().toString(), labels.shape().toString())); - tf = tf.withSubScope("confusionMatrix"); - LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, predictions, labels, null); - Operand tPredictions = cast(tf, ops.getTarget(), TInt64.class); - Operand tLabels = cast(tf, ops.getLabels(), TInt64.class); - - List labelControls = new ArrayList<>(); - List predictionControls = new ArrayList<>(); - - labelControls.add( - tf.assertThat( - tf.reduceAny(tf.math.greaterEqual(tLabels, tf.constant(0L)), allAxes(tf, tLabels)), - Collections.singletonList(tf.constant("`labels` contains negative values")))); - - predictionControls.add( - tf.assertThat( - tf.reduceAny( - tf.math.greaterEqual(tPredictions, tf.constant(0L)), allAxes(tf, tPredictions)), - Collections.singletonList(tf.constant("`predictions` contains negative values")))); - if (numClasses == null) { - numClasses = - tf.math.maximum( - tf.reduceMax(tPredictions, allAxes(tf, tPredictions)), - tf.reduceMax(tLabels, allAxes(tf, tLabels))); - } else { - labelControls.add( - tf.assertThat( - tf.reduceAny(tf.math.less(tLabels, numClasses), allAxes(tf, tLabels)), - Collections.singletonList(tf.constant("``labels` out of bounds")))); - predictionControls.add( - tf.assertThat( - tf.reduceAny(tf.math.less(tPredictions, numClasses), allAxes(tf, tPredictions)), - Collections.singletonList(tf.constant("``predictions` out of bounds")))); - } - - if (weights != null) { - if (!tPredictions.shape().isCompatibleWith(weights.shape())) { - throw new IllegalArgumentException( - String.format( - "Prediction shape %s is not compatible with weights shape %s", - tPredictions.shape().toString(), weights.shape().toString())); - } - } - - Ops tfc = tf.withSubScope("confusionMatrixLabels").withControlDependencies(labelControls); - tLabels = tfc.identity(tLabels); - - tfc = tf.withSubScope("confusionMatrixPredictions").withControlDependencies(predictionControls); - tPredictions = tfc.identity(tPredictions); - - Operand shape = tf.stack(Arrays.asList(numClasses, numClasses)); - Operand indices = tf.stack(Arrays.asList(tLabels, tPredictions), Stack.axis(1L)); - Operand values = - weights == null ? cast(tf, tf.onesLike(tPredictions), type) : cast(tf, weights, type); - SparseTensor cmSparse = new SparseTensor<>(indices, values, shape); - Operand zeroMatrix = tf.zeros(shape, type); - - return tf.sparse.sparseTensorDenseAdd( - cmSparse.getIndices(), cmSparse.getValues(), cmSparse.getDenseShape(), zeroMatrix); - } - - /** - * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false - * } + * Calculate the mean of the operand, along all axes and {@code keepDims} is false * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java deleted file mode 100644 index 68157632557..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java +++ /dev/null @@ -1,147 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.metrics.impl; - -import org.tensorflow.Operand; -import org.tensorflow.op.Ops; -import org.tensorflow.op.SparseOps; -import org.tensorflow.op.sparse.DenseToDenseSetOperation; -import org.tensorflow.types.family.TNumber; - -import static org.tensorflow.framework.utils.CastHelper.cast; - -/** Implementation of set operations */ -public class SetsOps { - - /** - * Computes set difference of elements in last dimension of {@code a} and {@code b} with - * {@code aMinusB} set to true. - * - *

All but the last dimension of {@code a} and {@code b} must match - * - * @param tf the TensorFlow Ops - * @param a The first operand representing set {@code a} - * @param b The other operand representing set {@code b} - * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set - * operation. - */ - public static Operand difference(Ops tf, Operand a, Operand b) { - return difference(tf, a, b, true); - } - - /** - * Computes set difference of elements in last dimension of {@code a} and {@code b}. - * - *

All but the last dimension of {@code a} and {@code b} must match - * - * @param tf the TensorFlow Ops - * @param a The first operand representing set {@code a} - * @param b The other operand representing set {@code b} - * @param aMinusB whether to subtract b from a, vs vice versa. - * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set - * operation. - */ - public static Operand difference( - Ops tf, Operand a, Operand b, boolean aMinusB) { - return setOperation(tf, a, b, aMinusB ? Operation.A_MINUS_B : Operation.B_MINUS_A); - } - - /** - * Computes set union of elements in last dimension of {@code a} and {@code b}. - * - * @param tf the TensorFlow Ops - * @param a The first operand representing set {@code a} - * @param b The other operand representing set {@code b} - * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set - * operation. - */ - public static Operand union(Ops tf, Operand a, Operand b) { - return setOperation(tf, a, b, Operation.UNION); - } - - /** - * Computes set intersection of elements in last dimension of {@code a} and {@code b}. - * - * @param tf the TensorFlow Ops - * @param a The first operand representing set {@code a} - * @param b The other operand representing set {@code b} - * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set - * operation. - */ - public static Operand intersection(Ops tf, Operand a, Operand b) { - return setOperation(tf, a, b, Operation.INTERSECTION); - } - - /** - * Compute set operation of elements in last dimension of {@code a} and {@code b}. - * - * @param tf the TensorFlow Ops - * @param a The first set operation operand - * @param b The other et operation operand - * @param setOperation The set operation to perform, {@link Operation}. - * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the same. Elements along the last dimension contain the results of the set - * operation. - */ - public static Operand setOperation( - Ops tf, Operand a, Operand b, Operation setOperation) { - - DenseToDenseSetOperation setOperationResult = - tf.sparse.denseToDenseSetOperation( - a, b, setOperation.getSetOperation(), DenseToDenseSetOperation.validateIndices(true)); - - return tf.sparse.sparseToDense( - setOperationResult.resultIndices(), - setOperationResult.resultShape(), - setOperationResult.resultValues(), - cast(tf, tf.constant(0), a.type())); - } - - /** - * Enumeration containing the string operation values to be passed to the TensorFlow Sparse Ops - * function {@link SparseOps#denseToDenseSetOperation} - */ - public enum Operation { - A_MINUS_B("a-b"), - B_MINUS_A("b-a"), - INTERSECTION("intersection"), - UNION("union"); - - private final String setOperation; - - Operation(String setOperation) { - this.setOperation = setOperation; - } - - /** - * Gets the set operation String value used to pass as the stringOperation value to {@link - * SparseOps#denseToDenseSetOperation} - * - * @return the set operation String value - */ - public String getSetOperation() { - return setOperation; - } - } -} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java index 6583465da2e..47d7f8ab737 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.metrics.impl; import org.tensorflow.Operand; +import org.tensorflow.framework.op.FrameworkOps; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -150,12 +151,13 @@ private static Operand hasValidNonscalarShape( private static Operand hasValidDims( Ops tf, Operand weightsShape, Operand valuesShape) { tf = tf.withSubScope("hasInvalidDims"); + FrameworkOps fops = FrameworkOps.create(tf); Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); Operand validDims = tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); Operand weightsShape2d = tf.expandDims(weightsShape, tf.constant(-1)); - Operand invalidDims = SetsOps.difference(tf, weightsShape2d, validDims); + Operand invalidDims = fops.sets.difference(weightsShape2d, validDims); Operand numInvalidDims = tf.size(invalidDims, TInt32.class); return tf.math.equal(tf.constant(0), numInvalidDims); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java new file mode 100644 index 00000000000..f182d9d7b80 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java @@ -0,0 +1,165 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op; + +import org.tensorflow.DeviceSpec; +import org.tensorflow.EagerSession; +import org.tensorflow.ExecutionEnvironment; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.Scope; + +/** + * An API for building framework operations as {@link Op Op}s + * + *

These are higher level ops that may invoke core ops. Higher level Ops may perform the + * operation solely in the TensorFlow framework or do preprocessing of the Operands before invoking + * a core level Op. + */ +public class FrameworkOps { + public final Ops coreOps; + public final NnOps nn; + public final SetOps sets; + public final MathOps math; + public final LinalgOps linalg; + private final Scope scope; + + /** + * Creates a FrameworkOps instance with the provided scope + * + * @param scope the scope + */ + private FrameworkOps(Scope scope) { + this.coreOps = Ops.create(scope.env()); + this.scope = scope; + nn = new NnOps(this); + sets = new SetOps(this); + math = new MathOps(this); + linalg = new LinalgOps(this); + } + + /** + * Creates a FrameworkOps instance based on the provided Core Ops + * + * @param coreOps The TensorFlow Core Ops + */ + private FrameworkOps(Ops coreOps) { + this.coreOps = coreOps; + this.scope = coreOps.scope(); + nn = new NnOps(this); + sets = new SetOps(this); + math = new MathOps(this); + linalg = new LinalgOps(this); + + } + + /** + * Creates an API for building operations in the provided execution environment + * + * @param env the exection environment + * @return the FrameworkOps + */ + public static FrameworkOps create(ExecutionEnvironment env) { + return new FrameworkOps(new Scope(env)); + } + + /** + * Creates an API for building operations in the default eager execution environment + * + *

Invoking this method is equivalent to {@code + * FrameworkOps.create(EagerSession.getDefault())}. + * + * @return the FrameworkOps + */ + public static FrameworkOps create() { + return new FrameworkOps(new Scope(EagerSession.getDefault())); + } + + /** + * Creates an API for building operations in the default eager execution environment + * + * @param coreOps the TensorFlow core Ops + * @return the FrameworkOps + */ + public static FrameworkOps create(Ops coreOps) { + return new FrameworkOps(coreOps); + } + + /** + * Returns the current {@link Scope scope} of this API + * + * @return the current {@link Scope scope} of this API + */ + public final Scope scope() { + return scope; + } + + /** + * Gets the core Ops + * + * @return coreOps + */ + public final Ops coreOps() { + return coreOps; + } + + /** + * Returns an API that builds operations with the provided name prefix. + * + *

@link Scope#withSubScope(String)} + * + * @param childScopeName the name of the child scope + * @return the FrameworkOps + */ + public FrameworkOps withSubScope(String childScopeName) { + return new FrameworkOps(scope.withSubScope(childScopeName)); + } + + /** + * Returns an API that uses the provided name for an op. + * + *

{@link Scope#withName(String)} + * + * @param opName the name of the scope + * @return the FrameworkOps + */ + public FrameworkOps withName(String opName) { + return new FrameworkOps(scope.withName(opName)); + } + + /** + * Returns an API that places the created operations on the device(s) matching the provided spec. + * + *

{@link Scope#withDevice(DeviceSpec)} + * + * @param deviceSpec the device specification for the scope + * @return the FrameworkOps + */ + public FrameworkOps withDevice(DeviceSpec deviceSpec) { + return new FrameworkOps(scope.withDevice(deviceSpec)); + } + + /** + * Returns an API that adds operations to the graph with the provided control dependencies. + * + *

{@link Scope#withControlDependencies(Iterable)} + * + * @param controls the operations + * @return the FrameworkOps + */ + public FrameworkOps withControlDependencies(Iterable controls) { + return new FrameworkOps(scope.withControlDependencies(controls)); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/LinalgOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/LinalgOps.java new file mode 100644 index 00000000000..931f7f851c2 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/LinalgOps.java @@ -0,0 +1,304 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op; + +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.SparseTensor; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.math.Conj; +import org.tensorflow.op.sparse.SparseMatMul; +import org.tensorflow.op.train.BatchMatMul; +import org.tensorflow.types.TBfloat16; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; + +public class LinalgOps { + private final Scope scope; + + private final FrameworkOps frameworkOps; + + /** + * Creates Framework {@code nn} Operations + * + * @param frameworkOps the TensorFLow framework Ops + */ + LinalgOps(FrameworkOps frameworkOps) { + this.scope = frameworkOps.scope(); + this.frameworkOps = frameworkOps; + } + + /** + * Multiplies matrix {@code a} by matrix {@code b}, producing {@code a} * {@code b + * }. + * + *

The inputs must, following any transpositions, be tensors of {@code rank >= 2} where the inner 2 + * dimensions specify valid matrix multiplication dimensions, and any further outer dimensions + * specify matching batch size. + * + *

Both matrices must be of the same type. The supported types are: {@code TFloat16}, + * {@code TFloat32}, {@code TFloat64}, {@code TInt32}. + * + *

Either matrix can be transposed or adjointed (conjugated and transposed) on the fly by + * setting one of the corresponding flag to true. These are false by default. + * + *

A simple 2-D tensor matrix multiplication: + * + *

{@code
+   * Operand a = tf.constant(new double[][] {
+   *         {-8.944851},
+   *         {4.1711287},
+   *         {-0.22380222}
+   *     });
+   * Operand b = tf.constant( new double[][] {
+   *         {-14.276086, -12.433481, -2.2447076, -1.5775859, 1.8588694}
+   *     });
+   * Operand result = fops.linalg.matmul(a, b);
+   * // result = {
+   * //     {127.69746,  111.21564,  20.078575,  14.111271,  -16.62731},
+   * //     {-59.547394, -51.861652, -9.362965,  -6.580314,    7.753584},
+   * //     {  3.1950197,  2.7826407, 0.50237054, 0.35306725, -0.4160191}
+   * //  }
+   *
+   * }
+ * + *

Note: This is matrix product, not element-wise product. + * + * @param a an Operand of of type {@code TFloat16}, {@code TFloat32}, {@code TFloat64 + * }, {@code TInt32}. with a {@code rank > 1} + * @param b an Operand with same type and rank as {@code a}. + * @param the data type of the Operands + * @return A Operand of the same type as {@code a} and {@code b} where each inner-most + * matrix is the product of the corresponding matrices in {@code a} and {@code b}. + * This is the matrix product not an element-wise product. + * @throws java.lang.IllegalArgumentException If {@code transposeA} and {@code adjointA} + * , or {@code transposeB} and {@code adjointB} are both set to `true`. + */ + @Endpoint(name = "matmul") + public Operand matmul(Operand a, Operand b) { + return matmul(a, b, false, false, false, false, false, false); + } + + /** + * Multiplies matrix {@code a} by matrix {@code b}, producing {@code a} * {@code b + * }. + * + *

The inputs must, following any transpositions, be tensors of {@code rank >= 2} where the inner 2 + * dimensions specify valid matrix multiplication dimensions, and any further outer dimensions + * specify matching batch size. + * + *

Both matrices must be of the same type. The supported types are: {@code TFloat16}, + * {@code TFloat32}, {@code TFloat64}, {@code TInt32}. + * + *

Either matrix can be transposed or adjointed (conjugated and transposed) on the fly by + * setting one of the corresponding flag to true. These are false by default. + * + *

Note: This is matrix product, not element-wise product. + * + *

A simple 2-D tensor matrix multiplication: + * + *

{@code
+   * Operand a = tf.constant(new double[][] {
+   *         {-8.944851},
+   *         {4.1711287},
+   *         {-0.22380222}
+   *     });
+   * Operand b = tf.constant( new double[][] {
+   *         {-14.276086, -12.433481, -2.2447076, -1.5775859, 1.8588694}
+   *     });
+   * Operand result = fops.linalg.matmul(a, b);
+   * // result = {
+   * //     {127.69746,  111.21564,  20.078575,  14.111271,  -16.62731},
+   * //     {-59.547394, -51.861652, -9.362965,  -6.580314,    7.753584},
+   * //     {  3.1950197,  2.7826407, 0.50237054, 0.35306725, -0.4160191}
+   * //  }
+   *
+   * }
+ * + * @param a an Operand of of type {@code TFloat16}, {@code TFloat32}, {@code TFloat64 + * }, {@code TInt32}. with a {@code rank > 1} + * @param b an Operand with same type and rank as {@code a}. + * @param transposeA If true, {@code a} is transposed before multiplication. + * @param transposeB If true, {@code b} is transposed before multiplication + * @param the data type of the Operands + * @return A Operand of the same type as {@code a} and {@code b} where each inner-most + * matrix is the product of the corresponding matrices in {@code a} and {@code b}. + * This is the matrix product not an element-wise product. + * @throws java.lang.IllegalArgumentException If {@code transposeA} and {@code adjointA} + * , or {@code transposeB} and {@code adjointB} are both set to `true`. + */ + @Endpoint(name = "matmul") + public Operand matmul( + Operand a, Operand b, boolean transposeA, boolean transposeB) { + return matmul(a, b, transposeA, transposeB, false, false, false, false); + } + + /** + * Multiplies matrix {@code a} by matrix {@code b}, producing {@code a} * {@code b + * }. + * + *

The inputs must, following any transpositions, be tensors of {@code rank >= 2} where the inner 2 + * dimensions specify valid matrix multiplication dimensions, and any further outer dimensions + * specify matching batch size. + * + *

Both matrices must be of the same type. The supported types are: {@code TFloat16}, + * {@code TFloat32}, {@code TFloat64}, {@code TInt32}. + * + *

Either matrix can be transposed or adjointed (conjugated and transposed) on the fly by + * setting one of the corresponding flag to true. These are false by default. + * + *

Note: This is matrix product, not element-wise product. + * + *

A simple 2-D tensor matrix multiplication: + * + *

{@code
+   * Operand a = tf.constant(new double[][] {
+   *         {-8.944851},
+   *         {4.1711287},
+   *         {-0.22380222}
+   *     });
+   * Operand b = tf.constant( new double[][] {
+   *         {-14.276086, -12.433481, -2.2447076, -1.5775859, 1.8588694}
+   *     });
+   * Operand result = fops.linalg.matmul(a, b);
+   * // result = {
+   * //     {127.69746,  111.21564,  20.078575,  14.111271,  -16.62731},
+   * //     {-59.547394, -51.861652, -9.362965,  -6.580314,    7.753584},
+   * //     {  3.1950197,  2.7826407, 0.50237054, 0.35306725, -0.4160191}
+   * //  }
+   *
+   * }
+ * + * @param a an Operand of of type {@code TFloat16}, {@code TFloat32}, {@code TFloat64 + * }, {@code TInt32}. with a {@code rank > 1} + * @param b an Operand with same type and rank as {@code a}. + * @param transposeA If true, {@code a} is transposed before multiplication. + * @param transposeB If True, {@code b} is transposed before multiplication + * @param adjointA If true, {@code a} is conjugated and transposed before multiplication. + * @param adjointB If true, {@code b} is conjugated and transposed before multiplication. + * @param aIsSparse If true, {@code a} is treated as a sparse matrix. Notice, this does + * not support {@link SparseTensor}, it just makes optimizations that assume most values + * in {@code a} are zero. + * @param bIsSparse If true, {@code b} is treated as a sparse matrix. Notice, this does + * not support {@link SparseTensor}, it just makes optimizations that assume most values + * in {@code b} are zero. + * @param the data type of the Operands + * @return A Operand of the same type as {@code a} and {@code b} where each inner-most + * matrix is the product of the corresponding matrices in {@code a} and {@code b}. + * This is the matrix product not an element-wise product. + * @throws java.lang.IllegalArgumentException If {@code transposeA} and {@code adjointA} + * , or {@code transposeB} and {@code adjointB} are both set to `true`. + */ + @SuppressWarnings("unchecked") + @Endpoint(name = "matmul") + public Operand matmul( + Operand a, + Operand b, + boolean transposeA, + boolean transposeB, + boolean adjointA, + boolean adjointB, + boolean aIsSparse, + boolean bIsSparse) { + Scope lscope = scope.withSubScope("MatMul"); + if (transposeA && adjointA) + throw new IllegalArgumentException("Only one of transposeA and adjointA can be true."); + if (transposeB && adjointB) + throw new IllegalArgumentException("Only one of transposeB and adjointB can be true."); + if (!(TFloating.class.isAssignableFrom(a.type()) || a.type().equals(TInt32.class))) + throw new IllegalArgumentException( + String.format( + "Operand 'a' must be of type 'TBfloat16','TFloat16', 'TFloat32', 'TFloat64' or 'TInt32'. found type : %s", + a.type().getSimpleName())); + if (!(TFloating.class.isAssignableFrom(a.type()) || b.type().equals(TInt32.class))) + throw new IllegalArgumentException( + String.format( + "Operand 'b' must be of type 'TBfloat16', 'TFloat32', 'TFloat64' or 'TInt32'. found type : %s", + b.type().getSimpleName())); + + Shape aShape = a.shape(); + Shape bShape = b.shape(); + if (aShape.numDimensions() != bShape.numDimensions()) + throw new IllegalArgumentException( + String.format( + "Parameters 'a' and 'b' must the same rank: found a rank = %d, b rank = %d", + aShape.numDimensions(), bShape.numDimensions())); + boolean outputMayHaveNonEmptyBatchShape = + aShape.numDimensions() == Shape.UNKNOWN_SIZE + || aShape.numDimensions() > 2 + || bShape.numDimensions() == Shape.UNKNOWN_SIZE; + + if ((!aIsSparse && !bIsSparse) && outputMayHaveNonEmptyBatchShape) { + // BatchMatmul does not support transpose, so we conjugate the matrix and + // use adjoint instead. Conj() is a noop for real matrices. + if (transposeA) { + a = Conj.create(scope, a); + adjointA = true; + } + if (transposeB) { + b = Conj.create(scope, b); + adjointB = true; + } + return BatchMatMul.create( + lscope, a, b, BatchMatMul.adjX(adjointA), BatchMatMul.adjY(adjointB)); + } + + // Neither matmul nor sparse_matmul support adjoint, so we conjugate + // the matrix and use transpose instead. Conj() is a noop for real + // matrices. + if (adjointA) { + a = Conj.create(scope, a); + transposeA = true; + } + if (adjointB) { + b = Conj.create(scope, b); + transposeB = true; + } + + boolean useSparseMatmul = false; + if (aIsSparse || bIsSparse) { + useSparseMatmul = + (a.type().equals(TBfloat16.class) || a.type().equals(TFloat32.class)) + && (b.type().equals(TBfloat16.class) || b.type().equals(TFloat32.class)); + } + if ((a.type().equals(TBfloat16.class) || b.type().equals(TBfloat16.class)) + && !a.type().equals(b.type())) useSparseMatmul = true; + + if (useSparseMatmul) { + Operand result = + SparseMatMul.create( + lscope, + a, + b, + SparseMatMul.transposeA(transposeA), + SparseMatMul.transposeB(transposeB), + SparseMatMul.aIsSparse(aIsSparse), + SparseMatMul.bIsSparse(bIsSparse)); + if (a.type().equals(TFloat32.class)) return (Operand) result; + else return Cast.create(scope, result, a.type()); + } + + return org.tensorflow.op.linalg.MatMul.create( + lscope, + a, + b, + org.tensorflow.op.linalg.MatMul.transposeA(transposeA), + org.tensorflow.op.linalg.MatMul.transposeB(transposeB)); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java new file mode 100644 index 00000000000..8fda58806ca --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java @@ -0,0 +1,1142 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.core.AssertThat; +import org.tensorflow.op.core.Concat; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Gather; +import org.tensorflow.op.core.Identity; +import org.tensorflow.op.core.OnesLike; +import org.tensorflow.op.core.Range; +import org.tensorflow.op.core.Rank; +import org.tensorflow.op.core.ReduceAll; +import org.tensorflow.op.core.ReduceMax; +import org.tensorflow.op.core.ReduceProd; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.op.core.Reshape; +import org.tensorflow.op.core.ScatterNd; +import org.tensorflow.op.core.Select; +import org.tensorflow.op.core.SetDiff1d; +import org.tensorflow.op.core.Slice; +import org.tensorflow.op.core.Squeeze; +import org.tensorflow.op.core.Stack; +import org.tensorflow.op.core.StopGradient; +import org.tensorflow.op.core.ZerosLike; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.linalg.Transpose; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Exp; +import org.tensorflow.op.math.GreaterEqual; +import org.tensorflow.op.math.IsFinite; +import org.tensorflow.op.math.Less; +import org.tensorflow.op.math.Log; +import org.tensorflow.op.math.Maximum; +import org.tensorflow.op.math.Mul; +import org.tensorflow.op.math.Rsqrt; +import org.tensorflow.op.math.Square; +import org.tensorflow.op.math.Sub; +import org.tensorflow.types.TBfloat16; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat16; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +public class MathOps { + private final Scope scope; + + private final FrameworkOps frameworkOps; + + /** + * Creates Framework {@code nn} Operations + * + * @param frameworkOps the TensorFLow framework Ops + */ + MathOps(FrameworkOps frameworkOps) { + this.scope = frameworkOps.scope(); + this.frameworkOps = frameworkOps; + } + + /** + * Normalizes along dimension axis using an L2 norm. + * + * @param x the input + * @param axis Dimension along which to normalize. + * @param the data type for the input and the result + * @return the normalized values based on L2 norm + */ + public Operand l2Normalize(Operand x, int[] axis) { + Operand squareSum = + ReduceSum.create( + scope, + Square.create(scope, x), + Constant.vectorOf(scope, axis), + ReduceSum.keepDims(Boolean.TRUE)); + Operand invNorm = + Rsqrt.create( + scope, + Maximum.create( + scope, squareSum, Cast.create(scope, Constant.scalarOf(scope, 1e-12F), x.type()))); + return Mul.create(scope, x, invNorm); + } + + /** + * Computes the confusion matrix from predictions and labels. + * + *

The matrix columns represent the prediction labels and the rows represent the real labels. + * The confusion matrix is always a 2-D array of shape {@code [n,n]}, where {@code n} is the number of valid + * labels for a given classification task. Both prediction and labels must be 1-D arrays of the + * same shape in order for this function to work. + * + *

If {@code numClasses} is null, then {@code numClasses} will be set to one plus the maximum value in + * either predictions or labels. Class labels are expected to start at 0. For example, if + * {@code numClasses} is 3, then the possible labels would be {@code [0, 1, 2]}. + * + *

If {@code weights} is not null, then each prediction contributes its corresponding weight to the + * total value of the confusion matrix cell. + * + *

For example: + * + *

{@code
+   *     fops.math.confusion_matrix(tf.constant(new int[] {1, 2, 4}), tf.constant(new int[] {2, 2, 4})) ==>
+   *         [[0 0 0 0 0]
+   *          [0 0 1 0 0]
+   *          [0 0 1 0 0]
+   *          [0 0 0 0 0]
+   *          [0 0 0 0 1]]
+   * }
+ * + *

Note that the possible labels are assumed to be {@code [0, 1, 2, 3, 4]}, resulting in a 5x5 + * confusion matrix. + * + * @param labels 1-D Operand of real labels for the classification task. + * @param predictions 1-D Operand of predictions for a given classification. + * @param Data type of the confusion matrix. + * @return An Operand of type {@code type} with shape {@code [n, n]} representing the confusion + * matrix, where {@code n} is the number of possible labels in the classification task. + * @throws IllegalArgumentException If both predictions and labels are not 1-D vectors and have + * mismatched shapes, or if {@code weights} is not null and its shape doesn't match {@code + * predictions}. + */ + public Operand confusionMatrix(Operand labels, Operand predictions) { + return confusionMatrix(labels, predictions, null, null); + } + + /** + * Computes the confusion matrix from predictions and labels. + * + *

The matrix columns represent the prediction labels and the rows represent the real labels. + * The confusion matrix is always a 2-D array of shape {@code [n,n]}, where {@code n} is the number of valid + * labels for a given classification task. Both prediction and labels must be 1-D arrays of the + * same shape in order for this function to work. + * + *

If {@code numClasses} is null, then {@code numClasses} will be set to one plus the maximum value in + * either predictions or labels. Class labels are expected to start at 0. For example, if + * {@code numClasses} is 3, then the possible labels would be {@code [0, 1, 2]}. + * + *

If {@code weights} is not null, then each prediction contributes its corresponding weight to the + * total value of the confusion matrix cell. + * + *

For example: + * + *

{@code
+   *     fops.math.confusion_matrix(tf.constant(new int[] {1, 2, 4}), tf.constant(new int[] {2, 2, 4})) ==>
+   *         [[0 0 0 0 0]
+   *          [0 0 1 0 0]
+   *          [0 0 1 0 0]
+   *          [0 0 0 0 0]
+   *          [0 0 0 0 1]]
+   * }
+ * + *

Note that the possible labels are assumed to be {@code [0, 1, 2, 3, 4]}, resulting in a 5x5 + * confusion matrix. + * + * @param labels 1-D Operand of real labels for the classification task. + * @param predictions 1-D Operand of predictions for a given classification. + * @param weights An optional Operand whose shape matches {@code predictions}. + * @param Data type of the confusion matrix. + * @return An Operand of type {@code type} with shape {@code [n, n]} representing the confusion + * matrix, where {@code n} is the number of possible labels in the classification task. + * @throws IllegalArgumentException If both predictions and labels are not 1-D vectors and have + * mismatched shapes, or if {@code weights} is not null and its shape doesn't match {@code + * predictions}. + */ + public Operand confusionMatrix( + Operand labels, Operand predictions, Operand weights) { + return confusionMatrix(labels, predictions, weights, null); + } + + /** + * Computes the confusion matrix from predictions and labels. + * + *

The matrix columns represent the prediction labels and the rows represent the real labels. + * The confusion matrix is always a 2-D array of shape {@code [n,n]}, where {@code n} is the number of valid + * labels for a given classification task. Both prediction and labels must be 1-D arrays of the + * same shape in order for this function to work. + * + *

If {@code numClasses} is null, then {@code numClasses} will be set to one plus the maximum value in + * either predictions or labels. Class labels are expected to start at 0. For example, if + * {@code numClasses} is 3, then the possible labels would be {@code [0, 1, 2]}. + * + *

If {@code weights} is not null, then each prediction contributes its corresponding weight to the + * total value of the confusion matrix cell. + * + *

For example: + * + *

{@code
+   *     fops.math.confusion_matrix(tf.constant(new int[] {1, 2, 4}), tf.constant(new int[] {2, 2, 4})) ==>
+   *         [[0 0 0 0 0]
+   *          [0 0 1 0 0]
+   *          [0 0 1 0 0]
+   *          [0 0 0 0 0]
+   *          [0 0 0 0 1]]
+   * }
+ * + *

Note that the possible labels are assumed to be {@code [0, 1, 2, 3, 4]}, resulting in a 5x5 + * confusion matrix. + * + * @param labels 1-D Operand of real labels for the classification task. + * @param predictions 1-D Operand of predictions for a given classification. + * @param weights An optional Operand whose shape matches {@code predictions}. + * @param numClasses The possible number of labels the classification task can have. If this value + * is null, it will be calculated using both predictions and labels. + * @param Data type of the confusion matrix. + * @return An Operand of type {@code type} with shape {@code [n, n]} representing the confusion + * matrix, where {@code n} is the number of possible labels in the classification task. + * @throws IllegalArgumentException If both predictions and labels are not 1-D vectors and have + * mismatched shapes, or if {@code weights} is not null and its shape doesn't match {@code + * predictions}. + */ + public Operand confusionMatrix( + Operand labels, Operand predictions, Operand weights, Operand numClasses) { + Scope lScope = scope.withSubScope("confusionMatrix"); + LossTuple tuple = removeSqueezableDimensions(labels, predictions, 0); + Operand lLabels = Cast.create(lScope, tuple.getLabels(), TInt64.class); + Operand lPredictions = Cast.create(lScope, tuple.getTarget(), TInt64.class); + + Operand zero = Constant.scalarOf(lScope, 0L); + Operand one = Constant.scalarOf(lScope, 1L); + + AssertThat labelsNonNegative = + AssertThat.create( + lScope, + ReduceAll.create(lScope, GreaterEqual.create(lScope, lLabels, zero), allAxes(lLabels)), + Collections.singletonList( + Constant.scalarOf(lScope, "labels contains negative values"))); + lLabels = + Identity.create( + lScope.withControlDependencies(Collections.singletonList(labelsNonNegative)), lLabels); + + AssertThat predictionsNonNegative = + AssertThat.create( + lScope, + ReduceAll.create( + lScope, GreaterEqual.create(lScope, lPredictions, zero), allAxes(lPredictions)), + Collections.singletonList( + Constant.scalarOf(lScope, "predictions contains negative values"))); + lPredictions = + Identity.create( + lScope.withControlDependencies(Collections.singletonList(predictionsNonNegative)), + lPredictions); + + Operand lNumClasses; + if (numClasses == null) { + lNumClasses = + Add.create( + lScope, + Maximum.create( + lScope, + ReduceMax.create(lScope, lPredictions, zero), + ReduceMax.create(lScope, lLabels, zero)), + one); + } else { + lNumClasses = Cast.create(lScope, numClasses, TInt64.class); + Operand less = Less.create(lScope, lLabels, lNumClasses); + AssertThat labelsLess = + AssertThat.create( + lScope, + ReduceAll.create(scope, less, allAxes(less), ReduceAll.keepDims(false)), + Collections.singletonList(Constant.scalarOf(lScope, "labels out of bounds"))); + lLabels = + Identity.create( + lScope.withControlDependencies(Collections.singletonList(labelsLess)), lLabels); + + less = Less.create(lScope, lPredictions, lNumClasses); + AssertThat predictionsLess = + AssertThat.create( + lScope, + ReduceAll.create(scope, less, allAxes(less), ReduceAll.keepDims(false)), + Collections.singletonList(Constant.scalarOf(lScope, "predictions out of bounds"))); + lPredictions = + Identity.create( + lScope.withControlDependencies(Collections.singletonList(predictionsLess)), + lPredictions); + } + + if (weights != null) { + if (!predictions.shape().isCompatibleWith(weights.shape())) { + throw new IllegalArgumentException( + String.format( + "predictions.shape() [%s], is not compatible with weights.shape() [ %s].", + predictions.shape(), weights.shape())); + } + } + + Operand shape = Stack.create(lScope, Arrays.asList(lNumClasses, lNumClasses)); + Operand indices = + Stack.create(lScope, Arrays.asList(lLabels, lPredictions), Stack.axis(1L)); + Operand values = weights == null ? OnesLike.create(lScope, predictions) : weights; + /// Operand zeroMatrix = Zeros.create(lScope, Cast.create(lScope, shape, TInt32.class), + // type); + + return ScatterNd.create(lScope, indices, values, shape); + } + + /** + * Squeeze last dim if ranks differ from expected by exactly 1. + * + * @param labels Label values, a {@code Operand} whose dimensions match {@code predictions + * }. + * @param predictions Predicted values, a {@code Tensor} of arbitrary dimensions. + * @param expectedRankDiff Expected result of {@code rank(predictions) - rank(labels)}. + * @param the data type for the labels, predictions and result + * @return {@code labels} and {@code predictions}, possibly with last dim squeezed. + */ + public LossTuple removeSqueezableDimensions( + Operand labels, Operand predictions, int expectedRankDiff) { + Scope lScope = scope.withSubScope("removeSqueezableDimensions"); + Shape predictionsShape = predictions.shape(); + int predictionsRank = predictionsShape.numDimensions(); + Shape labelsShape = labels.shape(); + int labelsRank = labelsShape.numDimensions(); + + if (predictionsRank != Shape.UNKNOWN_SIZE || labelsRank != Shape.UNKNOWN_SIZE) { + // Use rank. + int rankDiff = predictionsRank - labelsRank; + if (rankDiff == expectedRankDiff + 1 && Shape.isCompatible(predictionsShape.size(-1), 1)) { + predictions = Squeeze.create(lScope, predictions); + } else if (rankDiff == expectedRankDiff - 1 && Shape.isCompatible(labelsShape.size(-1), 1)) { + labels = Squeeze.create(lScope, labels); + } + return new LossTuple<>(labels, predictions); + } + // Use dynamic rank. + + // TODO: hold for lazy select feature, + // Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); + if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) { + /* + * TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze + * predictions = tf.select( tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), + * tf.squeeze(predictions, Squeeze.axis(Arrays.asList(-1L))), predictions ); * + */ + predictions = + Squeeze.create(lScope, predictions, Squeeze.axis(Collections.singletonList(-1L))); + } + if (labelsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(labelsShape.size(-1), 1)) { + /* + * TODO, if we ever get a select that does lazy evaluation labels = tf.select( + * tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), tf.squeeze(labels, + * Squeeze.axis(Arrays.asList(-1L))), predictions ); * + */ + labels = Squeeze.create(lScope, labels, Squeeze.axis(Collections.singletonList(-1L))); + } + return new LossTuple<>(labels, predictions); + } + + /** + * Creates an Operand that has all axes contained in the Operand's shape. + * + * @param op the Operand + * @return an Operand that has all axes contained in the Operand's shape.. + */ + public Operand allAxes(Operand op) { + int rank = op.shape().numDimensions(); + if (rank != Shape.UNKNOWN_SIZE) { + int[] axes = new int[rank]; + for (int i = 0; i < rank; i++) { + axes[i] = i; + } + return Constant.vectorOf(scope, axes); + } else { + return Range.create( + scope, Constant.scalarOf(scope, 0), Rank.create(scope, op), Constant.scalarOf(scope, 1)); + } + } + + /** + * Transpose and reshape the input for contraction op. + * + *

This method is helpful in reducing {@code math.tensordot} to {@code math_ops.matmul} using + * {@code array_ops.transpose} and {@code array_ops.reshape}. The method takes a tensor and performs the + * correct transpose and reshape operation for a given set of indices. It returns the reshaped + * tensor as well as a list of indices necessary to reshape the tensor again after matrix + * multiplication. + * + * @param the type of Operand + * @param a the Tensor + * @param axis unique indices specifying valid axes of {@code a}. + * @param flipped whether to flip the dimensions or not + * @return A tuple (reshapedA, freeDims, freeDimsStatic) where reshapedA is a reshaped to allow + * contraction via matmul, freeDims is a TInt32 Operand, depending on whether the shape of a + * is fully specified, and freeDimsStatic is either a list of integers and null values, or + * None, representing the inferred shape of the free dimensions + */ + private Object[] tensordotReshape( + Operand a, Operand axis, boolean flipped) { + Shape aShape = a.shape(); + + if (!aShape.hasUnknownDimension()) { // calculate using values + long[] aShapeDims = aShape.asArray(); + if (aShapeDims == null) aShapeDims = new long[0]; + long[] aDimsIndex = new long[aShapeDims.length]; + for (int i = 0; i < aDimsIndex.length; i++) aDimsIndex[i] = i; + + // get int array from axis Operand + int[] iAxes = getIntArray(axis); + // Convert negative axes to positive + for (int i = 0; i < iAxes.length; i++) + iAxes[i] = iAxes[i] >= 0 ? iAxes[i] : Math.floorMod(iAxes[i], iAxes.length); + + // convert integer axis to long axis + long[] lAxes = Arrays.stream(iAxes).mapToLong(i -> i).toArray(); + + // create list of the axes, dims, and free axes + List axesList = Arrays.stream(lAxes).boxed().collect(Collectors.toList()); + List freeList = Arrays.stream(aDimsIndex).boxed().collect(Collectors.toList()); + freeList.removeAll(axesList); + + // create array of free dims + long[] free = freeList.stream().mapToLong(i -> i).toArray(); + long[] freeDims = new long[free.length]; + for (int i = 0; i < free.length; i++) freeDims[i] = aShapeDims[(int) free[i]]; + + // Calculate the free dim by doing a reduce prod + long prodFree = 1; + for (long i : freeDims) { + prodFree *= i; + } + + // calculate the used dims by doing a reduce prod + long prodAxis = 1; + for (long i : lAxes) { + prodAxis *= aShapeDims[(int) i]; + } + + // setup the permutations array for the transpose + long[] perm = new long[freeDims.length + lAxes.length]; + Shape newShape; + if (flipped) { + System.arraycopy(lAxes, 0, perm, 0, lAxes.length); + System.arraycopy(free, 0, perm, lAxes.length, free.length); + newShape = Shape.of(prodAxis, prodFree); + } else { + System.arraycopy(free, 0, perm, 0, free.length); + System.arraycopy(lAxes, 0, perm, freeDims.length, lAxes.length); + newShape = Shape.of(prodFree, prodAxis); + } + + Operand aTrans; + long[] arrange = new long[lAxes.length]; + for (int i = 0; i < arrange.length; i++) arrange[i] = i; + + // if the permutations is not equals to the natural order of the dims, then do a transpose + if (!Arrays.equals(perm, arrange)) { + aTrans = Transpose.create(scope, a, Constant.vectorOf(scope, perm)); + } else { + aTrans = a; + } + + // reshape the final result to the new Shape, if necessary + Operand aReshaped = + aTrans.asOutput().shape().equals(newShape) + ? aTrans + : Reshape.create(scope, aTrans, Constant.vectorOf(scope, newShape.asArray())); + // return a tuple for the reshaped Operand, and Operand for the free dimensions, and a long + // array for the free dimensions + return new Object[] {aReshaped, Constant.vectorOf(scope, freeDims), freeDims}; + + } else { // calculate dynamically + + long[] freeDimsStatic = null; + Operand one = Constant.scalarOf(scope, 1); + Operand minusOne = Constant.scalarOf(scope, -1); + Operand zero = Constant.scalarOf(scope, 0); + org.tensorflow.op.core.Shape tShape = org.tensorflow.op.core.Shape.create(scope, a); + Operand axesT; + Operand freeT; + if (aShape.numDimensions() + != Shape.UNKNOWN_SIZE) { // we know the rank, but there are unknown dimensions + long[] aShapeDims = aShape.asArray(); + if (aShapeDims == null) aShapeDims = new long[0]; + + // get int array from axis Operand + int[] iAxes = getIntArray(axis); + // Convert negative axes to positive + for (int i = 0; i < iAxes.length; i++) + iAxes[i] = iAxes[i] >= 0 ? iAxes[i] : Math.floorMod(iAxes[i], iAxes.length); + + // convert integer axis to long axis + long[] lAxes = Arrays.stream(iAxes).mapToLong(i -> i).toArray(); + + // create list of the axes, dims, and free axes + List axesList = Arrays.stream(lAxes).boxed().collect(Collectors.toList()); + List dimsList = Arrays.stream(aShapeDims).boxed().collect(Collectors.toList()); + List freeList = new ArrayList<>(axesList); + freeList.removeAll(dimsList); + + // create array of free dims + long[] freeDims = freeList.stream().mapToLong(i -> i).toArray(); + freeDimsStatic = freeDims; + + axesT = Constant.vectorOf(scope, iAxes); + freeT = Cast.create(scope, Constant.vectorOf(scope, freeDims), TInt32.class); + + } else { // we don't know the rank yet + Rank rank = Rank.create(scope, a); + + // convert axis to positive + axesT = + Select.create( + scope, + GreaterEqual.create(scope, axis, Constant.scalarOf(scope, 0)), + axis, + Add.create(scope, axis, rank)); + + SetDiff1d diff = + SetDiff1d.create( + scope, Range.create(scope, Constant.scalarOf(scope, 0), rank, one), axesT); + freeT = diff.out(); + } + Operand freeDims = Gather.create(scope, tShape, freeT, zero); + Operand axesDims = Gather.create(scope, tShape, axesT, zero); + Operand prodFreeDims = ReduceProd.create(scope, freeDims, minusOne); + Operand prodAxesDims = ReduceProd.create(scope, axesDims, minusOne); + Operand perm; + Operand newShape; + if (flipped) { + perm = Concat.create(scope, Arrays.asList(axesT, freeT), zero); + newShape = Stack.create(scope, Arrays.asList(prodAxesDims, prodFreeDims)); + } else { + perm = Concat.create(scope, Arrays.asList(freeT, axesT), zero); + newShape = Stack.create(scope, Arrays.asList(prodFreeDims, prodAxesDims)); + } + Operand aReshaped = Reshape.create(scope, Transpose.create(scope, a, perm), newShape); + return new Object[] {aReshaped, freeDims, freeDimsStatic}; + } + } + + /** + * Gets an int array from an Operand<TInt32> operand. + * + * @param axes the Operand to fetch the values + * @return the int array from an Operand<TInt32> + */ + private int[] getIntArray(Operand axes) { + List result = new ArrayList<>(); + if (scope.env().isEager()) { + axes.asTensor().scalars().forEach(s -> result.add(s.getInt())); + } else { + try (Session session = new Session((Graph) scope.env()); + TInt32 tensor = (TInt32) session.runner().fetch(axes).run().get(0)) { + tensor.scalars().forEach(s -> result.add(s.getInt())); + } + } + return result.stream().mapToInt(i -> i).toArray(); + } + + /** + * Generates two sets of contraction axes for the two tensor arguments. + * + * @param a the Operand to analyze + * @param axis the axes + * @param the data type for the Operand + * @return the contraction axes + */ + @SuppressWarnings("unchecked") + private Operand[] tensordotAxes(Operand a, int axis) { + Shape aShape = a.asOutput().shape(); + if (axis < 0) { + throw new IllegalArgumentException("'axis' must be at least 0."); + } + int rank = aShape.numDimensions(); + Operand[] result = new Operand[2]; + if (rank != Shape.UNKNOWN_SIZE) { + if (axis > rank) { + throw new IllegalArgumentException( + String.format( + "'axis' must not be larger than the number of dimensions of tensor %s.", rank)); + } + int min = rank - axis; + int postRange = rank - min; + int[] postAxis = new int[postRange]; + for (int i = 0; i < postRange; i++) postAxis[i] = i + min; + + int[] preAxis = new int[axis]; + for (int i = 0; i < axis; i++) preAxis[i] = i; + + result[0] = Constant.vectorOf(scope, postAxis); + result[1] = Constant.vectorOf(scope, preAxis); + } else { + Rank rankT = Rank.create(scope, a); + Constant axisT = Constant.scalarOf(scope, axis); + Constant one = Constant.scalarOf(scope, 1); + Constant zero = Constant.scalarOf(scope, 0); + AssertThat assertion = + AssertThat.create( + scope, + Less.create(scope, axisT, rankT), + Arrays.asList( + Constant.scalarOf( + scope, "'axes' must not be larger than the number of dimensions of tensor "), + rankT)); + Scope scope1 = scope.withControlDependencies(Collections.singletonList(assertion)); + result[0] = Range.create(scope1, Sub.create(scope, rankT, axisT), rankT, one); + result[1] = Range.create(scope1, zero, axisT, one); + } + return result; + } + + /** + * Generates two sets of contraction axes for the two tensor arguments. + * + * @param a the Operand to analyze + * @param axes the axes + * @param the data type for the Operand + * @return the contraction axes + */ + @SuppressWarnings({"unchecked", "unused"}) + private Operand[] tensordotAxes(Operand a, int[] axes) { + if (axes.length != 2) + throw new IllegalArgumentException( + "'axes' must have length 1 or 2, provided with " + axes.length); + int[] aAxis = new int[] {axes[0]}; + int[] bAxis = new int[] {axes[1]}; + Operand[] result = new Operand[2]; + result[0] = Constant.vectorOf(scope, aAxis); + result[1] = Constant.vectorOf(scope, bAxis); + + return result; + } + + /** + * Generates two sets of contraction axes for the two tensor arguments. + * + * @param a the Operand to analyze + * @param axes the axes + * @param the data type for the Operand + * @return the contraction axes + */ + @SuppressWarnings({"unchecked", "unused"}) + private Operand[] tensordotAxes(Operand a, int[][] axes) { + if (axes.length != 2) + throw new IllegalArgumentException( + "'axes' must have length 1 or 2, provided with " + axes.length); + int[] aAxis = axes[0]; + int[] bAxis = axes[1]; + if (aAxis.length != bAxis.length) + throw new IllegalArgumentException( + String.format( + "Different number of contraction axes 'a' and 'b', %d != %d", + aAxis.length, bAxis.length)); + Operand[] result = new Operand[2]; + result[0] = Constant.vectorOf(scope, aAxis); + result[1] = Constant.vectorOf(scope, bAxis); + return result; + } + + /** + * Generates two sets of contraction axes for the two tensor arguments. + * + * @param a the Operand to analyze + * @param axes the axes + * @param the data type for the Operand + * @return the contraction axes + */ + @SuppressWarnings({"unchecked", "unused"}) + private Operand[] tensordotAxes(Operand a, Operand axes) { + + Constant one = Constant.scalarOf(scope, 1); + Constant zero = Constant.scalarOf(scope, 0); + Operand[] result = new Operand[2]; + result[0] = + Slice.create( + scope, + axes, + Cast.create(scope, zero, TInt32.class), + Cast.create(scope, one, TInt32.class)); + result[1] = + Slice.create( + scope, + axes, + Cast.create(scope, one, TInt32.class), + Cast.create(scope, one, TInt32.class)); + return result; + } + + /** + * Tensor contraction of a and b along specified axes and outer product. + *

+ * Tensordot (also known as tensor contraction) sums the product of elements + * from {@code a} and {@code b} over the indices specified by + * {@code a_axes} and {@code b_axes}. The lists + * {@code a_axes} and {@code b_axes} specify those pairs of axes + * along which to contract the tensors. The axis {@code a_axes[i]} of + * {@code a} must have the same dimension as axis + * {@code b_axes[i]} of {@code b} for all {@code i} in + * {@code range(0, len(a_axes))}. The lists + * {@code a_axes} and {@code b_axes} must have identical length + * and consist of unique integers that specify valid axes for each of the + * tensors. Additionally outer product is supported by passing + * {@code axes=0}. + *

+ * This operation corresponds to {@code numpy.tensordot(a, b, axes)}. + *

+ * Example 1: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes = 1} is equivalent to matrix multiplication. + *

+ * Example 2: When {@code a} and {@code b} are matrices (order 2), + * the case + * {@code axes = [[1], [0]]} is equivalent to matrix multiplication. + *

+ * Example 3: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes=0} gives the outer product, a tensor of order + * 4. + *

+ * Example 4: Suppose that aijk and blmn + * represent two tensors of order 3. Then, {@code contract(a, b, [[0], [2]])} is the order 4 tensor + * cjklm whose entry corresponding to the indices + * (j,k,l,m) is given by: + * cjklm = Σi aijk + * blmi . + *

+ * In general, {@code order(c) = order(a) + order(b) - 2*len(axes[0])}. + * + * @param a {@code Operand} of type {@code TFloat32} or {@code TFloat64}. + * @param b {@code Operand} with the same type as {@code a}. + * @param axis sum over the last N axes of a and the + * first N axes of b in order. If {@code axis=0}, computes the outer + * product between {@code a} and {@code b}. + * @param the datatype of the Operands, must be either TFloat32 or + * TFloat64 + * @return A {@code Operand} with the same type as {@code a}. + * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type + */ + @Endpoint(name = "tensordot") + public Operand tensordot(Operand a, Operand b, int axis) { + + Operand[] abAxis = tensordotAxes(a, axis); + Operand aAxis = abAxis[0]; + Operand bAxis = abAxis[1]; + return tensordot(a, b, aAxis, bAxis); + } + + /** + * Tensor contraction of a and b along specified axes and outer product. + *

+ * Tensordot (also known as tensor contraction) sums the product of elements + * from {@code a} and {@code b} over the indices specified by + * {@code a_axes} and {@code b_axes}. The lists + * {@code a_axes} and {@code b_axes} specify those pairs of axes + * along which to contract the tensors. The axis {@code a_axes[i]} of + * {@code a} must have the same dimension as axis + * {@code b_axes[i]} of {@code b} for all {@code i} in + * {@code range(0, len(a_axes))}. The lists + * {@code a_axes} and {@code b_axes} must have identical length + * and consist of unique integers that specify valid axes for each of the + * tensors. Additionally outer product is supported by passing + * {@code axes=0}. + *

+ * This operation corresponds to {@code numpy.tensordot(a, b, axes)}. + *

+ * Example 1: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes = 1} is equivalent to matrix multiplication. + *

+ * Example 2: When {@code a} and {@code b} are matrices (order 2), + * the case + * {@code axes = [[1], [0]]} is equivalent to matrix multiplication. + *

+ * Example 3: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes=0} gives the outer product, a tensor of order + * 4. + *

+ * Example 4: Suppose that aijk and blmn + * represent two tensors of order 3. Then, {@code contract(a, b, [[0], [2]])} is the order 4 tensor + * cjklm whose entry corresponding to the indices + * (j,k,l,m) is given by: + *

+ * cjklm = Σi aijk + * blmi . + *

+ * In general, {@code order(c) = order(a) + order(b) - 2*len(axes[0])}. + *

+ * + * @param a {@code Operand} of type {@code TFloat32} or {@code TFloat64}. + * @param b {@code Operand} with the same type as {@code a}. + * @param axes If axes is a scalar, sum over the last N axes of a and the + * first N axes of b in order. If axes is a list, the first and second row + * contain the set of unique integers specifying axes along which the + * contraction is computed, for {@code a} and {@code b}, respectively. The number of + * axes for {@code a} and {@code b} must be equal. If {@code axis=0}, computes the outer + * product between {@code a} and {@code b}. + * @param the datatype of the Operands, must be either TFloat32 or + * TFloat64 + * @return A {@code Operand} with the same type as {@code a}. + * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type + */ + @Endpoint(name = "tensordot") + public Operand tensordot( + Operand a, Operand b, Operand axes) { + + Operand[] abAxis = tensordotAxes(a, axes); + Operand aAxis = abAxis[0]; + Operand bAxis = abAxis[1]; + + return tensordot(a, b, aAxis, bAxis); + } + + /** + * Tensor contraction of a and b along specified axes and outer product. + *

+ * Tensordot (also known as tensor contraction) sums the product of elements + * from {@code a} and {@code b} over the indices specified by + * {@code a_axes} and {@code b_axes}. The lists + * {@code a_axes} and {@code b_axes} specify those pairs of axes + * along which to contract the tensors. The axis {@code a_axes[i]} of + * {@code a} must have the same dimension as axis + * {@code b_axes[i]} of {@code b} for all {@code i} in + * {@code range(0, len(a_axes))}. The lists + * {@code a_axes} and {@code b_axes} must have identical length + * and consist of unique integers that specify valid axes for each of the + * tensors. Additionally outer product is supported by passing + * {@code axes=0}. + *

+ * This operation corresponds to {@code numpy.tensordot(a, b, axes)}. + *

+ * Example 1: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes = 1} is equivalent to matrix multiplication. + *

+ * Example 2: When {@code a} and{@code b} are matrices (order 2), + * the case + * {@code axes = [[1], [0]]} is equivalent to matrix multiplication. + *

+ * Example 3: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes=0} gives the outer product, a tensor of order + * 4. + *

+ * Example 4: Suppose that aijk and blmn + * represent two tensors of order 3. Then, {@code contract(a, b, [[0], [2]])} is the order 4 tensor + * cjklm whose entry corresponding to the indices + * (j,k,l,m) is given by: + *

+ * cjklm = Σi aijk + * blmi . + *

+ * In general, {@code order(c) = order(a) + order(b) - 2*len(axes[0])}. + *

+ * + * @param a {@code Operand} of type {@code TFloat32} or {@code TFloat64}. + * @param b {@code Operand} with the same type as {@code a}. + * @param axes the first and second row + * contain the set of unique integers specifying axes along which the + * contraction is computed, for {@code a} and {@code b}, respectively. The number of + * axes for {@code a} and {@code b} must be equal. I + * @param the datatype of the Operands, must be either TFloat32 or + * TFloat64 + * @return A {@code Operand} with the same type as {@code a}. + * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type + */ + @Endpoint(name = "tensordot") + public Operand tensordot(Operand a, Operand b, int[] axes) { + + Operand[] abAxis = tensordotAxes(a, axes); + Operand aAxis = abAxis[0]; + Operand bAxis = abAxis[1]; + + return tensordot(a, b, aAxis, bAxis); + } + + /** + * Tensor contraction of a and b along specified axes and outer product. + *

+ * Tensordot (also known as tensor contraction) sums the product of elements + * from {@code a} and {@code b} over the indices specified by + * {@code a_axes} and {@code b_axes}. The lists + * {@code a_axes} and {@code b_axes} specify those pairs of axes + * along which to contract the tensors. The axis {@code a_axes[i]} of + * {@code a} must have the same dimension as axis + * {@code b_axes[i]} of {@code b} for all {@code i} in + * {@code range(0, len(a_axes))}. The lists + * {@code a_axes} and {@code b_axes} must have identical length + * and consist of unique integers that specify valid axes for each of the + * tensors. Additionally outer product is supported by passing + * {@code axes=0}. + *

+ * This operation corresponds to {@code numpy.tensordot(a, b, axes)}. + *

+ * Example 1: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes = 1} is equivalent to matrix multiplication. + *

+ * Example 2: When {@code a} and{@code b} are matrices (order 2), + * the case + * {@code axes = [[1], [0]]} is equivalent to matrix multiplication. + *

+ * Example 3: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes=0} gives the outer product, a tensor of order + * 4. + *

+ * Example 4: Suppose that aijk and blmn + * represent two tensors of order 3. Then, {@code contract(a, b, [[0], [2]])} is the order 4 tensor + * cjklm whose entry corresponding to the indices + * (j,k,l,m) is given by: + *

+ * cjklm = Σi aijk + * blmi . + *

+ * In general, {@code order(c) = order(a) + order(b) - 2*len(axes[0])}. + *

+ * + * @param a {@code Operand} of type {@code TFloat32} or {@code TFloat64}. + * @param b {@code Operand} with the same type as {@code a}. + * @param axes the first and second row + * contain the set of unique integers specifying axes along which the + * contraction is computed, for {@code a} and {@code b}, respectively. The number of + * axes for {@code a} and {@code b} must be equal. I + * @param the datatype of the Operands, must be either TFloat32 or + * TFloat64 + * @return A {@code Operand} with the same type as {@code a}. + * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type + */ + @Endpoint(name = "tensordot") + public Operand tensordot(Operand a, Operand b, int[][] axes) { + + Operand[] abAxis = tensordotAxes(a, axes); + Operand aAxis = abAxis[0]; + Operand bAxis = abAxis[1]; + + return tensordot(a, b, aAxis, bAxis); + } + + /** + * Tensor contraction of a and b along specified axes and outer product. + *

+ * Tensordot (also known as tensor contraction) sums the product of elements + * from {@code a} and {@code b} over the indices specified by + * {@code a_axes} and {@code b_axes}. The lists + * {@code a_axes} and {@code b_axes} specify those pairs of axes + * along which to contract the tensors. The axis {@code a_axes[i]} of + * {@code a} must have the same dimension as axis + * {@code b_axes[i]} of {@code b} for all {@code i} in + * {@code range(0, len(a_axes))}. The lists + * {@code a_axes} and {@code b_axes} must have identical length + * and consist of unique integers that specify valid axes for each of the + * tensors. Additionally outer product is supported by passing + * {@code axes=0}. + *

+ * This operation corresponds to {@code numpy.tensordot(a, b, axes)}. + *

+ * Example 1: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes = 1} is equivalent to matrix multiplication. + *

+ * Example 2: When {@code a} and{@code b} are matrices (order 2), + * the case + * {@code axes = [[1], [0]]} is equivalent to matrix multiplication. + *

+ * Example 3: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes=0} gives the outer product, a tensor of order + * 4. + *

+ * Example 4: Suppose that aijk and blmn + * represent two tensors of order 3. Then, {@code contract(a, b, [[0], [2]])} is the order 4 tensor + * cjklm whose entry corresponding to the indices + * (j,k,l,m) is given by: + *

+ * cjklm = Σi aijk + * blmi . + *

+ * In general, {@code order(c) = order(a) + order(b) - 2*len(axes[0])}. + *

+ * + * @param a {@code Operand} of type {@code TFloat32} or {@code TFloat64}. + * @param b {@code Operand} with the same type as {@code a}. + * @param aAxis axes for the a Operand + * @param bAxis axes for the b Operand + * @param the datatype of the Operands, must be either TFloat32 or + * TFloat64 + * @return A {@code Operand} with the same type as {@code a}. + * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type + */ + @SuppressWarnings({"unchecked", "unused"}) + @Endpoint(name = "tensordot") + public Operand tensordot( + Operand a, Operand b, Operand aAxis, Operand bAxis) { + + if (a.type().equals(TBfloat16.class) || a.type().equals(TFloat16.class)) { + throw new IllegalArgumentException( + String.format( + "Operand 'a' must be either TFloat32 or TFloat64 DataType, 'a' is a %s DataType", + a.type().getSimpleName())); + } + if (!a.type().equals(b.type())) { + throw new IllegalArgumentException( + String.format( + "Operands a and b must be the same data type, a is %s DataType, b is %s DataType", + a.type().getSimpleName(), b.type().getSimpleName())); + } + + // first result is Operand, second result is Operand, third result is long[] and it is + // ignored here. + Object[] aResult = tensordotReshape(a, aAxis, false); + Operand reshapedA = (Operand) aResult[0]; + Operand aFreeDims = (Operand) aResult[1]; + long[] aFreeDimsStatic = (long[]) aResult[2]; + + // first result is Operand, second result is Operand, third result is long[] and it is + // ignored here. + Object[] bResult = tensordotReshape(b, bAxis, true); + Operand reshapedB = (Operand) bResult[0]; + Operand bFreeDims = (Operand) bResult[1]; + long[] bFreeDimsStatic = (long[]) bResult[2]; + + Operand abMatmul = frameworkOps.linalg.matmul(reshapedA, reshapedB); + long[] abDimsStatic = new long[aFreeDimsStatic.length + bFreeDimsStatic.length]; + System.arraycopy(aFreeDimsStatic, 0, abDimsStatic, 0, aFreeDimsStatic.length); + System.arraycopy( + bFreeDimsStatic, 0, abDimsStatic, aFreeDimsStatic.length, bFreeDimsStatic.length); + if (!abMatmul.shape().hasUnknownDimension() + && abMatmul.shape().equals(Shape.of(abDimsStatic))) { + return abMatmul; + } else { + return Reshape.create(scope, abMatmul, Constant.vectorOf(scope, abDimsStatic)); + } + } + + /** + * Computes log(sum(exp(elements across dimensions of a tensor))). Reduces {@code input_tensor} + * along the dimensions given in {@code axes}. + * + *

Reduces {@code input} along the dimensions given in {@code axes}. Unless {@code keepdims} + * is true, the rank of the tensor is reduced by 1 for each of the entries in {@code axes}, which + * must be unique. If {@code keepdims} is true, the reduced dimensions are retained with length 1. + * If {@code axes} has no entries, all dimensions are reduced, and a tensor with a single element + * is returned. This function is more numerically stable than {@code log(sum(exp(input)))}. It + * avoids overflows caused by taking the exp of large inputs and underflows caused by taking the + * log of small inputs. + * + * @param input The tensor to reduce. + * @param axes The dimensions to reduce. If null, reduces all dimensions. Must be in the range + * {@code [-rank(input_tensor), rank(input_tensor)]}. + * @param keepDims If true, retains reduced dimensions with length 1. + * @param the data type for the input and the result + * @return The reduced tensor. + */ + @Endpoint(name = "reduceLogSumExp") + public Operand reduceLogSumExp( + Operand input, int[] axes, boolean keepDims) { + Operand reduceDims = reductionDims(input, axes); + Operand rawMax = reduceMaxWithDims(input, axes, keepDims, reduceDims); + Operand myMax = + StopGradient.create( + scope, + Select.create( + scope, IsFinite.create(scope, rawMax), rawMax, ZerosLike.create(scope, rawMax))); + + Operand result = + Log.create( + scope, + reduceSumWithDims( + Exp.create(scope, Sub.create(scope, input, myMax)), axes, keepDims, reduceDims)); + + if (!keepDims) { + myMax = Reshape.create(scope, myMax, org.tensorflow.op.core.Shape.create(scope, result)); + } + result = Add.create(scope, result, myMax); + return mayReduceToScalar(keepDims, axes, result); + } + + private Operand reduceSumWithDims( + Operand input, int[] axes, boolean keepDims, Operand dims) { + return mayReduceToScalar( + keepDims, axes, ReduceSum.create(scope, input, dims, ReduceSum.keepDims(keepDims))); + } + + private Operand reduceMaxWithDims( + Operand input, int[] axes, boolean keepDims, Operand dims) { + return mayReduceToScalar( + keepDims, axes, ReduceMax.create(scope, input, dims, ReduceMax.keepDims(keepDims))); + } + + /** + * Sets a reduction's output shape to be a scalar if possible. + * + * @return the operand, possibly reduced to a scalar. + */ + private Operand mayReduceToScalar( + boolean keepDims, int[] axes, Operand output) { + + if ((output.shape().numDimensions() == Shape.UNKNOWN_SIZE + || output.shape().hasUnknownDimension()) + && !keepDims + && axes == null) { + return Reshape.create(scope, output, Constant.tensorOf(scope, Shape.scalar())); + } else { + return output; + } + } + + /** + * Reduce dimensions based on axis + * + * @param input the input + * @param axes he dimensions to reduce, may be null + * @return the dimensions to be reduced. + */ + private Operand reductionDims(Operand input, int[] axes) { + if (axes != null) { + return Constant.vectorOf(scope, axes); + } + long rank = input.shape().numDimensions(); + if (rank != Shape.UNKNOWN_SIZE) { + int[] dims = new int[(int) rank]; + for (int i = 0; i < rank; i++) { + dims[i] = i; + } + return Constant.vectorOf(scope, dims); + + } else { + return Range.create( + scope, + Constant.scalarOf(scope, 0), + Rank.create(scope, input), + Constant.scalarOf(scope, 1)); + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java new file mode 100644 index 00000000000..a1f85544f95 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java @@ -0,0 +1,249 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op; + +import org.tensorflow.Operand; +import org.tensorflow.framework.op.nn.SigmoidCrossEntropyWithLogits; +import org.tensorflow.framework.op.nn.SoftmaxCrossEntropyWithLogits; +import org.tensorflow.framework.op.nn.SparseSoftmaxCrossEntropyWithLogits; +import org.tensorflow.op.Op; +import org.tensorflow.op.Scope; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.op.nn.NNhelper.wrap2DFunction; + +/** + * An API for building {@code nn} operations as {@link Op Op}s + * + *

These are higher level ops that may invoke core ops. Higher level Ops may perform the + * operation solely in the TensorFlow framework or do preprocessing of the Operands before invoking + * a core level Op. + * + *

{@link FrameworkOps} + */ +public class NnOps { + private final Scope scope; + + private final FrameworkOps frameworkOps; + + /** + * Creates Framework {@code nn} Operations + * + * @param frameworkOps the TensorFLow framework Ops + */ + NnOps(FrameworkOps frameworkOps) { + this.scope = frameworkOps.scope(); + this.frameworkOps = frameworkOps; + } + + /** + * Computes sigmoid cross entropy given {@code logits}. + * + *

Measures the probability error in discrete classification tasks in which each class is + * independent and not mutually exclusive. For instance, one could perform multilabel + * classification where a picture can contain both an elephant and a dog at the same time. + * + *

For brevity, let {@code x = logits}, {@code z = labels}. The logistic loss in pseudo-code is + * + *

+   *  z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
+   *   = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
+   *   = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
+   *   = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
+   *   = (1 - z) * x + log(1 + exp(-x))
+   *   = x - x * z + log(1 + exp(-x))
+   *  
+ * + *

For {@code x < 0}, to avoid overflow in {@code exp(-x)}, we reformulate the above + * + *

+   *  x - x * z + log(1 + exp(-x))
+   *   = log(exp(x)) - x * z + log(1 + exp(-x))
+   *   = - x * z + log(1 + exp(x))
+   *  
+ * + *

Hence, to ensure stability and avoid overflow, the implementation uses this equivalent + * formulation + * + *

+   *    max(x, 0) - x * z + log(1 + exp(-abs(x)))
+   *  
+ * + *

{@code logits} and {@code labels} must have the same type and shape. + * + *

+ * + * @param labels the labels + * @param logits the logits of type float32 or float64 + * @param the type of labels and logits + * @return the component-wise logistic losses. + * @throws IllegalArgumentException if logits and labels do not have the same shape + */ + public Operand sigmoidCrossEntropyWithLogits( + Operand labels, Operand logits) { + return SigmoidCrossEntropyWithLogits.sigmoidCrossEntropyWithLogits(scope, labels, logits); + } + + /** + * Computes softmax cross entropy between {@code logits} and {@code labels}. + * + *

Measures the probability error in discrete classification tasks in which the classes are + * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is + * labeled with one and only one label: an image can be a dog or a truck, but not both. + * + *

NOTE: + * + *

While the classes are mutually exclusive, their probabilities need not be. All that is + * required is that each row of {@code labels} is a valid probability distribution. If they are + * not, the computation of the gradient will be incorrect. + * + *

If using exclusive {@code labels} (wherein one and only one class is true at a time), see + * {@link org.tensorflow.op.NnOps#sparseSoftmaxCrossEntropyWithLogits} + * + *

Usage: + * + *

+   *    Operand<TFloat32> logits =
+   *        tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
+   *    Operand<TFloat32> labels =
+   *        tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
+   *    Operand<TFloat32> output =
+   *        tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
+   *    // output Shape = [2]
+   *    // dataType = FLOAT (1)
+   *    // values { 0.169846, 0.824745 }
+   *  
+ * + *

Backpropagation will happen into both {@code logits} and {@code labels}. To disallow + * backpropagation into {@code labels}, pass label tensors through {@code tf.stopGradient} before + * feeding it to this function. + * + * @param labels Each vector along the class dimension should hold a valid probability + * distribution e.g. for the case in which labels are of shape {@code [batch_size, + * num_classes] }, each row of {@code labels[i]} must be a valid probability distribution. + * @param logits Per-label activations, typically a linear output. These activation energies are + * interpreted as unnormalized log probabilities. + * @param axis The class dimension. -1 is the last dimension. + * @param the number type of the operands + * @param the data type for the labels. + * @return the softmax cross entropy loss. Its type is the same as {@code logits} and its shape is + * the same as {@code labels} except that it does not have the last dimension of {@code + * labels}. + */ + public Operand softmaxCrossEntropyWithLogits( + Operand labels, Operand logits, int axis) { + return SoftmaxCrossEntropyWithLogits.softmaxCrossEntropyWithLogits(scope, labels, logits, axis); + } + + /** + * Computes sparse softmax cross entropy between {@code logits} and {@code labels}. + * + *

Measures the probability error in discrete classification tasks in which the classes are + * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is + * labeled with one and only one label: an image can be a dog or a truck, but not both. + * + *

NOTE: + * + *

For this operation, the probability of a given label is considered exclusive. That is, soft + * classes are not allowed, and the {@code labels} vector must provide a single specific index for + * the true class for each row of {@code logits} (each minibatch entry). For soft softmax + * classification with a probability distribution for each entry, {@link + * org.tensorflow.op.NnOps#softmaxCrossEntropyWithLogits}. + * + *

WARNING: + * + *

This op expects unscaled logits, since it performs a {@code softmax} on {@code logits } + * internally for efficiency. Do not call this op with the output of {@code softmax}, as it will + * produce incorrect results. + * + *

A common use case is to have logits of shape {@code [batchSize, numClasses]} and have labels + * of shape {@code [batchSize]}, but higher dimensions are supported, in which case the {@code + * dim}-th dimension is assumed to be of size {@code numClasses}. {@code logits} must have the + * {@code dataType} of {@code TFloat16}, {@code TFloat32} , or {@code TFloat64}, and {@code + * labels} must have the dtype of {@code TInt32} or {@code TInt64}. + * + * @param labels {@code Tensor} of shape {@code [d_0, d_1, ..., d_{r-1}]} (where {@code r } is + * rank of {@code labels} and result) and the dataType is {@code TInt32} or {@code TInt64}. + * Each entry in {@code labels} must be an index in {@code [0, numClasses)}. Other values will + * raise an exception when this op is run on CPU, and return {@code NaN} for corresponding + * loss and gradient rows on GPU. + * @param logits Per-label activations (typically a linear output) of shape {@code [d_0, d_1, ..., + * d_{r-1}, numClasses]} and dataType of {@code TFloat16}, {@code TFloat32}, or {@code + * TFloat64}. These activation energies are interpreted as unnormalized log probabilities. + * @param the data type for the labels + * @param the data tyoe for the loss and logits. + * @return the loss + * @throws IllegalArgumentException If logits are scalars (need to have {@code rank >= 1}) or if + * the rank of the labels is not equal to the rank of the logits minus one. + */ + public Operand sparseSoftmaxCrossEntropyWithLogits( + Operand labels, Operand logits) { + return SparseSoftmaxCrossEntropyWithLogits.sparseSoftmaxCrossEntropyWithLogits( + scope, labels, logits); + } + + /** + * Calculates a Softmax operation operation on the last dimension + * + * @param input the input + * @return the softmax of the input for the specified axis. + * @throws IllegalArgumentException if axis is not in the range [-rank - rank], exclusive + * @param the data type for the input and result + */ + public Operand softmax(Operand input) { + return wrap2DFunction(scope, input, org.tensorflow.op.nn.Softmax::create, -1); + } + + /** + * Calculates a Softmax operation. If the axis is not the last dimension, then the input axis is + * moved to the last axis before calling tf.nn.softmax, then restored before returning. + * + * @param input the input + * @param axis the axis + * @return the softmax of the input for the specified axis. + * @throws IllegalArgumentException if axis is not in the range [-rank - rank], exclusive + * @param the data type for the input and result + */ + public Operand softmax(Operand input, int axis) { + return wrap2DFunction(scope, input, org.tensorflow.op.nn.Softmax::create, axis); + } + + /** + * Calculates a Log Softmax operation on the last dimension + * + * @param input the input + * @return the softmax of the input for the specified axis. + * @throws IllegalArgumentException if axis is not in the range [-rank - rank], exclusive + * @param the data type for the input and result + */ + public Operand logSoftmax(Operand input) { + return wrap2DFunction(scope, input, org.tensorflow.op.nn.LogSoftmax::create, -1); + } + + /** + * Calculates a Log Softmax operation. If the axis is not the last dimension, then the input axis + * is moved to the last axis before calling tf.nn.softmax, then restored before returning. + * + * @param input the input + * @param axis the axis + * @return the softmax of the input for the specified axis. + * @throws IllegalArgumentException if axis is not in the range [-rank - rank], exclusive + * @param the data type for the input and result + */ + public Operand logSoftmax(Operand input, int axis) { + return wrap2DFunction(scope, input, org.tensorflow.op.nn.LogSoftmax::create, axis); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetOps.java new file mode 100644 index 00000000000..f76947018b5 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetOps.java @@ -0,0 +1,161 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op; + +import org.tensorflow.Operand; +import org.tensorflow.op.Scope; +import org.tensorflow.op.SparseOps; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.sparse.DenseToDenseSetOperation; +import org.tensorflow.op.sparse.SparseToDense; +import org.tensorflow.types.family.TNumber; + +/** Implementation of set operations */ +public class SetOps { + + private final Scope scope; + + private final FrameworkOps frameworkOps; + + /** + * Creates Framework {@code nn} Operations + * + * @param frameworkOps the TensorFLow framework Ops + */ + SetOps(FrameworkOps frameworkOps) { + this.scope = frameworkOps.scope(); + this.frameworkOps = frameworkOps; + } + + /** + * Computes set difference of elements in last dimension of a and b with + * aMinusB set to true. + * + *

All but the last dimension of a and b must match + * + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand difference(Operand a, Operand b) { + return difference(a, b, true); + } + + /** + * Computes set difference of elements in last dimension of a and b. + * + *

All but the last dimension of a and b must match + * + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param aMinusB whether to subtract b from a, vs vice versa. + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand difference(Operand a, Operand b, boolean aMinusB) { + return setOperation(a, b, aMinusB ? Operation.A_MINUS_B : Operation.B_MINUS_A); + } + + /** + * Computes set union of elements in last dimension of a and b. + * + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand union(Operand a, Operand b) { + return setOperation(a, b, Operation.UNION); + } + + /** + * Computes set intersection of elements in last dimension of a and b. + * + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand intersection(Operand a, Operand b) { + return setOperation(a, b, Operation.INTERSECTION); + } + + /** + * Compute set operation of elements in last dimension of a and b. + * + * @param a The first set operation operand + * @param b The other et operation operand + * @param setOperation The set operation to perform, {@link Operation}. + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand setOperation( + Operand a, Operand b, Operation setOperation) { + + DenseToDenseSetOperation setOperationResult = + DenseToDenseSetOperation.create( + scope, + a, + b, + setOperation.getSetOperation(), + DenseToDenseSetOperation.validateIndices(true)); + + return SparseToDense.create( + scope, + setOperationResult.resultIndices(), + setOperationResult.resultShape(), + setOperationResult.resultValues(), + Cast.create(scope, Constant.scalarOf(scope, 0), a.type())); + } + + /** + * Enumeration containing the string operation values to be passed to the TensorFlow Sparse Ops + * function {@link SparseOps#denseToDenseSetOperation} + */ + public enum Operation { + A_MINUS_B("a-b"), + B_MINUS_A("b-a"), + INTERSECTION("intersection"), + UNION("union"); + + private final String setOperation; + + Operation(String setOperation) { + this.setOperation = setOperation; + } + + /** + * Gets the set operation String value used to pass as the stringOperation value to {@link + * SparseOps#denseToDenseSetOperation} + * + * @return the set operation String value + */ + public String getSetOperation() { + return setOperation; + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetsOps.java new file mode 100644 index 00000000000..d7833cdbb06 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetsOps.java @@ -0,0 +1,161 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op; + +import org.tensorflow.Operand; +import org.tensorflow.op.Scope; +import org.tensorflow.op.SparseOps; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.sparse.DenseToDenseSetOperation; +import org.tensorflow.op.sparse.SparseToDense; +import org.tensorflow.types.family.TNumber; + +/** Implementation of set operations */ +public class SetsOps { + + private final Scope scope; + + private final FrameworkOps frameworkOps; + + /** + * Creates Framework {@code nn} Operations + * + * @param frameworkOps the TensorFLow framework Ops + */ + SetsOps(FrameworkOps frameworkOps) { + this.scope = frameworkOps.scope(); + this.frameworkOps = frameworkOps; + } + + /** + * Computes set difference of elements in last dimension of a and b with + * aMinusB set to true. + * + *

All but the last dimension of a and b must match + * + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand difference(Operand a, Operand b) { + return difference(a, b, true); + } + + /** + * Computes set difference of elements in last dimension of a and b. + * + *

All but the last dimension of a and b must match + * + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param aMinusB whether to subtract b from a, vs vice versa. + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand difference(Operand a, Operand b, boolean aMinusB) { + return setOperation(a, b, aMinusB ? Operation.A_MINUS_B : Operation.B_MINUS_A); + } + + /** + * Computes set union of elements in last dimension of a and b. + * + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand union(Operand a, Operand b) { + return setOperation(a, b, Operation.UNION); + } + + /** + * Computes set intersection of elements in last dimension of a and b. + * + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand intersection(Operand a, Operand b) { + return setOperation(a, b, Operation.INTERSECTION); + } + + /** + * Compute set operation of elements in last dimension of a and b. + * + * @param a The first set operation operand + * @param b The other et operation operand + * @param setOperation The set operation to perform, {@link Operation}. + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand setOperation( + Operand a, Operand b, Operation setOperation) { + + DenseToDenseSetOperation setOperationResult = + DenseToDenseSetOperation.create( + scope, + a, + b, + setOperation.getSetOperation(), + DenseToDenseSetOperation.validateIndices(true)); + + return SparseToDense.create( + scope, + setOperationResult.resultIndices(), + setOperationResult.resultShape(), + setOperationResult.resultValues(), + Cast.create(scope, Constant.scalarOf(scope, 0), a.type())); + } + + /** + * Enumeration containing the string operation values to be passed to the TensorFlow Sparse Ops + * function {@link SparseOps#denseToDenseSetOperation} + */ + public enum Operation { + A_MINUS_B("a-b"), + B_MINUS_A("b-a"), + INTERSECTION("intersection"), + UNION("union"); + + private final String setOperation; + + Operation(String setOperation) { + this.setOperation = setOperation; + } + + /** + * Gets the set operation String value used to pass as the stringOperation value to {@link + * SparseOps#denseToDenseSetOperation} + * + * @return the set operation String value + */ + public String getSetOperation() { + return setOperation; + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/NNhelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/NNhelper.java new file mode 100644 index 00000000000..3af3c606b4a --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/NNhelper.java @@ -0,0 +1,127 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op.nn; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.op.core.Concat; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.ExpandDims; +import org.tensorflow.op.core.Range; +import org.tensorflow.op.core.Rank; +import org.tensorflow.op.core.Reshape; +import org.tensorflow.op.linalg.Transpose; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Sub; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TFloating; + +import java.util.Arrays; +import java.util.function.BiFunction; + +/** package private Helper class for nn functions */ +public class NNhelper { + /** + * Helper function for ops that accept and return 2d inputs of same shape. + * + *

It reshapes and transposes the inputs into a 2-D Tensor and then invokes the given function. + * The output would be transposed and reshaped back. + * + * @param scope the TensorFlow Scope + * @param input the input + * @param computeOp The function to wrap. Must accept the scope as the first argument, and the + * input as the second argument. (e.g. {@code org.tensorflow.op.nn.Softmax::create} + * @param axis The axisension the operation should operate on. {@code -1} indicates the last + * axisension. + * @param the data type from the input and the result. + * @return the result of the operation, the same shape as the input + * @throws IllegalArgumentException if axis is not in the range [-rank - rank], exclusive + */ + public static Operand wrap2DFunction( + Scope scope, + Operand input, + BiFunction, Operand> computeOp, + int axis) { + Shape shape = input.shape(); + boolean isLastDim = axis == -1 || axis == shape.numDimensions() - 1; + if (isLastDim) { + return computeOp.apply(scope, input); + } + + // validate axis + if (!(-shape.numDimensions() <= axis && axis < shape.numDimensions())) { + throw new IllegalArgumentException( + String.format( + "Axis (%d) must be in the range [%d, %d] where %d is the number of axisensions in the input.", + axis, -shape.numDimensions(), shape.numDimensions(), shape.numDimensions())); + } + + // If axis is not the last axisension, we have to do a transpose so that we can + // still perform the op on its last axisension. + + // In case axis is negative (and is not last axisension -1), convert to positive + int lAxis = Math.floorMod(axis, shape.numDimensions()); + Operand inputRank = Rank.create(scope, input); + Operand axisOp = Constant.scalarOf(scope, lAxis); + Operand one = Constant.scalarOf(scope, 1); + Operand lastIndex = Sub.create(scope, inputRank, one); + Operand swappedInputs = swapAxis(scope, input, axisOp, lastIndex); + Operand output = computeOp.apply(scope, swappedInputs); + return fixOutput(scope, output, shape, axisOp, lastIndex); + } + + /** + * Restores the specified axis, then reshapes the input to the provided shape. + * + * @param scope The TensorFlow scope + * @param output the output + * @param shape the desired shape + * @param axis the axisension to move + * @return the restored output based on the axisension and shape. + */ + private static Operand fixOutput( + Scope scope, + Operand output, + Shape shape, + Operand axis, + Operand lastIndex) { + Operand result = swapAxis(scope, output, axis, lastIndex); + return Reshape.create(scope, result, Constant.tensorOf(scope, shape)); + } + + /** + * Moves the specified Axis to the last axis + * + * @param input the input + * @param axis the axisension to move + * @param lastIndex the last axisension + * @return input with the axisension swapped to the last axisension + */ + private static Operand swapAxis( + Scope scope, Operand input, Operand axis, Operand lastIndex) { + + Operand zero = Constant.scalarOf(scope, 0); + Operand one = Constant.scalarOf(scope, 1); + Operand minus1 = Constant.scalarOf(scope, -1); + Operand range1 = Range.create(scope, zero, axis, one); + Operand range2 = Range.create(scope, Add.create(scope, axis, one), lastIndex, one); + Operand xDim = ExpandDims.create(scope, axis, minus1); + Operand xLastIndex = ExpandDims.create(scope, lastIndex, minus1); + + return Transpose.create( + scope, input, Concat.create(scope, Arrays.asList(range1, xLastIndex, range2, xDim), zero)); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java similarity index 83% rename from tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java rename to tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java index 92c413f7e52..432e1b47a3f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java @@ -1,29 +1,32 @@ -package org.tensorflow.op.nn; +package org.tensorflow.framework.op.nn; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; import org.tensorflow.op.core.Select; import org.tensorflow.op.core.ZerosLike; import org.tensorflow.op.dtypes.Cast; -import org.tensorflow.op.math.*; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Exp; +import org.tensorflow.op.math.GreaterEqual; +import org.tensorflow.op.math.Log1p; +import org.tensorflow.op.math.Mul; +import org.tensorflow.op.math.Neg; +import org.tensorflow.op.math.Sub; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -@Operator(group = "nn") +// @Operator(group = "nn") public class SigmoidCrossEntropyWithLogits { /** - * Computes sigmoid cross entropy given logits. + * Computes sigmoid cross entropy given {@code logits}. * *

Measures the probability error in discrete classification tasks in which each class is * independent and not mutually exclusive. For instance, one could perform multilabel * classification where a picture can contain both an elephant and a dog at the same time. * - *

For brevity, let x = logits, z = labels. The logistic loss in - * pseudo-code is + *

For brevity, let {@code x = logits}, {@code z = labels}. The logistic loss in pseudo-code is * *

    * z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
@@ -34,7 +37,7 @@ public class SigmoidCrossEntropyWithLogits {
    *  = x - x * z + log(1 + exp(-x))
    * 
* - *

For x < 0, to avoid overflow in exp(-x), we reformulate the above + *

For {@code x < 0}, to avoid overflow in {@code exp(-x)}, we reformulate the above * *

    * x - x * z + log(1 + exp(-x))
@@ -49,7 +52,7 @@ public class SigmoidCrossEntropyWithLogits {
    *   max(x, 0) - x * z + log(1 + exp(-abs(x)))
    * 
* - *

logits and labels must have the same type and shape. + *

{@code logits} and {@code labels} must have the same type and shape. * *

* @@ -60,7 +63,7 @@ public class SigmoidCrossEntropyWithLogits { * @return the component-wise logistic losses. * @throws IllegalArgumentException if logits' and labels' do not have the same shape */ - @Endpoint(name = "sigmoidCrossEntropyWithLogits") + // @Endpoint(name = "sigmoidCrossEntropyWithLogits") public static Operand sigmoidCrossEntropyWithLogits( Scope scope, Operand labels, Operand logits) { if (!isCompatible(labels.shape(), logits.shape())) { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java similarity index 86% rename from tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java rename to tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java index ddeacbea4d4..7d59941f27a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java @@ -1,11 +1,15 @@ -package org.tensorflow.op.nn; +package org.tensorflow.framework.op.nn; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.op.core.*; +import org.tensorflow.op.core.Concat; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Range; +import org.tensorflow.op.core.Rank; +import org.tensorflow.op.core.Reshape; +import org.tensorflow.op.core.Slice; import org.tensorflow.op.dtypes.Cast; import org.tensorflow.op.linalg.Transpose; import org.tensorflow.op.math.Sub; @@ -14,12 +18,11 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; import java.util.Arrays; import java.util.List; -@Operator(group = "nn") +// @Operator(group = "nn") public class SoftmaxCrossEntropyWithLogits { /** @@ -63,11 +66,13 @@ public class SoftmaxCrossEntropyWithLogits { * @param logits Per-label activations, typically a linear output. These activation energies are * interpreted as unnormalized log probabilities. * @param axis The class dimension. -1 is the last dimension. - * @param the number type of the operands + * @param the data type for the logits and return operand + * @param the data type for the labels * @return the softmax cross entropy loss. Its type is the same as logits and its * shape is the same as labels except that it does not have the last dimension of * labels. */ + @SuppressWarnings("unchecked") @Endpoint(name = "softmaxCrossEntropyWithLogits") public static Operand softmaxCrossEntropyWithLogits( Scope scope, Operand labels, Operand logits, int axis) { @@ -78,7 +83,9 @@ public static Operand softmaxCrossEntr } if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) { - Operand result = softmaxCrossEntropyWithLogits(scope, + Operand result = + softmaxCrossEntropyWithLogits( + scope, Cast.create(scope, labels, TFloat32.class), Cast.create(scope, logits, TFloat32.class), axis); @@ -86,10 +93,8 @@ public static Operand softmaxCrossEntr } if (logits.asOutput().type() != labels.asOutput().type()) { - return softmaxCrossEntropyWithLogits(scope, - Cast.create(scope, labels, logits.asOutput().type()), - logits, - axis); + return softmaxCrossEntropyWithLogits( + scope, Cast.create(scope, labels, logits.asOutput().type()), logits, axis); } Operand inputRank = Cast.create(scope, Rank.create(scope, logits), TInt64.class); @@ -101,13 +106,20 @@ public static Operand softmaxCrossEntr labels = moveDimToEnd(scope, labels, axis, inputRank); } + Operand tLabels; + if (labels.type() != logits.type()) { + tLabels = Cast.create(scope, labels, logits.type()); + } else { + // Unchecked warning checked in if statement. + tLabels = (Operand) labels; + } + Shape inputShape = logits.shape(); logits = flattenOuterDims(scope, logits); - labels = flattenOuterDims(scope, labels); + tLabels = flattenOuterDims(scope, tLabels); - org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits smax = - org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits.create( - scope, logits, (Operand)labels); + org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits smax = + org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits.create(scope, logits, tLabels); /* cannot use generic on cost, because cost may be recast later. */ Operand cost = smax.loss(); Operand outputShape = @@ -119,6 +131,9 @@ public static Operand softmaxCrossEntr cost = Reshape.create(scope, cost, outputShape); if (scope.env().isGraph() && !shape.hasUnknownDimension()) { long[] array = shape.asArray(); + if (array == null) { + array = new long[0]; + } long[] newArray = new long[array.length - 1]; if (axis < 0) { axis = shape.numDimensions() + axis; @@ -153,7 +168,7 @@ private static Operand flattenOuterDims(Scope scope, Oper boolean productValid = true; for (int i = ndims - 2; i >= 0; i--) { long d = shape.size(i); - if (d == org.tensorflow.ndarray.Shape.UNKNOWN_SIZE) { + if (d == Shape.UNKNOWN_SIZE) { productValid = false; break; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java similarity index 50% rename from tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java rename to tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java index 54b32bb5c63..553adf90aad 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java @@ -1,11 +1,10 @@ -package org.tensorflow.op.nn; +package org.tensorflow.framework.op.nn; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; import org.tensorflow.op.core.AssertThat; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Reshape; @@ -15,18 +14,22 @@ import org.tensorflow.types.TBfloat16; import org.tensorflow.types.TFloat16; import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; import java.util.ArrayList; import java.util.Collections; import java.util.List; -@Operator(group = "nn") +// @Operator(group = "nn") public class SparseSoftmaxCrossEntropyWithLogits { /** - * Computes sparse softmax cross entropy between logits and labels. + * Computes sparse softmax cross entropy between {@code logits} and {@code labels}. * *

Measures the probability error in discrete classification tasks in which the classes are * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is @@ -35,51 +38,61 @@ public class SparseSoftmaxCrossEntropyWithLogits { *

NOTE: * *

For this operation, the probability of a given label is considered exclusive. That is, soft - * classes are not allowed, and the labels vector must provide a single specific - * index for the true class for each row of logits (each minibatch entry). For soft - * softmax classification with a probability distribution for each entry, {@link + * classes are not allowed, and the {@code labels} vector must provide a single specific index for + * the true class for each row of {@code logits} (each minibatch entry). For soft softmax + * classification with a probability distribution for each entry, {@link * org.tensorflow.op.NnOps#softmaxCrossEntropyWithLogits}. * *

WARNING: * - *

This op expects unscaled logits, since it performs a softmax on logits - * internally for efficiency. Do not call this op with the output of softmax, - * as it will produce incorrect results. + *

This op expects unscaled logits, since it performs a {@code softmax} on {@code logits } + * internally for efficiency. Do not call this op with the output of {@code softmax}, as it will + * produce incorrect results. * - *

A common use case is to have logits of shape [batchSize, numClasses] and have - * labels of shape [batchSize], but higher dimensions are supported, in which case - * the dim-th dimension is assumed to be of size numClasses. - * logits must have the dataType of TFloat16, TFloat32 - * , or TFloat64, and labels must have the dtype of TInt32 - * or TInt64. + *

A common use case is to have logits of shape {@code [batchSize, numClasses]} and have labels + * of shape {@code [batchSize]}, but higher dimensions are supported, in which case the {@code + * dim}-th dimension is assumed to be of size {@code numClasses}. {@code logits} must have the + * {@code dataType} of {@code TFloat16}, {@code TFloat32} , or {@code TFloat64}, and {@code + * labels} must have the dtype of {@code TInt32} or {@code TInt64}. * * @param scope current scope - * @param labels Tensor of shape [d_0, d_1, ..., d_{r-1}] (where r - * is rank of labels and result) and the dataType is TInt32 - * or TInt64. Each entry in labels must be an index in [0, - * numClasses). Other values will raise an exception when this op is run on CPU, and - * return NaN for corresponding loss and gradient rows on GPU. - * @param logits Per-label activations (typically a linear output) of shape [d_0, d_1, ..., - * d_{r-1}, numClasses] and dataType of TFloat16, TFloat32, - * or TFloat64. These activation energies are interpreted as unnormalized log - * probabilities. - * @return A Tensor of the same shape as labels and of the same type as - * logits with the softmax cross entropy loss. - * @throws IllegalArgumentException If logits are scalars (need to have rank >= 1) or if the rank - * of the labels is not equal to the rank of the logits minus one. + * @param labels {@code Tensor} of shape {@code [d_0, d_1, ..., d_{r-1}]} (where {@code r } is + * rank of {@code labels} and result) and the dataType is {@code TInt32} or {@code TInt64}. + * Each entry in {@code labels} must be an index in {@code [0, numClasses)}. Other values will + * raise an exception when this op is run on CPU, and return {@code NaN} for corresponding + * loss and gradient rows on GPU. + * @param logits Per-label activations (typically a linear output) of shape {@code [d_0, d_1, ..., + * d_{r-1}, numClasses]} and dataType of {@code TFloat16}, {@code TFloat32}, or {@code + * TFloat64}. These activation energies are interpreted as unnormalized log probabilities. + * @param the data type for the labels + * @param the data tyoe for the loss and logits. + * @return the loss + * @throws IllegalArgumentException If logits are scalars (need to have {@code rank >= 1}) or if + * the rank of the labels is not equal to the rank of the logits minus one. */ + @SuppressWarnings("unchecked") @Endpoint(name = "sparseSoftmaxCrossEntropyWithLogits") - public static Operand sparseSoftmaxCrossEntropyWithLogits( - Scope scope, Operand labels, Operand logits) { + public static + Operand sparseSoftmaxCrossEntropyWithLogits( + Scope scope, Operand labels, Operand logits) { scope = scope.withSubScope("SparseSoftmaxCrossEntropyWithLogits"); - /** cannot use generics on preciseLogits as it may be recast later */ - Operand preciseLogits = logits; + Operand preciseLogits; if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) { preciseLogits = Cast.create(scope, logits, TFloat32.class); + } else if (TFloating.class.isAssignableFrom(logits.type())) { + preciseLogits = (Operand) logits; + } else { + preciseLogits = Cast.create(scope, logits, TFloat64.class); } - Shape labelsStaticShape = labels.shape(); + Operand iLabels; + if (TIntegral.class.isAssignableFrom(labels.type())) { + iLabels = (Operand) labels; + } else { + iLabels = Cast.create(scope, labels, TInt64.class); + } + Shape labelsStaticShape = iLabels.shape(); org.tensorflow.op.core.Shape labelsShape = - org.tensorflow.op.core.Shape.create(scope, labels); + org.tensorflow.op.core.Shape.create(scope, iLabels); Shape logitsShape = logits.shape(); Shape logitsShortened = logitsShape.take(logitsShape.numDimensions() - 1); @@ -108,14 +121,16 @@ public static Operand sparseSoftmaxCrossE } // Check if no reshapes are required. if (logitsShape.numDimensions() == 2) { - org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits smax = - org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits.create( - scope, preciseLogits, labels); - Operand loss = smax.loss(); - if (logits.asOutput().type() == TFloat16.class) { - loss = Cast.create(scope, loss, TFloat16.class); + org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits smax = + org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits.create( + scope, preciseLogits, iLabels); + Operand cost = smax.loss(); + if (cost.type() != logits.type()) { + return Cast.create(scope, cost, logits.type()); + } else { + // Unchecked cast already checked with previous if + return (Operand) cost; } - return loss; } List shapeChecks = new ArrayList<>(); @@ -126,7 +141,7 @@ public static Operand sparseSoftmaxCrossE scope, Equal.create( scope, - org.tensorflow.op.core.Shape.create(scope, labels), + org.tensorflow.op.core.Shape.create(scope, iLabels), Shapes.take( scope, org.tensorflow.op.core.Shape.create(scope, logits), @@ -143,16 +158,19 @@ public static Operand sparseSoftmaxCrossE long numClassses = logitsShape.size(-1); preciseLogits = Reshape.create(scope, preciseLogits, Constant.arrayOf(scope, -1L, numClassses)); - labels = Reshape.create(scope, labels, Constant.scalarOf(scope, -1)); + iLabels = Reshape.create(scope, iLabels, Constant.scalarOf(scope, -1)); scope.withControlDependencies(shapeChecks); - org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits smax = - org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits.create( - scope, preciseLogits, labels); - Operand cost = smax.loss(); + // call raw op + org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits smax = + org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits.create( + scope, preciseLogits, iLabels); + Operand cost = smax.loss(); cost = Reshape.create(scope, cost, labelsShape); - if (logits.asOutput().type() == TFloat16.class) { - cost = Cast.create(scope, cost, TFloat16.class); + if (cost.type() != logits.type()) { + return Cast.create(scope, cost, logits.type()); + } else { + // Unchecked cast already checked with previous if + return (Operand) cost; } - return cost; } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ActivationTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ActivationTest.java new file mode 100644 index 00000000000..0f779567ddc --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ActivationTest.java @@ -0,0 +1,48 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.activations.ReLU; +import org.tensorflow.framework.activations.Tanh; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class ActivationTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testReLU() { + float[][] input = {{1, -2}, {3, -4}, {-1, 2}, {-3, 4}}; + float[][] expected = {{1, 0}, {3, 0}, {0, 2}, {0, 4}}; + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + ReLU relu = new ReLU<>(tf); + + Activation instance = new Activation<>(tf, relu, TFloat32.class); + Operand result = instance.call(tf.constant(input), true, TFloat32.class); + + session.evaluate(tf.constant(expected), result); + } + } + + /** Test of Tanh call method. */ + @Test + public void testCallDouble() { + double[] input = {1, -2, 3, -4, -5, 6, -7, 8}; + double[] expected = { + 0.76159416, -0.96402758, + 0.99505475, -0.9993293, + -0.9999092, 0.99998771, + -0.99999834, 0.99999977 + }; + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Tanh tanh = new Tanh<>(tf); + Activation instance = new Activation<>(tf, tanh, TFloat64.class); + Operand result = instance.call(tf.constant(input), false, TFloat64.class); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AddTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AddTest.java new file mode 100644 index 00000000000..cf2c3a76117 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AddTest.java @@ -0,0 +1,237 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class AddTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + double[][][] x1 = { + { + {0.13570025, 0.55677077, 0.06648757, 0.58405729, 0.61086578}, + {0.18659685, 0.39331301, 0.68069423, 0.09510652, 0.86098578}, + {0.99338463, 0.37543824, 0.74858191, 0.31828287, 0.94056888}, + {0.76202298, 0.05605309, 0.73475366, 0.9313434, 0.48606332} + }, + { + {0.13023652, 0.39567908, 0.89910993, 0.71302943, 0.73722061}, + {0.6212917, 0.62624375, 0.8184835, 0.16864979, 0.96730508}, + {0.39645622, 0.35834793, 0.39924944, 0.90297727, 0.82857399}, + {0.70014157, 0.95498672, 0.6179583, 0.33104554, 0.11978174} + } + }; + double[][][] x2 = { + { + {0.82828211, 0.28889298, 0.7159566, 0.93377237, 0.32654201}, + {0.73234341, 0.17123203, 0.62582661, 0.96272026, 0.58700802}, + {0.12527705, 0.64175689, 0.64915537, 0.80589999, 0.26400939}, + {0.79376476, 0.24171677, 0.0677271, 0.07027092, 0.29195821} + }, + { + {0.56599224, 0.10611362, 0.83370522, 0.72514044, 0.08126704}, + {0.48173969, 0.16509515, 0.21040572, 0.44414272, 0.70656624}, + {0.89191749, 0.73008498, 0.9177326, 0.31897888, 0.56743576}, + {0.36304201, 0.36696309, 0.60722209, 0.79244879, 0.63492784} + } + }; + double[][][] x3 = { + { + {0.90545522, 0.55172128, 0.87254455, 0.1396359, 0.1538656}, + {0.04276304, 0.9315817, 0.91360492, 0.00604873, 0.04174153}, + {0.60856471, 0.37386072, 0.68937889, 0.21272655, 0.65082257}, + {0.44925012, 0.29825938, 0.20043074, 0.84906101, 0.78397795} + }, + { + {0.70855776, 0.17650269, 0.02422264, 0.84612297, 0.72450389}, + {0.05133022, 0.61175015, 0.56296539, 0.66780478, 0.63326012}, + {0.11212696, 0.50675282, 0.58170013, 0.21101392, 0.83090424}, + {0.91830915, 0.42113009, 0.49795942, 0.2814478, 0.11920788} + } + }; + double[][][] xsum = { + { + {1.86943758, 1.39738503, 1.65498872, 1.65746556, 1.09127339}, + {0.9617033, 1.49612674, 2.22012576, 1.06387551, 1.48973533}, + {1.7272264, 1.39105585, 2.08711617, 1.33690942, 1.85540083}, + {2.00503786, 0.59602925, 1.00291149, 1.85067532, 1.56199948} + }, + { + {1.40478652, 0.67829538, 1.75703778, 2.28429285, 1.54299154}, + {1.1543616, 1.40308904, 1.59185462, 1.28059728, 2.30713144}, + {1.40050067, 1.59518573, 1.89868217, 1.43297007, 2.22691399}, + {1.98149274, 1.7430799, 1.72313981, 1.40494213, 0.87391746} + } + }; + + @Test + public void testAdd() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i3 = + new Input<>( + tf, + "l3", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Add instance = new Add<>(tf, TFloat64.class); + List> resultList = + instance.call( + Arrays.asList( + i1.getOutput(TFloat64.class), + i2.getOutput(TFloat64.class), + i3.getOutput(TFloat64.class)), + TFloat64.class); + + Operand result = resultList.get(0); + + assertArrayEquals(new long[] {Shape.UNKNOWN_SIZE, 4, 5}, result.shape().asArray()); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + Operand x3Op = tf.constant(x3); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0); + TFloat64 x3Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x3Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + feedMap.put(i3.getOutput(TFloat64.class), x3Tensor); + session.evaluate(tf.constant(xsum), result, feedMap); + } + } + } + + @Test + public void testMask() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i3 = + new Input<>( + tf, + "l3", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Add instance = new Add<>(tf, TFloat64.class); + List> inputs = + Arrays.asList( + i1.getOutput(TFloat64.class), + i2.getOutput(TFloat64.class), + i3.getOutput(TFloat64.class)); + List> mask = Arrays.asList(null, null, null); + + List> result = instance.computeMask(inputs, mask); + assertNull(result); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + Operand x3Op = tf.constant(x3); + mask = Arrays.asList(x1Op, x2Op, x3Op); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0); + TFloat64 x3Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x3Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + feedMap.put(i3.getOutput(TFloat64.class), x3Tensor); + result = instance.computeMask(inputs, mask); + Boolean[] expected = new Boolean[(int) result.get(0).size()]; + Arrays.fill(expected, true); + session.evaluate(expected, result.get(0), feedMap); + } + } + } + + @Test + public void testMaskInvalidLengths() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i3 = + new Input<>( + tf, + "l3", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Add instance = new Add<>(tf, TFloat64.class); + List> inputs = + Arrays.asList( + i1.getOutput(TFloat64.class), + i2.getOutput(TFloat64.class), + i3.getOutput(TFloat64.class)); + List> mask = Arrays.asList(null, null); + instance.computeMask(inputs, mask); + } + }); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AlphaDropoutTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AlphaDropoutTest.java new file mode 100644 index 00000000000..cc7972e97cd --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AlphaDropoutTest.java @@ -0,0 +1,53 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class AlphaDropoutTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testAlphaDropout() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + long seed = 1001L; + TestSession.setEpsilon(1e-4F); + Shape expectedShape = Shape.of(3, 2, 3); + float[][][] x = + new float[][][] { + {{0.14517927f, 0.2574964f, 0.2291325f}, {0.9145494f, 0.9378068f, 0.6827883f}}, + {{0.27121753f, 0.08317473f, 0.3770739f}, {0.25451255f, 0.18511271f, 0.5620538f}}, + {{0.40101776f, 0.25205433f, 0.05103926f}, {0.08764106f, 0.00593294f, 0.37244815f}} + }; + AlphaDropout instance = + new AlphaDropout<>( + tf, 0.2f, seed, TFloat32.class, Layer.Options.create().inputShape(Shape.of(3, 2, 3))); + Operand input = tf.constant(x); + + Operand result = instance.call(input, false, TFloat32.class); + assertEquals(expectedShape, result.shape()); + session.evaluate(tf.identity(input), result); + + float[][][] exp = { + {{-1.236160f, 0.535354f, 0.510425f}, {1.112841f, 1.133282f, 0.909145f}}, + {{0.547414f, 0.382143f, -1.236160f}, {-1.236160f, 0.471736f, -1.236160f}}, + {{0.661496f, 0.530571f, -1.236160f}, {0.386068f, 0.314254f, 0.636386f}} + }; + + Operand expected = tf.constant(exp); + result = instance.call(input, true, TFloat32.class); + + assertEquals(expectedShape, result.shape()); + + // NOTE: result can only be evaluated once, otherwise new random numbers + // will be generated and won't match the expected + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AverageTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AverageTest.java new file mode 100644 index 00000000000..aec95201b32 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AverageTest.java @@ -0,0 +1,109 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +class AverageTest { + + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + double[][][] x1 = { + { + {0.13570025, 0.55677077, 0.06648757, 0.58405729, 0.61086578}, + {0.18659685, 0.39331301, 0.68069423, 0.09510652, 0.86098578}, + {0.99338463, 0.37543824, 0.74858191, 0.31828287, 0.94056888}, + {0.76202298, 0.05605309, 0.73475366, 0.9313434, 0.48606332} + }, + { + {0.13023652, 0.39567908, 0.89910993, 0.71302943, 0.73722061}, + {0.6212917, 0.62624375, 0.8184835, 0.16864979, 0.96730508}, + {0.39645622, 0.35834793, 0.39924944, 0.90297727, 0.82857399}, + {0.70014157, 0.95498672, 0.6179583, 0.33104554, 0.11978174} + } + }; + double[][][] x2 = { + { + {0.82828211, 0.28889298, 0.7159566, 0.93377237, 0.32654201}, + {0.73234341, 0.17123203, 0.62582661, 0.96272026, 0.58700802}, + {0.12527705, 0.64175689, 0.64915537, 0.80589999, 0.26400939}, + {0.79376476, 0.24171677, 0.0677271, 0.07027092, 0.29195821} + }, + { + {0.56599224, 0.10611362, 0.83370522, 0.72514044, 0.08126704}, + {0.48173969, 0.16509515, 0.21040572, 0.44414272, 0.70656624}, + {0.89191749, 0.73008498, 0.9177326, 0.31897888, 0.56743576}, + {0.36304201, 0.36696309, 0.60722209, 0.79244879, 0.63492784} + } + }; + + double[][][] xavg = { + { + {0.48199118, 0.42283187, 0.39122208, 0.75891483, 0.46870389}, + {0.45947013, 0.28227252, 0.65326042, 0.52891339, 0.7239969}, + {0.55933084, 0.50859756, 0.69886864, 0.56209143, 0.60228914}, + {0.77789387, 0.14888493, 0.40124038, 0.50080716, 0.38901076} + }, + { + {0.34811438, 0.25089635, 0.86640757, 0.71908493, 0.40924383}, + {0.55151569, 0.39566945, 0.51444461, 0.30639626, 0.83693566}, + {0.64418686, 0.54421645, 0.65849102, 0.61097808, 0.69800487}, + {0.53159179, 0.66097491, 0.6125902, 0.56174716, 0.37735479} + } + }; + + @Test + public void testAverage() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Average instance = new Average<>(tf, TFloat64.class); + List> resultList = + instance.call( + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), + TFloat64.class); + + Operand result = resultList.get(0); + + assertArrayEquals(new long[] {Shape.UNKNOWN_SIZE, 4, 5}, result.shape().asArray()); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + session.evaluate(tf.constant(xavg), result, feedMap); + } + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ConcatenateTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ConcatenateTest.java new file mode 100644 index 00000000000..bb69145182b --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ConcatenateTest.java @@ -0,0 +1,213 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class ConcatenateTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + double[][][] x1 = { + { + {5.67710153e-02, 5.68608495e-01, 6.94753423e-01, 7.06106392e-01, 9.55901476e-01}, + {1.16221311e-01, 2.77955841e-01, 8.48163908e-01, 6.65887805e-01, 8.48399407e-01}, + {1.32232733e-01, 6.07996978e-01, 5.04046847e-01, 9.79583238e-02, 6.71959629e-01}, + {9.69122927e-01, 2.65313461e-01, 7.25259997e-01, 2.95230608e-02, 2.68600949e-01} + }, + { + {9.26675552e-02, 9.11034266e-01, 9.42616405e-01, 1.76616001e-01, 4.35131783e-01}, + {3.42867908e-01, 4.42621793e-02, 1.86904412e-01, 2.30573118e-05, 1.40271865e-01}, + {9.92634263e-01, 3.50624173e-01, 9.53986246e-01, 6.98818650e-01, 9.82469750e-01}, + {7.84919140e-01, 5.03811516e-01, 2.99471974e-01, 4.13124006e-01, 1.67204622e-01} + } + }; + double[][][] x2 = { + { + {0.28151136, 0.99996448, 0.94123237, 0.92673981, 0.58165141}, + {0.41634875, 0.87652871, 0.52327084, 0.60899574, 0.97460049}, + {0.77076745, 0.46439171, 0.25499671, 0.18764164, 0.13748069}, + {0.19368776, 0.11778548, 0.55451791, 0.06335824, 0.63534461} + }, + { + {0.52078045, 0.85837043, 0.44845609, 0.69742864, 0.99834278}, + {0.23162816, 0.63328557, 0.24782906, 0.37476312, 0.16915018}, + {0.96264864, 0.97704619, 0.58534633, 0.87405632, 0.4750216}, + {0.73685149, 0.13915827, 0.23992944, 0.06455061, 0.30500096} + } + }; + + double[][][] x = { + { + {5.67710153e-02, 5.68608495e-01, 6.94753423e-01, 7.06106392e-01, 9.55901476e-01}, + {1.16221311e-01, 2.77955841e-01, 8.48163908e-01, 6.65887805e-01, 8.48399407e-01}, + {1.32232733e-01, 6.07996978e-01, 5.04046847e-01, 9.79583238e-02, 6.71959629e-01}, + {9.69122927e-01, 2.65313461e-01, 7.25259997e-01, 2.95230608e-02, 2.68600949e-01}, + {2.81511360e-01, 9.99964484e-01, 9.41232373e-01, 9.26739808e-01, 5.81651412e-01}, + {4.16348754e-01, 8.76528710e-01, 5.23270835e-01, 6.08995742e-01, 9.74600488e-01}, + {7.70767447e-01, 4.64391706e-01, 2.54996707e-01, 1.87641636e-01, 1.37480691e-01}, + {1.93687759e-01, 1.17785480e-01, 5.54517906e-01, 6.33582392e-02, 6.35344611e-01} + }, + { + {9.26675552e-02, 9.11034266e-01, 9.42616405e-01, 1.76616001e-01, 4.35131783e-01}, + {3.42867908e-01, 4.42621793e-02, 1.86904412e-01, 2.30573118e-05, 1.40271865e-01}, + {9.92634263e-01, 3.50624173e-01, 9.53986246e-01, 6.98818650e-01, 9.82469750e-01}, + {7.84919140e-01, 5.03811516e-01, 2.99471974e-01, 4.13124006e-01, 1.67204622e-01}, + {5.20780455e-01, 8.58370427e-01, 4.48456095e-01, 6.97428643e-01, 9.98342781e-01}, + {2.31628161e-01, 6.33285571e-01, 2.47829057e-01, 3.74763124e-01, 1.69150184e-01}, + {9.62648639e-01, 9.77046190e-01, 5.85346335e-01, 8.74056318e-01, 4.75021602e-01}, + {7.36851488e-01, 1.39158268e-01, 2.39929436e-01, 6.45506139e-02, 3.05000963e-01} + } + }; + + @Test + public void testConcatenate() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Concatenate instance = new Concatenate<>(tf, 1, TFloat64.class); + List> resultList = + instance.call( + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), + TFloat64.class); + + Operand result = resultList.get(0); + + assertArrayEquals(new long[] {Shape.UNKNOWN_SIZE, 8, 5}, result.shape().asArray()); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + session.evaluate(tf.constant(x), result, feedMap); + } + } + } + + @Test + public void testMask() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Concatenate instance = new Concatenate<>(tf, TFloat64.class); + List> inputs = + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)); + List> mask = Arrays.asList(null, null); + + List> result = instance.computeMask(inputs, mask); + assertNull(result); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + mask = Arrays.asList(x1Op, x2Op); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + result = instance.computeMask(inputs, mask); + Boolean[] expected = new Boolean[(int) result.get(0).size()]; + Arrays.fill(expected, true); + session.evaluate(expected, result.get(0), feedMap); + } + } + } + + @Test + public void testMaskInvalidMaskSize() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Concatenate instance = new Concatenate<>(tf, TFloat64.class); + List> inputs = + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)); + List> mask = Arrays.asList(null, null, null); + + List> result = instance.computeMask(inputs, mask); + assertNull(result); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + mask = Arrays.asList(x1Op, x2Op); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + result = instance.computeMask(inputs, mask); + Boolean[] expected = new Boolean[(int) result.get(0).size()]; + Arrays.fill(expected, true); + session.evaluate(expected, result.get(0), feedMap); + } + } + }); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DenseTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DenseTest.java new file mode 100644 index 00000000000..78dfe67ec06 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DenseTest.java @@ -0,0 +1,633 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.constraints.MinMaxNorm; +import org.tensorflow.framework.constraints.NonNeg; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; + +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + +public class DenseTest { + + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testShape3_2() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Shape inputShape = Shape.of(3, 2); + int units = 3; + + Dense instance = + new Dense<>( + tf, units, 1001L, TFloat32.class, Layer.Options.create().inputShape(inputShape)); + + float[][] data = { + {6.600953f, 4.659476f}, + {6.943807f, 2.113826f}, + {4.667166f, 6.931125f} + }; + float[][] expected = { + {5.866010f, 8.906790f, 7.214075f}, + {5.409174f, 8.020676f, 7.272593f}, + {5.140971f, 8.056995f, 5.513147f} + }; + + List> weights = instance.getWeights(); + instance.setWeights(weights); + Operand input = tf.constant(data); + + Operand y = instance.call(input, TFloat32.class); + session.run(tf.init()); + + List computed = instance.computeOutputShape(Collections.singletonList(inputShape)); + assertEquals(1, computed.size()); + assertEquals(computed.get(0), y.shape()); + Shape expectedOutput = Shape.of(3, units); + assertEquals(expectedOutput, y.shape()); + session.evaluate(tf.constant(expected), y); + } + } + + @Test + public void testShape4_2() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Shape inputShape = Shape.of(4, 2); + int units = 3; + + Dense instance = + new Dense<>( + tf, units, 1001L, TFloat32.class, Layer.Options.create().inputShape(inputShape)); + + float[][] inputArray = { + {6.600953f, 4.659476f}, + {6.943807f, 2.113826f}, + {4.667166f, 6.931125f}, + {7.716860f, 3.205337f} + }; + + List> weights = instance.getWeights(); + instance.setWeights(weights); + Operand input = tf.reshape(tf.constant(inputArray), tf.constant(inputShape)); + + Operand y = instance.call(input, TFloat32.class); + session.run(tf.init()); + List computedShapes = + instance.computeOutputShape(Collections.singletonList(input.shape())); + assertFalse(computedShapes.isEmpty()); + Shape computedShape = computedShapes.get(0); + Shape expectedShape = Shape.of(4, units); + assertEquals(expectedShape, computedShape); + assertEquals(expectedShape, y.shape()); + + float[][] expected = { + {5.866010f, 8.906790f, 7.214075f}, + {5.409174f, 8.020676f, 7.272593f}, + {5.140971f, 8.056996f, 5.513148f}, + {6.245262f, 9.327854f, 8.179358f} + }; + + session.evaluate(tf.constant(expected), y); + } + } + + @Test + public void testShapeN_N_2() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Shape inputShape = Shape.of(Shape.UNKNOWN_SIZE, Shape.UNKNOWN_SIZE, 2); + int units = 3; + + Dense instance = + new Dense<>( + tf, units, 1001L, TFloat32.class, Layer.Options.create().inputShape(inputShape)); + + Shape fullShape = Shape.of(5, 10, 2); + float[][][] data = { + { + {6.600953f, 4.659476f}, + {6.943807f, 2.113826f}, + {4.667166f, 6.931125f}, + {7.716860f, 3.205337f}, + {8.066205f, 2.362994f}, + {0.686355f, 8.934626f}, + {1.293296f, 9.073912f}, + {4.554000f, 0.347209f}, + {6.760708f, 8.464749f}, + {9.203295f, 6.147404f} + }, + { + {7.022987f, 3.022041f}, + {0.175645f, 7.057390f}, + {4.537057f, 3.270523f}, + {5.694380f, 0.481678f}, + {1.267088f, 4.573346f}, + {7.239103f, 2.671200f}, + {4.631621f, 1.366283f}, + {4.380660f, 0.902928f}, + {7.663558f, 8.725193f}, + {4.102549f, 2.243720f} + }, + { + {0.251945f, 1.804798f}, + {5.300526f, 7.791917f}, + {-0.071388f, 9.458032f}, + {7.492148f, 1.584492f}, + {6.854610f, 2.461785f}, + {4.187295f, 3.974617f}, + {-0.015711f, 1.355883f}, + {1.855492f, 7.734279f}, + {3.403170f, 7.473061f}, + {4.243813f, 6.584970f} + }, + { + {1.645227f, 0.730085f}, + {3.999032f, 5.628812f}, + {5.522727f, 3.001995f}, + {2.459637f, 9.221226f}, + {0.305633f, 9.156766f}, + {8.218584f, 7.329232f}, + {2.657161f, 3.237010f}, + {3.008971f, 7.147655f}, + {2.788105f, 2.895133f}, + {2.805755f, 3.646185f} + }, + { + {2.086996f, 5.481725f}, + {4.222548f, 4.396897f}, + {1.799221f, 7.522835f}, + {3.549520f, 9.244308f}, + {4.980303f, 0.475735f}, + {3.644282f, 0.544247f}, + {6.282454f, 8.306262f}, + {3.650939f, 1.386086f}, + {3.526051f, 1.671946f}, + {7.763572f, 6.653723f} + }, + }; + + float[][][] expected = { + { + {5.866010f, 8.906790f, 7.214075f}, + {5.409174f, 8.020676f, 7.272593f}, + {5.140971f, 8.056995f, 5.513147f}, + {6.245262f, 9.327854f, 8.179358f}, + {6.258242f, 9.272379f, 8.437642f}, + {2.918294f, 5.014478f, 1.708536f}, + {3.378673f, 5.693542f, 2.339058f}, + {3.263673f, 4.757503f, 4.651771f}, + {7.016673f, 10.908866f, 7.807479f}, + {8.083269f, 12.249319f, 10.018546f} + }, + { + {5.712371f, 8.539888f, 7.455799f}, + {2.050113f, 3.591536f, 0.978359f}, + {4.050456f, 6.154792f, 4.966178f}, + {4.093921f, 5.971835f, 5.822024f}, + {2.131000f, 3.489655f, 1.802051f}, + {5.766911f, 8.587945f, 7.634893f}, + {3.596069f, 5.328780f, 4.845973f}, + {3.294865f, 4.851680f, 4.539239f}, + {7.716053f, 11.944765f, 8.751445f}, + {3.467615f, 5.220104f, 4.409636f} + }, + { + {0.668335f, 1.127111f, 0.459879f}, + {5.816830f, 9.111765f, 6.252261f}, + {2.534012f, 4.504061f, 1.000444f}, + {5.646128f, 8.317189f, 7.767926f}, + {5.442162f, 8.099133f, 7.221718f}, + {3.999421f, 6.142959f, 4.691792f}, + {0.359459f, 0.640173f, 0.137874f}, + {3.403915f, 5.611978f, 2.756518f}, + {4.409483f, 7.045343f, 4.294413f}, + {4.751827f, 7.462863f, 5.045105f} + }, + { + {1.344244f, 2.011289f, 1.749129f}, + {4.320304f, 6.753564f, 4.688737f}, + {4.662963f, 7.018229f, 5.934030f}, + {4.230494f, 6.940253f, 3.537062f}, + {2.714058f, 4.738264f, 1.348129f}, + {7.720919f, 11.828723f, 9.155255f}, + {2.733208f, 4.244021f, 3.058378f}, + {4.046294f, 6.490631f, 3.858251f}, + {2.730931f, 4.210578f, 3.152225f}, + {2.948380f, 4.591742f, 3.255287f} + }, + { + {2.949665f, 4.755452f, 2.735502f}, + {4.139306f, 6.382794f, 4.775392f}, + {3.306999f, 5.452967f, 2.675543f}, + {4.995176f, 8.049803f, 4.643537f}, + {3.595419f, 5.249314f, 5.098118f}, + {2.684487f, 3.936022f, 3.752738f}, + {6.640594f, 10.350204f, 7.305118f}, + {2.919088f, 4.350031f, 3.854963f}, + {2.910276f, 4.362474f, 3.760896f}, + {7.219775f, 11.043336f, 8.617791f} + } + }; + + List> weights = instance.getWeights(); + instance.setWeights(weights); + Operand input = tf.constant(data); + + Operand y = instance.call(input, TFloat32.class); + session.run(tf.init()); + + List computed = instance.computeOutputShape(Collections.singletonList(fullShape)); + assertEquals(1, computed.size()); + assertEquals(computed.get(0), y.shape()); + Shape expectedOutput = Shape.of(5, 10, units); + assertEquals(expectedOutput, y.shape()); + session.evaluate(tf.constant(expected), y); + } + } + + @Test + public void testShape3_4_5_2() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + int units = 3; + + Shape inputShape = Shape.of(3, 4, 5, 2); + + Dense instance = + new Dense<>( + tf, units, 1001L, TFloat32.class, Layer.Options.create().inputShape(inputShape)); + assertEquals("Dense", instance.getName()); + session.run(tf.init()); + + float[][][][] data = { + { + { + {6.600953f, 4.659476f}, + {6.943807f, 2.113826f}, + {4.667166f, 6.931125f}, + {7.716860f, 3.205337f}, + {8.066205f, 2.362994f}, + }, + { + {0.686355f, 8.934626f}, + {1.293296f, 9.073912f}, + {4.554000f, 0.347209f}, + {6.760708f, 8.464749f}, + {9.203295f, 6.147404f}, + }, + { + {7.022987f, 3.022041f}, + {0.175645f, 7.057390f}, + {4.537057f, 3.270523f}, + {5.694380f, 0.481678f}, + {1.267088f, 4.573346f}, + }, + { + {7.239103f, 2.671200f}, + {4.631621f, 1.366283f}, + {4.380660f, 0.902928f}, + {7.663558f, 8.725193f}, + {4.102549f, 2.243720f}, + }, + }, + { + { + {0.251945f, 1.804798f}, + {5.300526f, 7.791917f}, + {-0.071388f, 9.458032f}, + {7.492148f, 1.584492f}, + {6.854610f, 2.461785f}, + }, + { + {4.187295f, 3.974617f}, + {-0.015711f, 1.355883f}, + {1.855492f, 7.734279f}, + {3.403170f, 7.473061f}, + {4.243813f, 6.584970f}, + }, + { + {1.645227f, 0.730085f}, + {3.999032f, 5.628812f}, + {5.522727f, 3.001995f}, + {2.459637f, 9.221226f}, + {0.305633f, 9.156766f}, + }, + { + {8.218584f, 7.329232f}, + {2.657161f, 3.237010f}, + {3.008971f, 7.147655f}, + {2.788105f, 2.895133f}, + {2.805755f, 3.646185f}, + }, + }, + { + { + {2.086996f, 5.481725f}, + {4.222548f, 4.396897f}, + {1.799221f, 7.522835f}, + {3.549520f, 9.244308f}, + {4.980303f, 0.475735f}, + }, + { + {3.644282f, 0.544247f}, + {6.282454f, 8.306262f}, + {3.650939f, 1.386086f}, + {3.526051f, 1.671946f}, + {7.763572f, 6.653723f}, + }, + { + {2.367239f, 3.317834f}, + {2.330428f, 9.358873f}, + {3.638705f, 5.096712f}, + {9.156695f, 4.436713f}, + {-0.416358f, 8.118915f}, + }, + { + {6.330701f, 6.326071f}, + {4.724874f, -0.368026f}, + {3.975863f, 0.017570f}, + {3.545376f, 7.946171f}, + {-0.495031f, 7.853283f}, + } + } + }; + + float[][][][] expected = { + { + { + {5.866010f, 8.906790f, 7.214075f}, + {5.409174f, 8.020676f, 7.272593f}, + {5.140971f, 8.056995f, 5.513147f}, + {6.245262f, 9.327854f, 8.179358f}, + {6.258242f, 9.272379f, 8.437642f}, + }, + { + {2.918294f, 5.014478f, 1.708536f}, + {3.378673f, 5.693542f, 2.339058f}, + {3.263673f, 4.757503f, 4.651771f}, + {7.016673f, 10.908866f, 7.807479f}, + {8.083269f, 12.249319f, 10.018546f}, + }, + { + {5.712371f, 8.539888f, 7.455799f}, + {2.050113f, 3.591536f, 0.978359f}, + {4.050456f, 6.154792f, 4.966178f}, + {4.093921f, 5.971835f, 5.822024f}, + {2.131000f, 3.489655f, 1.802051f}, + }, + { + {5.766911f, 8.587945f, 7.634893f}, + {3.596069f, 5.328780f, 4.845973f}, + {3.294865f, 4.851680f, 4.539239f}, + {7.716053f, 11.944765f, 8.751445f}, + {3.467615f, 5.220104f, 4.409636f}, + } + }, + { + { + {0.668335f, 1.127111f, 0.459879f}, + {5.816830f, 9.111765f, 6.252261f}, + {2.534012f, 4.504061f, 1.000444f}, + {5.646128f, 8.317189f, 7.767926f}, + {5.442162f, 8.099133f, 7.221718f}, + }, + { + {3.999421f, 6.142959f, 4.691792f}, + {0.359459f, 0.640173f, 0.137874f}, + {3.403915f, 5.611978f, 2.756518f}, + {4.409483f, 7.045343f, 4.294413f}, + {4.751827f, 7.462863f, 5.045105f}, + }, + { + {1.344244f, 2.011289f, 1.749129f}, + {4.320304f, 6.753564f, 4.688737f}, + {4.662963f, 7.018229f, 5.934030f}, + {4.230494f, 6.940253f, 3.537062f}, + {2.714058f, 4.738264f, 1.348129f}, + }, + { + {7.720919f, 11.828723f, 9.155255f}, + {2.733208f, 4.244021f, 3.058378f}, + {4.046294f, 6.490631f, 3.858251f}, + {2.730931f, 4.210578f, 3.152225f}, + {2.948380f, 4.591742f, 3.255287f}, + } + }, + { + { + {2.949665f, 4.755452f, 2.735502f}, + {4.139306f, 6.382794f, 4.775392f}, + {3.306999f, 5.452967f, 2.675543f}, + {4.995176f, 8.049803f, 4.643537f}, + {3.595419f, 5.249314f, 5.098118f}, + }, + { + {2.684487f, 3.936022f, 3.752738f}, + {6.640594f, 10.350204f, 7.305118f}, + {2.919088f, 4.350031f, 3.854963f}, + {2.910276f, 4.362474f, 3.760896f}, + {7.219775f, 11.043336f, 8.617791f}, + }, + { + {2.553549f, 3.990942f, 2.773906f}, + {4.178188f, 6.876633f, 3.421808f}, + {3.924220f, 6.132985f, 4.263438f}, + {7.583528f, 11.374685f, 9.777319f}, + {1.928159f, 3.508506f, 0.499166f}, + }, + { + {6.133229f, 9.440765f, 7.129385f}, + {3.187190f, 4.583663f, 4.743713f}, + {2.771337f, 4.015369f, 4.028832f}, + {4.637676f, 7.417559f, 4.492103f}, + {1.800851f, 3.300701f, 0.389355f}, + } + } + }; + + List> weights = instance.getWeights(); + instance.setWeights(weights); + Operand input = tf.constant(data); + + Operand y = instance.call(input, TFloat32.class); + session.run(tf.init()); + + List computed = instance.computeOutputShape(Collections.singletonList(inputShape)); + assertEquals(1, computed.size()); + assertEquals(computed.get(0), y.shape()); + Shape expectedOutput = Shape.of(3, 4, 5, units); + assertEquals(expectedOutput, y.shape()); + session.evaluate(tf.constant(expected), y); + } + } + + @Test + public void testConstraintsNonNeg() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Shape inputShape = Shape.of(3, 2); + int units = 3; + + NonNeg nonNeg = new NonNeg(tf); + + Dense instance = + new Dense<>( + tf, + "constraintTest", + units, + null, + true, + null, + null, + null, + null, + null, + nonNeg::call, + nonNeg::call, + 1001L, + TFloat32.class, + Layer.Options.create().inputShape(inputShape)); + + float[][] data = { + {6.600953f, 4.659476f}, + {6.943807f, 2.113826f}, + {4.667166f, 6.931125f} + }; + + float[][] constraintInput = { + {-1, 2, 5}, + {-2, 4, -4} + }; + float[][] constraintExpected = { + {-0, 2, 5}, + {-0, 4, -0} + }; + + float[] biasConstraintInput = {-1, 2, 5}; + float[] biasConstraintExpected = {-0, 2, 5}; + + Operand input = tf.constant(data); + + @SuppressWarnings("unused") + Operand y = instance.call(input, TFloat32.class); + // initialize variables + session.run(tf.init()); + + List> weights = instance.getWeights(); + instance.setWeights(weights); + + // Test kernel + Variable kernel = instance.getKernel(); + Operand varUpdate = instance.assign(kernel, tf.constant(constraintInput)); + session.run(varUpdate); + session.evaluate(tf.constant(constraintExpected), instance.applyConstraint(kernel)); + + // test bias + Variable bias = instance.getBias(); + assertEquals(Shape.of(units), bias.shape()); + varUpdate = instance.assignAdd(bias, tf.constant(biasConstraintInput)); + session.run(varUpdate); + session.evaluate(tf.constant(biasConstraintExpected), instance.applyConstraint(bias)); + } + } + + @Test + public void testConstraintsMinMaxNorm() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Shape inputShape = Shape.of(3, 2); + int units = 3; + + MinMaxNorm minMaxNorm = new MinMaxNorm(tf); + + Dense instance = + new Dense<>( + tf, + "constraintTest", + units, + null, + true, + null, + null, + null, + null, + null, + minMaxNorm::call, + minMaxNorm::call, + 1001L, + TFloat32.class, + Layer.Options.create().inputShape(inputShape)); + + float[][] data = { + {6.600953f, 4.659476f}, + {6.943807f, 2.113826f}, + {4.667166f, 6.931125f} + }; + + float[][] constraintInput = { + {1, 0.5f, 2}, + {-2, 0.75f, 0} + }; + float[][] constraintExpected = { + {0.447214f, 0.5f, 1}, + {-0.894427f, 0.75f, 0} + }; + + float[] biasConstraintInput = {-1, 2, 5}; + float[] biasConstraintExpected = {-0.182574f, 0.365148f, 0.912871f}; + + Operand input = tf.constant(data); + + @SuppressWarnings("unused") + Operand y = instance.call(input, TFloat32.class); + // initialize variables + session.run(tf.init()); + + List> weights = instance.getWeights(); + instance.setWeights(weights); + + // Test kernel + Variable kernel = instance.getKernel(); + Operand varUpdate = instance.assign(kernel, tf.constant(constraintInput)); + session.run(varUpdate); + session.evaluate(tf.constant(constraintExpected), instance.applyConstraint(kernel)); + + // test bias + Variable bias = instance.getBias(); + assertEquals(Shape.of(units), bias.shape()); + + varUpdate = instance.assignAdd(bias, tf.constant(biasConstraintInput)); + session.run(varUpdate); + session.evaluate(tf.constant(biasConstraintExpected), instance.applyConstraint(bias)); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DotTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DotTest.java new file mode 100644 index 00000000000..c8856dcbba2 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DotTest.java @@ -0,0 +1,133 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + +class DotTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + double[][] x1 = { + {0.04867243, 0.42833055, 0.57495679, 0.04191259}, {0.48993384, 0.80122145, 0.8199583, 0.0552641} + }; + double[][] x2 = { + {0.37530763, 0.65938955, 0.69901548, 0.87864686}, + {0.79027356, 0.29017831, 0.62662979, 0.34575866} + }; + + double[][] xdot = {{0.73943388}, {1.15259719}}; + + @Test + public void testDot() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4))); + Dot instance = new Dot<>(tf, 1, TFloat64.class); + List> resultList = + instance.call( + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), + TFloat64.class); + + Operand result = resultList.get(0); + + assertArrayEquals(new long[] {Shape.UNKNOWN_SIZE, 1}, result.shape().asArray()); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + session.evaluate(tf.constant(xdot), result, feedMap); + } + } + } + + @Test + public void testDotNegativeAxis() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4))); + Dot instance = new Dot<>(tf, new int[] {-1, -1}, TFloat64.class); + List> resultList = + instance.call( + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), + TFloat64.class); + + Operand result = resultList.get(0); + + assertArrayEquals(new long[] {Shape.UNKNOWN_SIZE, 1}, result.shape().asArray()); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + session.evaluate(tf.constant(xdot), result, feedMap); + } + } + } + + @Test + public void testDotComputeOutputShape() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Dot dot = new Dot<>(tf, -1, TFloat32.class); + + List outputShapes = + dot.computeOutputShape(Arrays.asList(Shape.of(4, 5), Shape.of(4, 5))); + assertFalse(outputShapes.isEmpty()); + assertArrayEquals(new long[] {4,1}, outputShapes.get(0).asArray()); + + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DropoutTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DropoutTest.java new file mode 100644 index 00000000000..cf093bd7c6d --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DropoutTest.java @@ -0,0 +1,89 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class DropoutTest { + + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testShape3_2() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + long seed = 1001L; + Shape expectedShape = Shape.of(3, 2); + Operand input = + tf.constant( + new float[][] { + {1.3463433f, 7.2481093f}, + {5.4018216f, 0.6772865f}, + {3.4442706f, 0.95697135f} + }); + + Dropout instance = new Dropout<>(tf, 0.5f, seed, TFloat32.class); + + // first pass, trainable is false, so there should be no dropout + Operand result = instance.call(input, false, TFloat32.class); + assertEquals(expectedShape, result.shape()); + session.evaluate(tf.identity(input), result); + + Operand expected = + tf.constant( + new float[][] { + {0f, 14.496219f}, + {0f, 0f}, + {0f, 0f} + }); + + // second pass, trainable is true, so there should be dropout + result = instance.call(input, true, TFloat32.class); + assertEquals(expectedShape, result.shape()); + session.evaluate(expected, result); + } + } + + @Test + public void testShape3_2Noise() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + long seed = 1001L; + Shape expectedShape = Shape.of(3, 2); + Operand input = + tf.constant( + new float[][] { + {1.3463433f, 7.2481093f}, + {5.4018216f, 0.6772865f}, + {3.4442706f, 0.95697135f} + }); + + Dropout instance = new Dropout<>(tf, 0.5f, Shape.of(3, 1), seed, TFloat32.class); + + Float[] expected = new Float[] {0f, 0f, 10.803643f, 1.354573f, 0f, 0f}; + + // trainable is true, so there should be dropout + Operand result = instance.call(input, true, TFloat32.class); + assertEquals(expectedShape, result.shape()); + // Note: this can only be evaluated once, or else the result will be updated with + // new values and will not match expected. + session.evaluate(expected, result); + } + } + + @Test + public void testSupportsMasking() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + long seed = 1001L; + Dropout instance = new Dropout<>(tf, 0.5f, seed, TFloat32.class); + assertTrue(instance.isSupportsMasking()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ELUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ELUTest.java new file mode 100644 index 00000000000..b492d0c7c8d --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ELUTest.java @@ -0,0 +1,117 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class ELUTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + float[][][] inputArray = { + { + {2.70857435f, 8.25453567f, 9.75479311f, 1.10273526f}, + {8.69836437f, 2.27818352f, 8.60856328f, 1.43265882f}, + {0.75845834f, 5.60463474f, 7.35998787f, 0.06365667f} + }, + { + {4.87355239f, 9.90221978f, 5.39014402f, 2.05263398f}, + {5.91652733f, 0.9186602f, 0.91375672f, 0.56053326f}, + {2.08046551f, 8.53763374f, 6.40378721f, 5.83284758f} + } + }; + + @Test + public void testCallAlpha0() { + + float[][][] expectedArray = { + { + {2.7085743f, 8.254536f, 9.754793f, 1.1027353f}, + {8.698364f, 2.2781835f, 8.608563f, 1.4326588f}, + {0.7584583f, 5.604635f, 7.3599877f, 0.06365667f} + }, + { + {4.8735523f, 9.90222f, 5.390144f, 2.052634f}, + {5.9165273f, 0.9186602f, 0.9137567f, 0.5605333f}, + {2.0804656f, 8.537634f, 6.403787f, 5.8328476f} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float alpha = 0f; + ELU instance = + new ELU<>( + tf, alpha, TFloat32.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = instance.call(tf.constant(inputArray), TFloat32.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } + + @Test + public void testCallAlpha0Point5() { + + float[][][] expectedArray = { + { + {2.7085743f, 8.254536f, 9.754793f, 1.1027353f}, + {8.698364f, 2.2781835f, 8.608563f, 1.4326588f}, + {0.7584583f, 5.604635f, 7.3599877f, 0.06365667f} + }, + { + {4.8735523f, 9.90222f, 5.390144f, 2.052634f}, + {5.9165273f, 0.9186602f, 0.9137567f, 0.5605333f}, + {2.0804656f, 8.537634f, 6.403787f, 5.8328476f} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float alpha = 0.5f; + ELU instance = + new ELU<>( + tf, alpha, TFloat32.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = instance.call(tf.constant(inputArray), TFloat32.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } + + @Test + public void testCallAlphaMinus1() { + + double[][][] expectedArray = { + { + {2.7085743, 8.254536, 9.754793, 1.1027353}, + {8.698364, 2.2781835, 8.608563, 1.4326588}, + {0.7584583, 5.604635, 7.3599877, 0.06365667} + }, + { + {4.8735523, 9.90222, 5.390144, 2.052634}, + {5.9165273, 0.9186602, 0.9137567, 0.5605333}, + {2.0804656, 8.537634, 6.403787, 5.8328476} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float alpha = -1.f; + ELU instance = + new ELU<>( + tf, alpha, TFloat64.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = + instance.call(tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/FlattenTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/FlattenTest.java new file mode 100644 index 00000000000..0c360c6ecdf --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/FlattenTest.java @@ -0,0 +1,74 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.TensorFormat; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class FlattenTest { + + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCall() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape inputShape = Shape.of(1, 3, 2); + float[] a = {1F, 2F, 3F, 4F, 5F, 6F}; + Float[] expected = {1F, 2F, 3F, 4F, 5F, 6F}; + Shape expectedShape = Shape.of(1, 6); + Operand input = tf.reshape(tf.constant(a), tf.constant(inputShape)); + Flatten layer = new Flatten<>(tf, TFloat32.class); + Operand output = layer.call(input, TFloat32.class); + assertEquals(expectedShape, output.shape()); + session.evaluate(expected, output); + } + } + + @Test + public void testCallChannelsFirst() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[] a = { + 0.12911275f, + 0.16172077f, + 0.7024991f, + 0.3936557f, + 0.8216052f, + 0.04838822f, + 0.96763366f, + 0.1477106f, + 0.03416549f, + 0.40088153f + }; + Shape expectedShape = Shape.of(10, 1); + Operand input = tf.constant(a); + Flatten layer = new Flatten<>(tf, TensorFormat.NCHW, TFloat32.class); + Operand output = layer.call(input, TFloat32.class); + assertEquals(expectedShape, output.asOutput().shape()); + Operand expected = tf.expandDims(input, tf.constant(-1)); + session.evaluate(expected, output); + } + } + + /** Test of computeOutputShape method, of class Flatten. */ + @Test + public void testComputeOutputShape() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape inputShape = Shape.of(1, 3, 2); + Shape expectedShape = Shape.of(1, 6); + Flatten layer = new Flatten<>(tf, TFloat32.class); + List computedShapes = layer.computeOutputShape(Collections.singletonList(inputShape)); + assertEquals(expectedShape, computedShapes.get(0)); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianDropoutTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianDropoutTest.java new file mode 100644 index 00000000000..c031b8c3387 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianDropoutTest.java @@ -0,0 +1,62 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class GaussianDropoutTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testShape3_2_3() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + long seed = 1001L; + Shape expectedShape = Shape.of(3, 2, 3); + Operand input = + tf.constant( + new double[][][] { + {{3.22382299, 1.41224385, 7.265976}, {9.1436238, 6.15759347, 6.79954284}}, + {{6.41459591, 2.16451569, 4.12015256}, {2.42915398, 2.27193001, 1.09604702}}, + {{5.13626611, 4.34388458, 1.32951124}, {8.47118881, 6.70455732, 8.57420547}} + }); + + GaussianDropout instance = new GaussianDropout<>(tf, 0.5f, seed, TFloat64.class); + + // first pass, trainable is false, so there should be no dropout + Operand result = instance.call(input, false, TFloat64.class); + assertEquals(expectedShape, result.shape()); + session.evaluate(tf.identity(input), result); + + Operand expected = + tf.constant( + new double[][][] { + {{0.734139, 0.398936, 5.600555}, {0.817308, 6.134556, 2.074016}}, + {{3.984928, 1.046533, 2.354332}, {1.876065, 1.218514, 1.014165}}, + {{4.405925, 3.813551, 1.100304}, {4.984621, 1.846423, 2.097348}} + }); + + // second pass, trainable is true, so there should be dropout + result = instance.call(input, true, TFloat64.class); + assertEquals(expectedShape, result.shape()); + session.evaluate(expected, result); + } + } + + @Test + public void testSupportsMasking() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + long seed = 1001L; + Dropout instance = new Dropout<>(tf, 0.5f, seed, TFloat32.class); + assertTrue(instance.isSupportsMasking()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianNoiseTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianNoiseTest.java new file mode 100644 index 00000000000..ab7cdde0906 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianNoiseTest.java @@ -0,0 +1,65 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class GaussianNoiseTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testShape3_2_3() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + long seed = 1001L; + Shape expectedShape = Shape.of(3, 2, 3); + Operand input = + tf.constant( + new double[][][] { + {{3.451546, 1.694727, 8.036768}, {9.233009, 6.462616, 7.226611}}, + {{6.644345, 2.605274, 4.224330}, {2.912649, 2.843349, 1.097698}}, + {{5.672600, 5.269178, 2.187318}, {9.298789, 7.292978, 8.849604}} + }); + + GaussianNoise instance = new GaussianNoise<>(tf, 1.f, seed, TFloat64.class); + + // first pass, trainable is false, so there should be no dropout + Operand result = instance.call(input, false, TFloat64.class); + assertEquals(expectedShape, result.shape()); + session.evaluate(tf.identity(input), result); + + Operand expected = + tf.constant( + new double[][][] { + {{3.679269, 1.977210, 8.807560}, {9.322395, 6.767639, 7.653679}}, + {{6.874095, 3.046032, 4.328507}, {3.396144, 3.414768, 1.099349}}, + {{6.208934, 6.194471, 3.045125}, {10.126389, 7.881398, 9.125002}} + }); + + // second pass, trainable is true, so there should be noise applied + result = instance.call(input, true, TFloat64.class); + assertEquals(expectedShape, result.shape()); + // cannot evaluate more than once, else it doesn't match expected + // because of random number generation. + // session.print(result); + session.evaluate(expected, result); + } + } + + @Test + public void testSupportsMasking() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + long seed = 1001L; + Dropout instance = new Dropout<>(tf, 0.5f, seed, TFloat32.class); + assertTrue(instance.isSupportsMasking()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/InputTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/InputTest.java new file mode 100644 index 00000000000..185262ee5dd --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/InputTest.java @@ -0,0 +1,101 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.family.TType; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class InputTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + void call() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[][] array = new float[][] {{0, 1}, {2, 3}}; + Operand input = tf.constant(array); + Input instance = new Input<>(tf, input, TFloat32.class); + List> result = + instance.call(Collections.singletonList(input), null, false, TFloat32.class); + + assertNotNull(result); + assertEquals(1, result.size()); + + session.evaluate(input, result.get(0)); + } + } + + @Test + void getOutput() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[][] array = new float[][] {{0, 1}, {2, 3}}; + Operand input = tf.constant(array); + Input instance = new Input<>(tf, input, TFloat32.class); + Operand result = instance.getOutput(TFloat32.class); + + session.evaluate(input, result); + } + } + + @Test + void isPlaceholder() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[][] array = new float[][] {{0, 1}, {2, 3}}; + Operand input = tf.constant(array); + Input instance = + new Input<>( + tf, TFloat32.class, TFloat32.class, Layer.Options.create().inputShape(input.shape())); + + assertTrue(instance.isPlaceholder()); + Operand result = instance.getOutput(TFloat32.class); + assertTrue(result instanceof Placeholder); + try (TFloat32 inputTensor = + (TFloat32) session.getGraphSession().runner().fetch(input).run().get(0)) { + Map, Tensor> feedMap = + Collections.singletonMap(result, inputTensor); + session.evaluate(tf.constant(array), tf.identity(result), feedMap); + } + } + } + + @Test + void getInputType() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[][] array = new float[][] {{0, 1}, {2, 3}}; + Operand input = tf.constant(array); + Input instance = new Input<>(tf, input, TFloat32.class); + Operand result = instance.getOutput(TFloat32.class); + + session.evaluate(input, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LambdaTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LambdaTest.java new file mode 100644 index 00000000000..738d5daf5e0 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LambdaTest.java @@ -0,0 +1,43 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.utils.CastHelper.cast; + +class LambdaTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCallLambda() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape shape = Shape.of(3, 2); + + Lambda instance = new Lambda<>(tf, TFloat32.class); + instance.setLambda((t, y) -> t.math.mul(cast(t, t.constant(2), y.type()), y)); + + double[][] array = { + {0.41448207, 0.71509451}, {0.21307868, 0.76890945}, {0.37533432, 0.7761148} + }; + double[][] expected = new double[array.length][array[0].length]; + for (int i = 0; i < array.length; i++) { + for (int j = 0; j < array[0].length; j++) { + expected[i][j] = array[i][j] * 2; + } + } + + Operand result = + instance.call(tf.dtypes.cast(tf.constant(array), TFloat64.class), TFloat64.class); + + assertEquals(shape, result.shape()); + session.evaluate(tf.constant(expected), result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LeakyReLUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LeakyReLUTest.java new file mode 100644 index 00000000000..363894827cb --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LeakyReLUTest.java @@ -0,0 +1,116 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class LeakyReLUTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + float[][][] inputArray = { + { + {2.70857435f, 8.25453567f, 9.75479311f, 1.10273526f}, + {8.69836437f, 2.27818352f, 8.60856328f, 1.43265882f}, + {0.75845834f, 5.60463474f, 7.35998787f, 0.06365667f} + }, + { + {4.87355239f, 9.90221978f, 5.39014402f, 2.05263398f}, + {5.91652733f, 0.9186602f, 0.91375672f, 0.56053326f}, + {2.08046551f, 8.53763374f, 6.40378721f, 5.83284758f} + } + }; + + @Test + public void testCallAlpha0() { + + float[][][] expectedArray = { + { + {2.7085743f, 8.254536f, 9.754793f, 1.1027353f}, + {8.698364f, 2.2781835f, 8.608563f, 1.4326588f}, + {0.7584583f, 5.604635f, 7.3599877f, 0.06365667f} + }, + { + {4.8735523f, 9.90222f, 5.390144f, 2.052634f}, + {5.9165273f, 0.9186602f, 0.9137567f, 0.5605333f}, + {2.0804656f, 8.537634f, 6.403787f, 5.8328476f} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float alpha = 0f; + LeakyReLU instance = + new LeakyReLU<>( + tf, alpha, TFloat32.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = instance.call(tf.constant(inputArray), TFloat32.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } + + @Test + public void testCallAlpha0Point5() { + + float[][][] expectedArray = { + { + {2.7085743f, 8.254536f, 9.754793f, 1.1027353f}, + {8.698364f, 2.2781835f, 8.608563f, 1.4326588f}, + {0.7584583f, 5.604635f, 7.3599877f, 0.06365667f} + }, + { + {4.8735523f, 9.90222f, 5.390144f, 2.052634f}, + {5.9165273f, 0.9186602f, 0.9137567f, 0.5605333f}, + {2.0804656f, 8.537634f, 6.403787f, 5.8328476f} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float alpha = 0.5f; + LeakyReLU instance = + new LeakyReLU<>( + tf, alpha, TFloat32.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = instance.call(tf.constant(inputArray), TFloat32.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } + + @Test + public void testCallAlphaMinus1() { + + double[][][] expectedArray = { + { + {2.7085743, 8.254536, 9.754793, 1.1027353}, + {8.698364, 2.2781835, 8.608563, 1.4326588}, + {0.7584583, 5.604635, 7.3599877, 0.06365667} + }, + { + {4.8735523, 9.90222, 5.390144, 2.052634}, + {5.9165273, 0.9186602, 0.9137567, 0.5605333}, + {2.0804656, 8.537634, 6.403787, 5.8328476} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float alpha = -1.f; + LeakyReLU instance = + new LeakyReLU<>( + tf, alpha, TFloat64.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = instance.call(tf.constant(inputArray), TFloat64.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MaximumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MaximumTest.java new file mode 100644 index 00000000000..4e7ed046c34 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MaximumTest.java @@ -0,0 +1,108 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +class MaximumTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + double[][][] x1 = { + { + {0.13570025, 0.55677077, 0.06648757, 0.58405729, 0.61086578}, + {0.18659685, 0.39331301, 0.68069423, 0.09510652, 0.86098578}, + {0.99338463, 0.37543824, 0.74858191, 0.31828287, 0.94056888}, + {0.76202298, 0.05605309, 0.73475366, 0.9313434, 0.48606332} + }, + { + {0.13023652, 0.39567908, 0.89910993, 0.71302943, 0.73722061}, + {0.6212917, 0.62624375, 0.8184835, 0.16864979, 0.96730508}, + {0.39645622, 0.35834793, 0.39924944, 0.90297727, 0.82857399}, + {0.70014157, 0.95498672, 0.6179583, 0.33104554, 0.11978174} + } + }; + double[][][] x2 = { + { + {0.82828211, 0.28889298, 0.7159566, 0.93377237, 0.32654201}, + {0.73234341, 0.17123203, 0.62582661, 0.96272026, 0.58700802}, + {0.12527705, 0.64175689, 0.64915537, 0.80589999, 0.26400939}, + {0.79376476, 0.24171677, 0.0677271, 0.07027092, 0.29195821} + }, + { + {0.56599224, 0.10611362, 0.83370522, 0.72514044, 0.08126704}, + {0.48173969, 0.16509515, 0.21040572, 0.44414272, 0.70656624}, + {0.89191749, 0.73008498, 0.9177326, 0.31897888, 0.56743576}, + {0.36304201, 0.36696309, 0.60722209, 0.79244879, 0.63492784} + } + }; + + double[][][] xmax = { + { + {0.82828211, 0.55677077, 0.7159566, 0.93377237, 0.61086578}, + {0.73234341, 0.39331301, 0.68069423, 0.96272026, 0.86098578}, + {0.99338463, 0.64175689, 0.74858191, 0.80589999, 0.94056888}, + {0.79376476, 0.24171677, 0.73475366, 0.9313434, 0.48606332} + }, + { + {0.56599224, 0.39567908, 0.89910993, 0.72514044, 0.73722061}, + {0.6212917, 0.62624375, 0.8184835, 0.44414272, 0.96730508}, + {0.89191749, 0.73008498, 0.9177326, 0.90297727, 0.82857399}, + {0.70014157, 0.95498672, 0.6179583, 0.79244879, 0.63492784} + } + }; + + @Test + public void testAverage() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Maximum instance = new Maximum<>(tf, TFloat64.class); + List> resultList = + instance.call( + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), + TFloat64.class); + + Operand result = resultList.get(0); + + assertArrayEquals(new long[] {Shape.UNKNOWN_SIZE, 4, 5}, result.shape().asArray()); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + session.evaluate(tf.constant(xmax), result, feedMap); + } + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MinimumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MinimumTest.java new file mode 100644 index 00000000000..c9132472d41 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MinimumTest.java @@ -0,0 +1,108 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +class MinimumTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + double[][][] x1 = { + { + {0.13570025, 0.55677077, 0.06648757, 0.58405729, 0.61086578}, + {0.18659685, 0.39331301, 0.68069423, 0.09510652, 0.86098578}, + {0.99338463, 0.37543824, 0.74858191, 0.31828287, 0.94056888}, + {0.76202298, 0.05605309, 0.73475366, 0.9313434, 0.48606332} + }, + { + {0.13023652, 0.39567908, 0.89910993, 0.71302943, 0.73722061}, + {0.6212917, 0.62624375, 0.8184835, 0.16864979, 0.96730508}, + {0.39645622, 0.35834793, 0.39924944, 0.90297727, 0.82857399}, + {0.70014157, 0.95498672, 0.6179583, 0.33104554, 0.11978174} + } + }; + double[][][] x2 = { + { + {0.82828211, 0.28889298, 0.7159566, 0.93377237, 0.32654201}, + {0.73234341, 0.17123203, 0.62582661, 0.96272026, 0.58700802}, + {0.12527705, 0.64175689, 0.64915537, 0.80589999, 0.26400939}, + {0.79376476, 0.24171677, 0.0677271, 0.07027092, 0.29195821} + }, + { + {0.56599224, 0.10611362, 0.83370522, 0.72514044, 0.08126704}, + {0.48173969, 0.16509515, 0.21040572, 0.44414272, 0.70656624}, + {0.89191749, 0.73008498, 0.9177326, 0.31897888, 0.56743576}, + {0.36304201, 0.36696309, 0.60722209, 0.79244879, 0.63492784} + } + }; + + double[][][] xmin = { + { + {0.13570025, 0.28889298, 0.06648757, 0.58405729, 0.32654201}, + {0.18659685, 0.17123203, 0.62582661, 0.09510652, 0.58700802}, + {0.12527705, 0.37543824, 0.64915537, 0.31828287, 0.26400939}, + {0.76202298, 0.05605309, 0.0677271, 0.07027092, 0.29195821} + }, + { + {0.13023652, 0.10611362, 0.83370522, 0.71302943, 0.08126704}, + {0.48173969, 0.16509515, 0.21040572, 0.16864979, 0.70656624}, + {0.39645622, 0.35834793, 0.39924944, 0.31897888, 0.56743576}, + {0.36304201, 0.36696309, 0.60722209, 0.33104554, 0.11978174} + } + }; + + @Test + public void testAverage() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Minimum instance = new Minimum<>(tf, TFloat64.class); + List> resultList = + instance.call( + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), + TFloat64.class); + + Operand result = resultList.get(0); + + assertArrayEquals(new long[] {Shape.UNKNOWN_SIZE, 4, 5}, result.shape().asArray()); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + session.evaluate(tf.constant(xmin), result, feedMap); + } + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MultiplyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MultiplyTest.java new file mode 100644 index 00000000000..33bc5775f96 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MultiplyTest.java @@ -0,0 +1,135 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +class MultiplyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + double[][][] x1 = { + { + {0.13570025, 0.55677077, 0.06648757, 0.58405729, 0.61086578}, + {0.18659685, 0.39331301, 0.68069423, 0.09510652, 0.86098578}, + {0.99338463, 0.37543824, 0.74858191, 0.31828287, 0.94056888}, + {0.76202298, 0.05605309, 0.73475366, 0.9313434, 0.48606332} + }, + { + {0.13023652, 0.39567908, 0.89910993, 0.71302943, 0.73722061}, + {0.6212917, 0.62624375, 0.8184835, 0.16864979, 0.96730508}, + {0.39645622, 0.35834793, 0.39924944, 0.90297727, 0.82857399}, + {0.70014157, 0.95498672, 0.6179583, 0.33104554, 0.11978174} + } + }; + double[][][] x2 = { + { + {0.82828211, 0.28889298, 0.7159566, 0.93377237, 0.32654201}, + {0.73234341, 0.17123203, 0.62582661, 0.96272026, 0.58700802}, + {0.12527705, 0.64175689, 0.64915537, 0.80589999, 0.26400939}, + {0.79376476, 0.24171677, 0.0677271, 0.07027092, 0.29195821} + }, + { + {0.56599224, 0.10611362, 0.83370522, 0.72514044, 0.08126704}, + {0.48173969, 0.16509515, 0.21040572, 0.44414272, 0.70656624}, + {0.89191749, 0.73008498, 0.9177326, 0.31897888, 0.56743576}, + {0.36304201, 0.36696309, 0.60722209, 0.79244879, 0.63492784} + } + }; + double[][][] x3 = { + { + {0.90545522, 0.55172128, 0.87254455, 0.1396359, 0.1538656}, + {0.04276304, 0.9315817, 0.91360492, 0.00604873, 0.04174153}, + {0.60856471, 0.37386072, 0.68937889, 0.21272655, 0.65082257}, + {0.44925012, 0.29825938, 0.20043074, 0.84906101, 0.78397795} + }, + { + {0.70855776, 0.17650269, 0.02422264, 0.84612297, 0.72450389}, + {0.05133022, 0.61175015, 0.56296539, 0.66780478, 0.63326012}, + {0.11212696, 0.50675282, 0.58170013, 0.21101392, 0.83090424}, + {0.91830915, 0.42113009, 0.49795942, 0.2814478, 0.11920788} + } + }; + double[][][] xmul = { + { + {0.10177144, 0.0887428, 0.04153505, 0.07615415, 0.03069209}, + {0.0058437, 0.06273996, 0.38919256, 0.00055383, 0.0210964}, + {0.07573484, 0.09007803, 0.33500089, 0.05456525, 0.16161162}, + {0.27173657, 0.00404111, 0.00997398, 0.05556795, 0.11125445} + }, + { + {0.05222982, 0.00741081, 0.01815711, 0.4374849, 0.04340629}, + {0.01536318, 0.06324873, 0.0969503, 0.05002163, 0.4328112}, + {0.03964879, 0.13257892, 0.21313739, 0.06077848, 0.39066002}, + {0.23341656, 0.14758288, 0.18685326, 0.07383407, 0.00906609} + } + }; + + @Test + public void testAdd() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i3 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Multiply instance = new Multiply<>(tf, TFloat64.class); + List> resultList = + instance.call( + Arrays.asList( + i1.getOutput(TFloat64.class), + i2.getOutput(TFloat64.class), + i3.getOutput(TFloat64.class)), + TFloat64.class); + + Operand result = resultList.get(0); + + assertArrayEquals(new long[] {Shape.UNKNOWN_SIZE, 4, 5}, result.shape().asArray()); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + Operand x3Op = tf.constant(x3); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0); + TFloat64 x3Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x3Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + feedMap.put(i3.getOutput(TFloat64.class), x3Tensor); + session.evaluate(tf.constant(xmul), result, feedMap); + } + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReLUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReLUTest.java new file mode 100644 index 00000000000..d2ad1988124 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReLUTest.java @@ -0,0 +1,130 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class ReLUTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + float[][][] inputArray = { + { + {2.70857435f, 8.25453567f, 9.75479311f, 1.10273526f}, + {8.69836437f, 2.27818352f, 8.60856328f, 1.43265882f}, + {0.75845834f, 5.60463474f, 7.35998787f, 0.06365667f} + }, + { + {4.87355239f, 9.90221978f, 5.39014402f, 2.05263398f}, + {5.91652733f, 0.9186602f, 0.91375672f, 0.56053326f}, + {2.08046551f, 8.53763374f, 6.40378721f, 5.83284758f} + } + }; + + @Test + public void testCallMaxValue10() { + + float[][][] expectedArray = { + { + {2.7085743f, 8.254536f, 9.754793f, 1.1027353f}, + {8.698364f, 2.2781835f, 8.608563f, 1.4326588f}, + {0.7584583f, 5.604635f, 7.3599877f, 0.06365667f} + }, + { + {4.8735523f, 9.90222f, 5.390144f, 2.052634f}, + {5.9165273f, 0.9186602f, 0.9137567f, 0.5605333f}, + {2.0804656f, 8.537634f, 6.403787f, 5.8328476f} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float maxValue = 10f; + ReLU instance = + new ReLU<>( + tf, + ReLU.DEFAULT_NEGATIVE_SLOPE, + maxValue, + ReLU.DEFAULT_THRESHOLD, + TFloat32.class, + Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = instance.call(tf.constant(inputArray), TFloat32.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } + + @Test + public void testCallNegativeSlope() { + + float[][][] expectedArray = { + { + {2.7085743f, 8.254536f, 9.754793f, 1.1027353f}, + {8.698364f, 2.2781835f, 8.608563f, 1.4326588f}, + {0.7584583f, 5.604635f, 7.3599877f, 0.06365667f} + }, + { + {4.8735523f, 9.90222f, 5.390144f, 2.052634f}, + {5.9165273f, 0.9186602f, 0.9137567f, 0.5605333f}, + {2.0804656f, 8.537634f, 6.403787f, 5.8328476f} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float negativeSlope = 0.2f; + ReLU instance = + new ReLU<>( + tf, + negativeSlope, + TFloat32.class, + Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = instance.call(tf.constant(inputArray), TFloat32.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } + + @Test + public void testCallMaxValue6() { + + double[][][] expectedArray = { + { + {2.7085743, 6., 6., 1.1027353}, + {6., 2.2781835, 6., 1.4326588}, + {0.7584583, 5.604635, 6., 0.06365667} + }, + { + {4.8735523, 6., 5.390144, 2.052634}, + {5.9165273, 0.9186602, 0.9137567, 0.5605333}, + {2.0804656, 6., 6., 5.8328476} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float maxValue = 6f; + ReLU instance = + new ReLU<>( + tf, + ReLU.DEFAULT_NEGATIVE_SLOPE, + maxValue, + ReLU.DEFAULT_THRESHOLD, + TFloat32.class, + Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = + instance.call(tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/RepeatVectorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/RepeatVectorTest.java new file mode 100644 index 00000000000..77d75999a3c --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/RepeatVectorTest.java @@ -0,0 +1,43 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class RepeatVectorTest { + + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCall3_2() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape shape = Shape.of(3, 2); + RepeatVector instance = new RepeatVector(tf, 3, TFloat32.class); + + double[][] array = { + {0.41448207, 0.71509451}, {0.21307868, 0.76890945}, {0.37533432, 0.7761148} + }; + + Shape expectedShape = Shape.of(3, 3, 2); + + double[][][] expected = { + {{0.41448206, 0.7150945}, {0.41448206, 0.7150945}, {0.41448206, 0.7150945}}, + {{0.21307868, 0.76890945}, {0.21307868, 0.76890945}, {0.21307868, 0.76890945}}, + {{0.37533432, 0.7761148}, {0.37533432, 0.7761148}, {0.37533432, 0.7761148}} + }; + + Operand result = + instance.call(tf.dtypes.cast(tf.constant(array), TFloat64.class), TFloat64.class); + + assertEquals(expectedShape, result.shape()); + session.evaluate(tf.constant(expected), result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReshapeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReshapeTest.java new file mode 100644 index 00000000000..a31c10973e9 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReshapeTest.java @@ -0,0 +1,126 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +class ReshapeTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + float[][][] inputArray = { + { + {2.70857435f, 8.25453567f, 9.75479311f, 1.10273526f}, + {8.69836437f, 2.27818352f, 8.60856328f, 1.43265882f}, + {0.75845834f, 5.60463474f, 7.35998787f, 0.06365667f} + }, + { + {4.87355239f, 9.90221978f, 5.39014402f, 2.05263398f}, + {5.91652733f, 0.9186602f, 0.91375672f, 0.56053326f}, + {2.08046551f, 8.53763374f, 6.40378721f, 5.83284758f} + } + }; + + float[][][] inputArrayNN2 = { + { + {2.70857435f, 8.25453567f}, {9.75479311f, 1.10273526f}, + {8.69836437f, 2.27818352f}, {8.60856328f, 1.43265882f}, + {0.75845834f, 5.60463474f}, {7.35998787f, 0.06365667f} + }, + { + {4.87355239f, 9.90221978f}, {5.39014402f, 2.05263398f}, + {5.91652733f, 0.9186602f}, {0.91375672f, 0.56053326f}, + {2.08046551f, 8.53763374f}, {6.40378721f, 5.83284758f} + } + }; + + @Test + public void testCall43() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape targetShape = Shape.of(4, 3); + long batchSize = 2; + Reshape instance = + new Reshape<>( + tf, + targetShape, + TFloat32.class, + Layer.Options.create().inputShape(Shape.of(batchSize, 3, 4))); + + Operand result = + instance.call(tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); + + assertArrayEquals(targetShape.prepend(batchSize).asArray(), result.shape().asArray()); + } + } + + @Test + public void testCallUnknown_1() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape targetShape = Shape.of(Shape.UNKNOWN_SIZE, 1); + long batchSize = 2; + Reshape instance = + new Reshape<>( + tf, + targetShape, + TFloat32.class, + Layer.Options.create().inputShape(Shape.of(batchSize, 3, 4))); + + Shape expectedShape = Shape.of(batchSize, 12, 1); + + Operand result = + instance.call(tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); + + assertArrayEquals(expectedShape.asArray(), result.shape().asArray()); + } + } + + @Test + public void testCall1_Unknown_1() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape targetShape = Shape.of(1, Shape.UNKNOWN_SIZE); + long batchSize = 2; + Reshape instance = + new Reshape<>( + tf, + targetShape, + TFloat32.class, + Layer.Options.create().inputShape(Shape.of(batchSize, 3, 4))); + + Operand result = + instance.call(tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); + + Shape expectedShape = Shape.of(batchSize, 1, 12); + assertArrayEquals(expectedShape.asArray(), result.shape().asArray()); + } + } + + @Test + public void testCallUnknownUnknown2() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape targetShape = Shape.of(Shape.UNKNOWN_SIZE, 1); + long batchSize = 2; + Reshape instance = + new Reshape<>( + tf, + targetShape, + TFloat32.class, + Layer.Options.create() + .inputShape(Shape.of(Shape.UNKNOWN_SIZE, Shape.UNKNOWN_SIZE, 2))); + + Operand result = + instance.call(tf.dtypes.cast(tf.constant(inputArrayNN2), TFloat64.class), TFloat64.class); + + Shape expectedShape = Shape.of(batchSize, 12, 1); + assertArrayEquals(expectedShape.asArray(), result.shape().asArray()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/SequentialLayersTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/SequentialLayersTest.java new file mode 100644 index 00000000000..11879b85ff2 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/SequentialLayersTest.java @@ -0,0 +1,67 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +import java.util.Arrays; +import java.util.List; + +public class SequentialLayersTest { + + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + /** Tests executing a thread through sequential layers. */ + @Test + public void testSequentialLayers() { + long seed = 1001L; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[][] inputArray = { + {6.600953f, 4.659476f}, + {6.943807f, 2.113826f}, + {4.667166f, 6.931125f}, + {7.716860f, 3.205337f} + }; + Operand input = tf.constant(inputArray); + List> sequencedLayers = + Arrays.asList( + new Input<>(tf, input, TFloat32.class), + new Dense<>(tf, 3, seed, TFloat32.class), + new Dropout<>(tf, 0.3f, seed, TFloat32.class), + new Flatten<>(tf, TFloat32.class)); + + Operand result = input; + for (Layer layer : sequencedLayers) { + result = layer.call(result, TFloat32.class); + } + session.run(tf.init()); + float[][] expected = + new float[][] { + {0f, 12.723986f, 0f}, + {0f, 0f, 0f}, + {7.344245f, 11.509995f, 7.875926f}, + {8.921803f, 0f, 0f} + }; + + session.evaluate(tf.constant(expected), result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/SubtractTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/SubtractTest.java new file mode 100644 index 00000000000..2a20e999907 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/SubtractTest.java @@ -0,0 +1,191 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class SubtractTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + double[][][] x1 = { + { + {0.13570025, 0.55677077, 0.06648757, 0.58405729, 0.61086578}, + {0.18659685, 0.39331301, 0.68069423, 0.09510652, 0.86098578}, + {0.99338463, 0.37543824, 0.74858191, 0.31828287, 0.94056888}, + {0.76202298, 0.05605309, 0.73475366, 0.9313434, 0.48606332} + }, + { + {0.13023652, 0.39567908, 0.89910993, 0.71302943, 0.73722061}, + {0.6212917, 0.62624375, 0.8184835, 0.16864979, 0.96730508}, + {0.39645622, 0.35834793, 0.39924944, 0.90297727, 0.82857399}, + {0.70014157, 0.95498672, 0.6179583, 0.33104554, 0.11978174} + } + }; + double[][][] x2 = { + { + {0.82828211, 0.28889298, 0.7159566, 0.93377237, 0.32654201}, + {0.73234341, 0.17123203, 0.62582661, 0.96272026, 0.58700802}, + {0.12527705, 0.64175689, 0.64915537, 0.80589999, 0.26400939}, + {0.79376476, 0.24171677, 0.0677271, 0.07027092, 0.29195821} + }, + { + {0.56599224, 0.10611362, 0.83370522, 0.72514044, 0.08126704}, + {0.48173969, 0.16509515, 0.21040572, 0.44414272, 0.70656624}, + {0.89191749, 0.73008498, 0.9177326, 0.31897888, 0.56743576}, + {0.36304201, 0.36696309, 0.60722209, 0.79244879, 0.63492784} + } + }; + + double[][][] xsub = { + { + {-0.69258186, 0.26787779, -0.64946903, -0.34971508, 0.28432377}, + {-0.54574656, 0.22208098, 0.05486762, -0.86761374, 0.27397776}, + {0.86810758, -0.26631865, 0.09942654, -0.48761712, 0.67655949}, + {-0.03174178, -0.18566368, 0.66702656, 0.86107248, 0.19410511} + }, + { + {-0.43575572, 0.28956546, 0.06540471, -0.01211101, 0.65595357}, + {0.13955201, 0.4611486, 0.60807778, -0.27549293, 0.26073884}, + {-0.49546127, -0.37173705, -0.51848316, 0.58399839, 0.26113823}, + {0.33709956, 0.58802363, 0.01073621, -0.46140325, -0.5151461} + } + }; + + @Test + public void testSubtract() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Subtract instance = new Subtract<>(tf, TFloat64.class); + List> resultList = + instance.call( + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), + TFloat64.class); + + Operand result = resultList.get(0); + + assertArrayEquals(new long[] {Shape.UNKNOWN_SIZE, 4, 5}, result.shape().asArray()); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + session.evaluate(tf.constant(xsub), result, feedMap); + } + } + } + + @Test + public void testSubtractInvalidInputsLength() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Subtract instance = new Subtract<>(tf, TFloat64.class); + + // not used, should throw exception + List> resultList = + instance.call( + Arrays.asList( + i1.getOutput(TFloat64.class), + i2.getOutput(TFloat64.class), + i2.getOutput(TFloat64.class)), + TFloat64.class); + } + }); + } + + @Test + public void testMask() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + + Subtract instance = new Subtract<>(tf, TFloat64.class); + List> inputs = + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)); + List> mask = Arrays.asList(null, null); + List> result = instance.computeMask(inputs, mask); + assertNull(result); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + mask = Arrays.asList(x1Op, x2Op); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + result = instance.computeMask(inputs, mask); + Boolean[] expected = new Boolean[(int) result.get(0).size()]; + Arrays.fill(expected, true); + session.evaluate(expected, result.get(0), feedMap); + } + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ThresholdedReLUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ThresholdedReLUTest.java new file mode 100644 index 00000000000..bc4ac177989 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ThresholdedReLUTest.java @@ -0,0 +1,56 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat64; + +class ThresholdedReLUTest { + + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + double[][][] inputArray = { + { + {2.70857435f, 8.25453567f, 9.75479311f, 1.10273526f}, + {8.69836437f, 2.27818352f, 8.60856328f, 1.43265882f}, + {0.75845834f, 5.60463474f, 7.35998787f, 0.06365667f} + }, + { + {4.87355239f, 9.90221978f, 5.39014402f, 2.05263398f}, + {5.91652733f, 0.9186602f, 0.91375672f, 0.56053326f}, + {2.08046551f, 8.53763374f, 6.40378721f, 5.83284758f} + } + }; + + @Test + public void testCallThetaPoint5() { + + double[][][] expectedArray = { + { + {2.7085743, 8.254536, 9.754793, 1.1027353}, + {8.698364, 2.2781835, 8.608563, 1.4326588}, + {0.7584583, 5.604635, 7.3599877, 0.} + }, + { + {4.8735523, 9.90222, 5.390144, 2.052634}, + {5.9165273, 0.9186602, 0.9137567, 0.5605333}, + {2.0804656, 8.537634, 6.403787, 5.8328476} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float theta = 0.5f; + ThresholdedReLU instance = + new ThresholdedReLU<>( + tf, theta, TFloat64.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = instance.call(tf.constant(inputArray), TFloat64.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/impl/InputSpecTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/impl/InputSpecTest.java new file mode 100644 index 00000000000..1a8336e33f5 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/impl/InputSpecTest.java @@ -0,0 +1,68 @@ +package org.tensorflow.framework.layers.impl; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.Shape; + +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class InputSpecTest { + + @Test + public void testAxis() { + + InputSpec instance = + new InputSpec( + InputSpec.Options.create() + .shape(Shape.of(1, Shape.UNKNOWN_SIZE, 2, 3)) + .axesMap(3, 5L) + .axesMap(2, 2L)); + + assertThrows( + java.lang.IllegalArgumentException.class, + () -> { + InputSpec instance1 = + new InputSpec( + InputSpec.Options.create() + .shape(Shape.of(1, Shape.UNKNOWN_SIZE, 2, 3)) + .axesMap(4, 5L)); + }); + } + + @Test + public void testDefinedShape() { + Shape expected = Shape.of(1, Shape.UNKNOWN_SIZE, 2, 3); + InputSpec instance = new InputSpec(InputSpec.Options.create().shape(expected)); + assertArrayEquals(expected.asArray(), instance.toShape().asArray()); + } + + @Test + public void testDefinedRank() { + InputSpec instance = new InputSpec(InputSpec.Options.create().rank(5)); + long[] dims = new long[5]; + Arrays.fill(dims, Shape.UNKNOWN_SIZE); + assertArrayEquals(dims, instance.toShape().asArray()); + + instance = new InputSpec(InputSpec.Options.create().rank(0)); + dims = new long[0]; + assertArrayEquals(dims, instance.toShape().asArray()); + + instance = new InputSpec(InputSpec.Options.create().rank(3).axesMap(1, 3L).axesMap(-1, 2L)); + dims = new long[] {Shape.UNKNOWN_SIZE, 3, 2}; + assertArrayEquals(dims, instance.toShape().asArray()); + } + + @Test + public void testUndefinedShapes() { + InputSpec instance = new InputSpec(InputSpec.Options.create().maxRank(5)); + Shape genShaped = instance.toShape(); + assertTrue(genShaped.isUnknown()); + + instance = new InputSpec(InputSpec.Options.create().minRank(5).maxRank(5)); + genShaped = instance.toShape(); + assertTrue(genShaped.isUnknown()); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/op/LinalgOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/LinalgOpsTest.java new file mode 100644 index 00000000000..76d86a95e85 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/LinalgOpsTest.java @@ -0,0 +1,60 @@ +package org.tensorflow.framework.op; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class LinalgOpsTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void test2D() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + Operand a = tf.constant(new float[][] {{3.7213619f}}); + Operand b = tf.constant(new float[][] {{8.153921f}}); + + Operand ans = fops.linalg.matmul(a, b); + Operand expected = tf.constant(new float[][] {{30.34369f}}); + session.evaluate(expected, ans); + + Operand a64 = + tf.constant(new double[][] {{-8.944851}, {4.1711287}, {-0.22380222}}); + Operand b64 = + tf.constant( + new double[][] {{-14.276086, -12.433481, -2.2447076, -1.5775859, 1.8588694}}); + + Operand ans64 = fops.linalg.matmul(a64, b64); + Operand expected64 = + tf.constant( + new double[][] { + {127.69746, 111.21564, 20.078575, 14.111271, -16.62731}, + {-59.547394, -51.861652, -9.362965, -6.580314, 7.753584}, + {3.1950197, 2.7826407, 0.50237054, 0.35306725, -0.4160191} + }); + session.evaluate(expected64, ans64); + + a64 = + tf.constant( + new double[][] { + {-9.189821, -1.588742, -8.684379}, + {-10.953391, -8.473055, -6.8909864}, + {-11.712155, -6.6350083, -2.4441578}, + {1.4037079, -11.279383, 0.9129576}, + {0.11368857, 2.3792067, -11.218701}, + }); + b64 = tf.constant(new double[][] {{-4.933953}, {-12.692161}, {-10.192119}}); + ans64 = fops.linalg.matmul(a64, b64); + expected64 = + tf.constant( + new double[][] {{154.01892}, {231.81863}, {166.91096}, {126.92895}, {83.58413}}); + TestSession.setEpsilon(1e-4f); + session.evaluate(expected64, ans64); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/op/MathOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/MathOpsTest.java new file mode 100644 index 00000000000..dda5a7c6eaa --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/MathOpsTest.java @@ -0,0 +1,501 @@ +package org.tensorflow.framework.op; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +class MathOpsTest { + + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + double[][][] array = + new double[][][] { + { + {4.17021990e-01, 7.20324516e-01, 1.14374816e-04}, + {3.02332580e-01, 1.46755889e-01, 9.23385918e-02}, + {1.86260208e-01, 3.45560730e-01, 3.96767467e-01}, + {5.38816750e-01, 4.19194520e-01, 6.85219526e-01}, + {2.04452246e-01, 8.78117442e-01, 2.73875929e-02}, + {6.70467496e-01, 4.17304814e-01, 5.58689833e-01}, + {1.40386939e-01, 1.98101491e-01, 8.00744593e-01} + }, + { + {9.68261600e-01, 3.13424170e-01, 6.92322612e-01}, + {8.76389146e-01, 8.94606650e-01, 8.50442126e-02}, + {3.90547849e-02, 1.69830427e-01, 8.78142476e-01}, + {9.83468369e-02, 4.21107620e-01, 9.57889557e-01}, + {5.33165276e-01, 6.91877127e-01, 3.15515637e-01}, + {6.86500907e-01, 8.34625661e-01, 1.82882771e-02}, + {7.50144303e-01, 9.88861084e-01, 7.48165667e-01} + }, + { + {2.80443996e-01, 7.89279342e-01, 1.03226006e-01}, + {4.47893530e-01, 9.08595502e-01, 2.93614149e-01}, + {2.87775338e-01, 1.30028576e-01, 1.93669572e-02}, + {6.78835511e-01, 2.11628109e-01, 2.65546650e-01}, + {4.91573155e-01, 5.33625446e-02, 5.74117601e-01}, + {1.46728575e-01, 5.89305520e-01, 6.99758351e-01}, + {1.02334432e-01, 4.14055973e-01, 6.94400132e-01} + }, + { + {4.14179265e-01, 4.99534607e-02, 5.35896420e-01}, + {6.63794637e-01, 5.14889121e-01, 9.44594741e-01}, + {5.86555064e-01, 9.03401911e-01, 1.37474701e-01}, + {1.39276341e-01, 8.07391286e-01, 3.97676826e-01}, + {1.65354192e-01, 9.27508593e-01, 3.47765863e-01}, + {7.50812113e-01, 7.25997984e-01, 8.83306086e-01}, + {6.23672187e-01, 7.50942409e-01, 3.48898351e-01} + }, + { + {2.69927889e-01, 8.95886242e-01, 4.28091198e-01}, + {9.64840055e-01, 6.63441479e-01, 6.21695697e-01}, + {1.14745975e-01, 9.49489236e-01, 4.49912131e-01}, + {5.78389585e-01, 4.08136815e-01, 2.37026975e-01}, + {9.03379500e-01, 5.73679507e-01, 2.87032709e-03}, + {6.17144942e-01, 3.26644897e-01, 5.27058125e-01}, + {8.85942101e-01, 3.57269764e-01, 9.08535123e-01} + }, + { + {6.23360097e-01, 1.58212427e-02, 9.29437220e-01}, + {6.90896928e-01, 9.97322857e-01, 1.72340512e-01}, + {1.37135744e-01, 9.32595491e-01, 6.96818173e-01}, + {6.60001710e-02, 7.55463064e-01, 7.53876209e-01}, + {9.23024535e-01, 7.11524785e-01, 1.24270961e-01}, + {1.98801346e-02, 2.62109861e-02, 2.83064879e-02}, + {2.46211067e-01, 8.60027969e-01, 5.38831055e-01} + }, + { + {5.52821994e-01, 8.42030883e-01, 1.24173313e-01}, + {2.79183686e-01, 5.85759282e-01, 9.69595730e-01}, + {5.61030209e-01, 1.86472889e-02, 8.00632656e-01}, + {2.32974276e-01, 8.07105184e-01, 3.87860656e-01}, + {8.63541842e-01, 7.47121632e-01, 5.56240261e-01}, + {1.36455223e-01, 5.99176884e-02, 1.21343456e-01}, + {4.45518792e-02, 1.07494131e-01, 2.25709334e-01} + }, + { + {7.12988973e-01, 5.59717000e-01, 1.25559801e-02}, + {7.19742775e-02, 9.67276335e-01, 5.68100452e-01}, + {2.03293234e-01, 2.52325743e-01, 7.43825853e-01}, + {1.95429474e-01, 5.81358910e-01, 9.70019996e-01}, + {8.46828818e-01, 2.39847764e-01, 4.93769705e-01}, + {6.19955719e-01, 8.28980923e-01, 1.56791389e-01}, + {1.85762029e-02, 7.00221434e-02, 4.86345112e-01} + }, + { + {6.06329441e-01, 5.68851411e-01, 3.17362398e-01}, + {9.88616168e-01, 5.79745233e-01, 3.80141169e-01}, + {5.50948203e-01, 7.45334446e-01, 6.69232905e-01}, + {2.64919549e-01, 6.63348362e-02, 3.70084196e-01}, + {6.29717529e-01, 2.10174009e-01, 7.52755582e-01}, + {6.65364787e-02, 2.60315090e-01, 8.04754555e-01}, + {1.93434283e-01, 6.39460862e-01, 5.24670303e-01} + }, + { + {9.24807966e-01, 2.63296783e-01, 6.59610927e-02}, + {7.35065937e-01, 7.72178054e-01, 9.07815874e-01}, + {9.31972086e-01, 1.39515726e-02, 2.34362081e-01}, + {6.16778374e-01, 9.49016333e-01, 9.50176120e-01}, + {5.56653202e-01, 9.15606380e-01, 6.41566217e-01}, + {3.90007704e-01, 4.85990673e-01, 6.04310513e-01}, + {5.49547911e-01, 9.26181436e-01, 9.18733418e-01} + }, + { + {3.94875616e-01, 9.63262558e-01, 1.73955664e-01}, + {1.26329526e-01, 1.35079160e-01, 5.05662143e-01}, + {2.15248056e-02, 9.47970212e-01, 8.27115476e-01}, + {1.50189810e-02, 1.76196262e-01, 3.32063586e-01}, + {1.30996838e-01, 8.09490681e-01, 3.44736665e-01}, + {9.40107465e-01, 5.82014203e-01, 8.78831983e-01}, + {8.44734430e-01, 9.05392289e-01, 4.59880263e-01} + }, + { + {5.46346843e-01, 7.98603594e-01, 2.85718858e-01}, + {4.90253508e-01, 5.99110305e-01, 1.55332759e-02}, + {5.93481421e-01, 4.33676362e-01, 8.07360530e-01}, + {3.15244794e-01, 8.92888725e-01, 5.77857196e-01}, + {1.84010208e-01, 7.87929237e-01, 6.12031162e-01}, + {5.39092720e-02, 4.20193672e-01, 6.79068863e-01}, + {9.18601751e-01, 4.02024889e-04, 9.76759136e-01} + }, + { + {3.76580328e-01, 9.73783553e-01, 6.04716122e-01}, + {8.28845799e-01, 5.74711502e-01, 6.28076196e-01}, + {2.85576284e-01, 5.86833358e-01, 7.50021756e-01}, + {8.58313859e-01, 7.55082190e-01, 6.98057234e-01}, + {8.64479423e-01, 3.22681010e-01, 6.70788765e-01}, + {4.50873941e-01, 3.82102758e-01, 4.10811365e-01}, + {4.01479572e-01, 3.17383945e-01, 6.21919394e-01} + }, + { + {4.30247277e-01, 9.73802090e-01, 6.77800894e-01}, + {1.98569894e-01, 4.26701009e-01, 3.43346238e-01}, + {7.97638834e-01, 8.79998267e-01, 9.03841972e-01}, + {6.62719786e-01, 2.70208269e-01, 2.52366692e-01}, + {8.54897916e-01, 5.27714670e-01, 8.02161098e-01}, + {5.72488546e-01, 7.33142555e-01, 5.19011617e-01}, + {7.70883918e-01, 5.68857968e-01, 4.65709865e-01} + }, + { + {3.42688918e-01, 6.82093501e-02, 3.77924174e-01}, + {7.96260759e-02, 9.82817113e-01, 1.81612849e-01}, + {8.11858714e-01, 8.74961674e-01, 6.88413262e-01}, + {5.69494426e-01, 1.60971433e-01, 4.66880023e-01}, + {3.45172048e-01, 2.25039959e-01, 5.92511892e-01}, + {3.12269837e-01, 9.16305542e-01, 9.09635544e-01}, + {2.57118285e-01, 1.10891297e-01, 1.92962736e-01} + }, + { + {4.99584168e-01, 7.28585660e-01, 2.08194435e-01}, + {2.48033553e-01, 8.51671875e-01, 4.15848732e-01}, + {6.16685092e-01, 2.33666137e-01, 1.01967260e-01}, + {5.15857041e-01, 4.77140993e-01, 1.52671650e-01}, + {6.21806204e-01, 5.44010103e-01, 6.54137373e-01}, + {1.44545540e-01, 7.51527846e-01, 2.22049147e-01}, + {5.19351840e-01, 7.85296023e-01, 2.23304275e-02} + }, + { + {3.24362457e-01, 8.72922361e-01, 8.44709635e-01}, + {5.38440585e-01, 8.66608262e-01, 9.49805975e-01}, + {8.26407015e-01, 8.54115427e-01, 9.87434015e-02}, + {6.51304305e-01, 7.03516960e-01, 6.10240817e-01}, + {7.99615264e-01, 3.45712192e-02, 7.70238757e-01}, + {7.31728613e-01, 2.59698391e-01, 2.57069290e-01}, + {6.32303298e-01, 3.45297456e-01, 7.96588659e-01} + }, + { + {4.46146220e-01, 7.82749414e-01, 9.90471780e-01}, + {3.00248325e-01, 1.43005833e-01, 9.01308417e-01}, + {5.41559398e-01, 9.74740386e-01, 6.36604428e-01}, + {9.93912995e-01, 5.46070814e-01, 5.26425958e-01}, + {1.35427907e-01, 3.55705172e-01, 2.62185670e-02}, + {1.60395175e-01, 7.45637178e-01, 3.03996895e-02}, + {3.66543084e-01, 8.62346232e-01, 6.92677736e-01} + }, + { + {6.90942168e-01, 1.88636795e-01, 4.41904277e-01}, + {5.81577420e-01, 9.89751697e-01, 2.03906223e-01}, + {2.47732908e-01, 2.62173086e-01, 7.50172436e-01}, + {4.56975341e-01, 5.69294393e-02, 5.08516252e-01}, + {2.11960167e-01, 7.98604250e-01, 2.97331393e-01}, + {2.76060123e-02, 5.93432426e-01, 8.43840420e-01}, + {3.81016135e-01, 7.49858320e-01, 5.11141479e-01} + }, + { + {5.40951788e-01, 9.59434330e-01, 8.03960919e-01}, + {3.23230661e-02, 7.09387243e-01, 4.65001494e-01}, + {9.47548926e-01, 2.21432731e-01, 2.67072022e-01}, + {8.14739615e-02, 4.28618819e-01, 1.09018765e-01}, + {6.33786738e-01, 8.02963257e-01, 6.96800470e-01}, + {7.66211390e-01, 3.42454106e-01, 8.45851481e-01}, + {4.28768784e-01, 8.24009895e-01, 6.26496136e-01} + } + }; + + double[][][] expectedArray = { + { + {3.45350616e-02, 5.96526116e-02, 9.47178160e-06}, + {2.50372272e-02, 1.21533722e-02, 7.64688430e-03}, + {1.54248644e-02, 2.86171008e-02, 3.28577124e-02}, + {4.46213149e-02, 3.47149745e-02, 5.67454435e-02}, + {1.69314109e-02, 7.27199987e-02, 2.26806314e-03}, + {5.55237755e-02, 3.45584825e-02, 4.62670736e-02}, + {1.16259372e-02, 1.64054818e-02, 6.63124844e-02} + }, + { + {8.01851526e-02, 2.59557609e-02, 5.73336743e-02}, + {7.25768730e-02, 7.40855262e-02, 7.04281079e-03}, + {3.23426444e-03, 1.40642561e-02, 7.27220699e-02}, + {8.14444851e-03, 3.48734073e-02, 7.93262124e-02}, + {4.41532955e-02, 5.72967827e-02, 2.61289626e-02}, + {5.68515584e-02, 6.91182911e-02, 1.51451665e-03}, + {6.21220917e-02, 8.18910673e-02, 6.19582348e-02} + }, + { + {2.32245550e-02, 6.53630048e-02, 8.54850933e-03}, + {3.70916426e-02, 7.52439946e-02, 2.43152231e-02}, + {2.38316897e-02, 1.07681248e-02, 1.60384597e-03}, + {5.62167615e-02, 1.75256692e-02, 2.19908543e-02}, + {4.07089069e-02, 4.41914052e-03, 4.75447029e-02}, + {1.21511100e-02, 4.88024652e-02, 5.79494536e-02}, + {8.47467501e-03, 3.42894346e-02, 5.75057231e-02} + }, + { + {3.42996456e-02, 4.13682219e-03, 4.43794727e-02}, + {5.49711734e-02, 4.26397808e-02, 7.82252178e-02}, + {4.85746935e-02, 7.48138949e-02, 1.13847647e-02}, + {1.15339644e-02, 6.68629184e-02, 3.29330191e-02}, + {1.36935636e-02, 7.68102556e-02, 2.87997164e-02}, + {6.21773973e-02, 6.01224527e-02, 7.31496885e-02}, + {5.16484901e-02, 6.21881858e-02, 2.88935024e-02} + }, + { + {2.23536789e-02, 7.41914958e-02, 3.54517400e-02}, + {7.99018070e-02, 5.49419262e-02, 5.14848121e-02}, + {9.50251892e-03, 7.86305517e-02, 3.72588076e-02}, + {4.78984788e-02, 3.37992460e-02, 1.96290389e-02}, + {7.48120397e-02, 4.75084223e-02, 2.37701897e-04}, + {5.11079468e-02, 2.70506144e-02, 4.36475389e-02}, + {7.33679906e-02, 2.95867678e-02, 7.52389953e-02} + }, + { + {5.16226478e-02, 1.31021289e-03, 7.69699737e-02}, + {5.72156087e-02, 8.25918168e-02, 1.42721254e-02}, + {1.13566946e-02, 7.72315189e-02, 5.77059686e-02}, + {5.46570681e-03, 6.25625551e-02, 6.24311455e-02}, + {7.64389113e-02, 5.89238741e-02, 1.02913165e-02}, + {1.64634397e-03, 2.17062421e-03, 2.34416011e-03}, + {2.03896053e-02, 7.12219477e-02, 4.46224995e-02} + }, + { + {4.57811356e-02, 6.97315410e-02, 1.02832299e-02}, + {2.31201854e-02, 4.85087894e-02, 8.02956372e-02}, + {4.64608893e-02, 1.54424773e-03, 6.63032085e-02}, + {1.92934200e-02, 6.68392256e-02, 3.21201086e-02}, + {7.15129450e-02, 6.18717745e-02, 4.60642166e-02}, + {1.13003375e-02, 4.96199494e-03, 1.00488793e-02}, + {3.68949817e-03, 8.90196767e-03, 1.86917856e-02} + }, + { + {5.90451285e-02, 4.63521369e-02, 1.03980501e-03}, + {5.96044352e-03, 8.01035613e-02, 4.70464006e-02}, + {1.68354288e-02, 2.08959840e-02, 6.15988411e-02}, + {1.61842033e-02, 4.81443815e-02, 8.03307742e-02}, + {7.01288804e-02, 1.98626388e-02, 4.08908091e-02}, + {5.13407178e-02, 6.86508343e-02, 1.29844472e-02}, + {1.53836084e-03, 5.79878036e-03, 4.02759537e-02} + }, + { + {5.02122790e-02, 4.71085906e-02, 2.62818988e-02}, + {8.18707868e-02, 4.80107442e-02, 3.14808302e-02}, + {4.56259623e-02, 6.17237724e-02, 5.54215349e-02}, + {2.19389219e-02, 5.49342157e-03, 3.06479763e-02}, + {5.21491282e-02, 1.74052510e-02, 6.23383410e-02}, + {5.51012019e-03, 2.15576105e-02, 6.66445568e-02}, + {1.60189737e-02, 5.29560074e-02, 4.34497967e-02} + }, + { + {7.65866041e-02, 2.18045339e-02, 5.46247046e-03}, + {6.08734004e-02, 6.39467835e-02, 7.51794279e-02}, + {7.71798939e-02, 1.15537888e-03, 1.94083489e-02}, + {5.10775894e-02, 7.85913840e-02, 7.86874294e-02}, + {4.60984148e-02, 7.58245885e-02, 5.31303585e-02}, + {3.22979130e-02, 4.02465984e-02, 5.00450842e-02}, + {4.55099978e-02, 7.67003447e-02, 7.60835484e-02} + }, + { + {3.27010415e-02, 7.97711685e-02, 1.44058811e-02}, + {1.04617933e-02, 1.11863809e-02, 4.18756641e-02}, + {1.78254500e-03, 7.85047561e-02, 6.84963465e-02}, + {1.24377478e-03, 1.45914331e-02, 2.74993554e-02}, + {1.08483098e-02, 6.70367777e-02, 2.85488572e-02}, + {7.78536126e-02, 4.81986478e-02, 7.27791712e-02}, + {6.99554384e-02, 7.49787241e-02, 3.80843058e-02} + }, + { + {4.52449061e-02, 6.61351755e-02, 2.36613862e-02}, + {4.05996218e-02, 4.96144369e-02, 1.28636532e-03}, + {4.91482876e-02, 3.59142683e-02, 6.68603703e-02}, + {2.61065327e-02, 7.39432648e-02, 4.78543900e-02}, + {1.52385337e-02, 6.52511939e-02, 5.06844558e-02}, + {4.46441676e-03, 3.47977169e-02, 5.62360846e-02}, + {7.60726482e-02, 3.32930977e-05, 8.08888674e-02} + }, + { + {3.11859436e-02, 8.06424469e-02, 5.00786714e-02}, + {6.86396435e-02, 4.75938842e-02, 5.20132035e-02}, + {2.36495789e-02, 4.85977381e-02, 6.21119440e-02}, + {7.10799918e-02, 6.25310168e-02, 5.78085780e-02}, + {7.15905875e-02, 2.67223511e-02, 5.55503815e-02}, + {3.73384580e-02, 3.16432752e-02, 3.40207368e-02}, + {3.32479365e-02, 2.62836833e-02, 5.15033379e-02} + }, + { + {3.56302932e-02, 8.06439817e-02, 5.61310798e-02}, + {1.64442733e-02, 3.53366137e-02, 2.84337122e-02}, + {6.60552830e-02, 7.28757605e-02, 7.48503357e-02}, + {5.48821613e-02, 2.23768987e-02, 2.08993759e-02}, + {7.07971081e-02, 4.37019095e-02, 6.64297864e-02}, + {4.74097952e-02, 6.07141182e-02, 4.29811813e-02}, + {6.38396144e-02, 4.71091345e-02, 3.85670736e-02} + }, + { + {2.83792764e-02, 5.64865675e-03, 3.12972330e-02}, + {6.59411587e-03, 8.13905448e-02, 1.50400000e-02}, + {6.72328845e-02, 7.24586621e-02, 5.70099279e-02}, + {4.71618399e-02, 1.33306114e-02, 3.86639796e-02}, + {2.85849143e-02, 1.86363515e-02, 4.90679964e-02}, + {2.58601662e-02, 7.58824944e-02, 7.53301233e-02}, + {2.12928709e-02, 9.18329880e-03, 1.59799233e-02} + }, + { + {4.13723253e-02, 6.03367463e-02, 1.72413141e-02}, + {2.05405317e-02, 7.05299526e-02, 3.44378985e-02}, + {5.10698669e-02, 1.93507168e-02, 8.44426826e-03}, + {4.27199379e-02, 3.95137258e-02, 1.26432776e-02}, + {5.14939614e-02, 4.50513922e-02, 5.41714206e-02}, + {1.19703254e-02, 6.22366704e-02, 1.83886718e-02}, + {4.30093557e-02, 6.50331303e-02, 1.84926135e-03} + }, + { + {2.68615987e-02, 7.22897798e-02, 6.99533820e-02}, + {4.45901640e-02, 7.17668831e-02, 7.86567777e-02}, + {6.84376806e-02, 7.07323104e-02, 8.17728881e-03}, + {5.39368056e-02, 5.82607202e-02, 5.05361930e-02}, + {6.62189573e-02, 2.86296452e-03, 6.37861863e-02}, + {6.05970249e-02, 2.15065386e-02, 2.12888140e-02}, + {5.23632653e-02, 2.85952985e-02, 6.59683123e-02} + }, + { + {3.69469412e-02, 6.48222342e-02, 8.20244551e-02}, + {2.48646215e-02, 1.18428171e-02, 7.46405274e-02}, + {4.48484421e-02, 8.07216838e-02, 5.27194552e-02}, + {8.23094398e-02, 4.52220477e-02, 4.35951874e-02}, + {1.12152621e-02, 2.94571985e-02, 2.17125192e-03}, + {1.32828895e-02, 6.17488436e-02, 2.51750532e-03}, + {3.03547252e-02, 7.14139268e-02, 5.73630854e-02} + }, + { + {5.72193563e-02, 1.56216780e-02, 3.65956500e-02}, + {4.81624752e-02, 8.19648281e-02, 1.68861933e-02}, + {2.05156356e-02, 2.17114780e-02, 6.21244237e-02}, + {3.78437378e-02, 4.71452763e-03, 4.21120226e-02}, + {1.75531674e-02, 6.61352351e-02, 2.46230606e-02}, + {2.28615105e-03, 4.91442308e-02, 6.98814020e-02}, + {3.15532871e-02, 6.20984100e-02, 4.23294269e-02} + }, + { + {4.47981246e-02, 7.94541389e-02, 6.65788352e-02}, + {2.67678709e-03, 5.87468557e-02, 3.85084115e-02}, + {7.84698650e-02, 1.83376241e-02, 2.21171752e-02}, + {6.74714567e-03, 3.54954340e-02, 9.02822800e-03}, + {5.24861142e-02, 6.64962158e-02, 5.77045009e-02}, + {6.34526685e-02, 2.83598304e-02, 7.00479448e-02}, + {3.55078541e-02, 6.82391599e-02, 5.18823527e-02} + } + }; + + @Test + public void testL2Normalize() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + Operand input = tf.constant(array); + Operand result = fops.math.l2Normalize(tf.constant(array), new int[] {0, 1, 2}); + session.evaluate(tf.constant(expectedArray), result); + } + } + + @Test + public void testConfusionMatrix() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + long[] labels = new long[] {2, 0, 2, 2, 0, 1}; + long[] predictions = new long[] {0, 0, 2, 2, 0, 2}; + Operand result = + fops.math.confusionMatrix(tf.constant(labels), tf.constant(predictions)); + long[][] expected = + new long[][] { + {2, 0, 0}, + {0, 0, 1}, + {1, 0, 2} + }; + session.evaluate(tf.constant(expected), result); + } + } + + @Test + public void testTensorDotValid() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + int[] axes1 = new int[] {1, 2}; + int[][] axes2 = new int[][] {{1}, {2}}; + int[][] axes3 = new int[2][0]; + int axes4 = 0; + + Operand a = tf.ones(tf.constant(Shape.of(3, 3)), TFloat32.class); + Operand b = tf.constant(new float[][][] {{{2, 3, 1}}}); + + Operand ans = fops.math.tensordot(a, b, axes1); + Operand expected = tf.constant(new float[][][] {{{6}}, {{6}}, {{6}}}); + session.evaluate(expected, ans); + + ans = fops.math.tensordot(a, b, axes2); + expected = tf.constant(new float[][][] {{{6}}, {{6}}, {{6}}}); + session.evaluate(expected, ans); + + ans = fops.math.tensordot(a, b, axes3); + + float[][][][][] expectedArray = + new float[][][][][] { + {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}}, + {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}}, + {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}} + }; + ans = fops.math.tensordot(a, b, axes3); + expected = tf.constant(expectedArray); + session.evaluate(expected, ans); + + ans = fops.math.tensordot(a, b, axes4); + expected = tf.constant(expectedArray); + session.evaluate(expected, ans); + } + } + + @Test + public void testTensorDotInValidAxis() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + Operand a = tf.constant(new float[][] {{1, 2}, {3, 4}}); + Operand b = tf.constant(new float[][] {{1, 2}, {3, 4}}); + assertThrows(IllegalArgumentException.class, () -> fops.math.tensordot(a, b, -1)); + assertThrows(IllegalArgumentException.class, () -> fops.math.tensordot(a, b, 3)); + assertThrows( + IllegalArgumentException.class, () -> fops.math.tensordot(a, b, new int[] {1})); + assertThrows( + IllegalArgumentException.class, () -> fops.math.tensordot(a, b, new int[][] {{1}})); + assertThrows( + IllegalArgumentException.class, + () -> fops.math.tensordot(a, b, new int[][] {{1}, {0, 1}})); + + assertThrows( + ArrayIndexOutOfBoundsException.class, + () -> fops.math.tensordot(a, b, new int[][] {{0}, {7}})); + } + } + + @Test + public void testReduceLogSumExp() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + Operand x = + tf.constant( + new float[][] { + {0.43346116f, 0.8569728f, 0.57155997f, 0.0743812f, 0.63846475f}, + {0.8165283f, 0.26554802f, 0.37025765f, 0.8255019f, 0.45682374f}, + {0.93511814f, 0.52291054f, 0.80983895f, 0.11580781f, 0.8111686f}, + {0.49967498f, 0.27537802f, 0.48554695f, 0.28238368f, 0.7989301f}, + {0.8958915f, 0.84870094f, 0.56874424f, 0.08818512f, 0.13915819f} + }); + + Operand result = fops.math.reduceLogSumExp(x, new int[] {0, 1}, false); + session.evaluate(3.7911222f, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/op/NnOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/NnOpsTest.java new file mode 100644 index 00000000000..1f538493d91 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/NnOpsTest.java @@ -0,0 +1,396 @@ +package org.tensorflow.framework.op; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class NnOpsTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testSigmoidCrossEntropyWithLogits() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + float[] x = new float[] {-100, -2, -2, 0, 2, 2, 2, 100}; + float[] y = new float[] {0, 0, 1, 0, 0, 1, 0.5f, 1}; + + Operand logits = tf.constant(x); + Operand targets = tf.constant(y); + Operand loss = fops.nn.sigmoidCrossEntropyWithLogits(targets, logits); + Operand expected = + tf.constant( + new float[] { + 0.f, 0.126928f, 2.126928f, 0.6931472f, + 2.126928f, 0.126928f, 1.126928f, 0.f + }); + session.evaluate(expected, loss); + } + } + + @Test + public void testSoftmaxCrossEntropyWithLogits() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + float[] x = new float[] {-100, -2, -2, 0, 2, 2, 2, 100}; + float[] y = new float[] {0, 0, 1, 0, 0, 1, 0.5f, 1}; + + Operand logits = tf.constant(x); + Operand targets = tf.constant(y); + Operand loss = fops.nn.softmaxCrossEntropyWithLogits(targets, logits, 0); + + session.evaluate(249.0f, loss); + } + } + + @Test + public void testSparseSoftmaxCrossEntropyWithLogits() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + float[][] x = new float[][] {{0, 0}}; + int[] y = new int[] {0}; + + Operand logits = tf.constant(x); + Operand labels = tf.constant(y); + Operand loss = fops.nn.sparseSoftmaxCrossEntropyWithLogits(labels, logits); + + session.evaluate(0.69314718f, loss); + } + } + + @Test + public void testSoftmax() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + + double[][] x = { + { + 1.53975978e-01, + -5.55871308e-01, + 1.06272554e+00, + -7.75577792e-05, + -1.07574403e+00, + -1.70856595e+00, + 6.31895363e-01, + 2.69239008e-01, + 5.44192731e-01, + 6.31500483e-01 + }, + { + -1.15359895e-01, + -2.49849468e-01, + 8.04671764e-01, + 6.24943256e-01, + -4.80956525e-01, + 5.99363089e-01, + 7.44674265e-01, + 1.03888428e+00, + -2.00478077e-01, + 5.33391297e-01 + }, + { + 3.77073050e-01, + -4.92661327e-01, + -8.23421478e-01, + -6.53828621e-01, + 2.00867987e+00, + 2.94002771e-01, + 1.70056212e+00, + 5.50198834e-04, + 6.12767756e-01, + -2.29190066e-01 + }, + { + -1.60519981e+00, + -3.48692238e-01, + -3.25094163e-03, + 4.39969897e-01, + 1.50762582e+00, + 9.69331264e-01, + -1.18115306e+00, + 1.34852254e+00, + -1.24402285e+00, + -3.12961072e-01 + }, + { + -1.40357280e+00, + -1.08287978e+00, + -3.79449308e-01, + 1.51061141e+00, + 7.71783948e-01, + 5.29040515e-01, + 8.77655566e-01, + -1.53738844e+00, + 9.32778895e-01, + 3.69026303e-01 + } + }; + + double[][] expected = { + { + 0.09007322, + 0.04429074, + 0.2234913, + 0.07721311, + 0.0263351, + 0.01398634, + 0.14526248, + 0.10107733, + 0.13306525, + 0.14520513 + }, + { + 0.05687012, + 0.04971369, + 0.14270815, + 0.11923224, + 0.0394555, + 0.11622094, + 0.13439782, + 0.1803707, + 0.05222973, + 0.10880107 + }, + { + 0.06962293, + 0.02917639, + 0.02095966, + 0.02483347, + 0.35591814, + 0.06407304, + 0.26153892, + 0.04777828, + 0.08812784, + 0.03797131 + }, + { + 0.01272309, + 0.0446979, + 0.06314083, + 0.0983555, + 0.28607225, + 0.16699266, + 0.01944258, + 0.24399339, + 0.01825786, + 0.04632387 + }, + { + 0.0151052, + 0.02081621, + 0.04206274, + 0.2784457, + 0.13300617, + 0.10433973, + 0.1478602, + 0.01321329, + 0.15623957, + 0.08891118 + } + }; + + Operand input = tf.constant(x); + Operand expectedResult = tf.constant(expected); + Operand result = fops.nn.softmax(input); + session.evaluate(expectedResult, result); + } + } + + @Test + public void testSoftmaxAxes() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + + double[][] arr = { + {0., 0.09090909, 0.18181818, 0.27272727}, + {0.36363636, 0.45454545, 0.54545455, 0.63636364}, + {0.72727273, 0.81818182, 0.90909091, 1.} + }; + + Operand arrTF = tf.constant(arr); + + Operand xNegAxisResult = fops.nn.softmax(arrTF, -2); + Operand yPosAxisResult = fops.nn.softmax(arrTF, 0); + Operand zGtAxisResult = fops.nn.softmax(arrTF, 0); + session.evaluate(xNegAxisResult, yPosAxisResult); + session.evaluate(yPosAxisResult, zGtAxisResult); + } + } + + @Test + public void testLogSoftmax() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + + double[][] array = { + { + -0.37646714, + -0.5311734, + 0.82353556, + -1.0500005, + -1.2197578, + -1.0560939, + 0.7242568, + -0.93800896, + 0.47922453, + 0.96604276 + }, + { + 1.8431032, + 0.63521856, + -0.82236594, + -1.8610067, + 0.890422, + -1.8440033, + -1.5645103, + -0.31505722, + 1.7022362, + 0.5422927 + }, + { + -2.3798232, + 0.56610274, + -0.28281465, + 1.37052, + -0.08637848, + 0.3824045, + -0.7390341, + 0.38309613, + -0.05741333, + 0.41976207 + }, + { + 1.4530851, + -0.8334874, + 0.14740701, + 0.00373064, + -0.86982375, + -0.6652942, + 0.665558, + 1.1553634, + 1.5083209, + 0.04152437 + }, + { + -0.3040565, + -0.86586237, + 1.0949674, + -0.4449086, + -0.48374927, + 0.6941735, + 0.21010222, + -0.20612952, + -0.32806364, + 1.6194562 + } + }; + + double[][] expectedArray = { + { + -2.7961895, + -2.9508958, + -1.5961869, + -3.4697227, + -3.63948, + -3.4758162, + -1.6954656, + -3.3577313, + -1.9404979, + -1.4536797 + }, + { + -1.1292517, + -2.3371363, + -3.794721, + -4.8333616, + -2.081933, + -4.8163586, + -4.536865, + -3.2874122, + -1.2701187, + -2.4300623 + }, + { + -4.97046, + -2.024534, + -2.8734512, + -1.2201167, + -2.6770153, + -2.2082322, + -3.329671, + -2.2075405, + -2.64805, + -2.1708746 + }, + { + -1.464081, + -3.7506535, + -2.7697592, + -2.9134355, + -3.78699, + -3.5824602, + -2.2516081, + -1.7618027, + -1.4088452, + -2.8756418 + }, + { + -3.0270078, + -3.5888138, + -1.6279839, + -3.1678598, + -3.2067006, + -2.0287778, + -2.512849, + -2.929081, + -3.051015, + -1.1034951 + } + }; + + Operand input = tf.constant(array); + + Operand result = fops.nn.logSoftmax(input); + Operand expected = tf.constant(expectedArray); + session.evaluate(expected, result); + } + } + + @Test + public void testLogSoftmaxAxes() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + + double[][] arr = { + {0., 0.09090909, 0.18181818, 0.27272727}, + {0.36363636, 0.45454545, 0.54545455, 0.63636364}, + {0.72727273, 0.81818182, 0.90909091, 1.} + }; + + Operand arrTF = tf.constant(arr); + + Operand xNegAxisResult = fops.nn.logSoftmax(arrTF, -2); + Operand yPosAxisResult = fops.nn.logSoftmax(arrTF, 0); + Operand zGtAxisResult = fops.nn.logSoftmax(arrTF, 0); + session.evaluate(xNegAxisResult, yPosAxisResult); + session.evaluate(yPosAxisResult, zGtAxisResult); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/SetOpsTest.java similarity index 86% rename from tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java rename to tensorflow-framework/src/test/java/org/tensorflow/framework/op/SetOpsTest.java index eceff2797f8..0c4b6ab9a51 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/SetOpsTest.java @@ -1,4 +1,4 @@ -package org.tensorflow.framework.metrics.impl; +package org.tensorflow.framework.op; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; @@ -15,7 +15,7 @@ import static org.tensorflow.framework.utils.CastHelper.cast; -class SetsOpsTest { +class SetOpsTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; @@ -28,14 +28,16 @@ public void testSetIntersectionMultirow2() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); Operand a = tf.constant(new int[][] {{9, 1, 5}, {2, 4, 3}}); Operand b = tf.constant(new int[][] {{1, 9}, {1, 5}}); int[][] expected = new int[][] {{1, 9}, {0, 0}}; Shape expectedShape = Shape.of(2, 2); for (Class type : types) { + // Use raw type because of changing type in for loop Operand aa = cast(tf, a, type); Operand bb = cast(tf, b, type); - Operand intersection = SetsOps.intersection(tf, aa, bb); + Operand intersection = fops.sets.intersection(aa, bb); session.evaluate(cast(tf, tf.constant(expected), type), intersection); session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); } @@ -49,6 +51,7 @@ public void testSetIntersectionDuplicates2d() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); Operand a = tf.constant(new int[][] {{1, 1, 3}}); Operand b = tf.constant(new int[][] {{1, 1}}); int[][] expected = {{1}}; @@ -56,7 +59,7 @@ public void testSetIntersectionDuplicates2d() { for (Class type : types) { Operand aa = cast(tf, a, type); Operand bb = cast(tf, b, type); - Operand intersection = SetsOps.intersection(tf, aa, bb); + Operand intersection = fops.sets.intersection(aa, bb); session.evaluate(cast(tf, tf.constant(expected), type), intersection); @@ -72,6 +75,7 @@ public void testDenseSetDifferenceMultirow2d() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); Operand a = tf.constant(new int[][] {{1, 5, 9}, {4, 5, 3}}); Operand b = tf.constant(new int[][] {{1, 2, 6}, {1, 2, 2}}); @@ -81,14 +85,14 @@ public void testDenseSetDifferenceMultirow2d() { int[][] expected = {{5, 9, 0}, {3, 4, 5}}; // a- b Shape expectedShape = Shape.of(2, 3); - Operand intersection = SetsOps.difference(tf, aa, bb); + Operand intersection = fops.sets.difference(aa, bb); session.evaluate(cast(tf, tf.constant(expected), type), intersection); session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); // b - a expected = new int[][] {{2, 6}, {1, 2}}; expectedShape = Shape.of(2, 2); - intersection = SetsOps.difference(tf, aa, bb, false); + intersection = fops.sets.difference(aa, bb, false); session.evaluate(cast(tf, tf.constant(expected), type), intersection); session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); @@ -103,6 +107,7 @@ public void testDenseUnionMultirow2d() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); Operand a = tf.constant(new int[][] {{9, 1, 5}, {2, 4, 3}}); Operand b = tf.constant(new int[][] {{1, 9}, {1, 2}}); int[][] expected = new int[][] {{5, 0}, {3, 4}}; @@ -111,7 +116,7 @@ public void testDenseUnionMultirow2d() { Operand bb = cast(tf, b, type); Shape expectedShape = Shape.of(2, 2); // a- b - Operand intersection = SetsOps.difference(tf, aa, bb); + Operand intersection = fops.sets.difference(aa, bb); session.evaluate(cast(tf, tf.constant(expected), type), intersection); session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index d6786b71972..acaf6bcaa67 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -1,13 +1,17 @@ package org.tensorflow.framework.optimizers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Session; import org.tensorflow.Tensor; import org.tensorflow.framework.initializers.Glorot; import org.tensorflow.framework.initializers.VarianceScaling; import org.tensorflow.framework.utils.TestSession; -import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.buffer.DataBuffers; import org.tensorflow.op.Op; @@ -26,10 +30,8 @@ import org.tensorflow.types.family.TType; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; /** Test cases for GradientDescent Optimizer */ @@ -162,7 +164,7 @@ public void testDeterminism() { tf.withName("output").placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 2))); Mean loss = tf.math.mean( - tf.nn.raw.softmaxCrossEntropyWithLogits(output, placeholder).loss(), tf.constant(0)); + tf.nn.softmaxCrossEntropyWithLogits(output, placeholder).loss(), tf.constant(0)); lossName = loss.op().name(); GradientDescent gd = new GradientDescent(g, 10.0f); @@ -205,12 +207,15 @@ public void testDeterminism() { .fetch(outputBiasName) .run()); - TFloat32 lossVal = (TFloat32) s.runner() - .addTarget(trainName) - .feed("input", dataTensor) - .feed("output", targetTensor) - .fetch(lossName) - .run().get(0); + TFloat32 lossVal = + (TFloat32) + s.runner() + .addTarget(trainName) + .feed("input", dataTensor) + .feed("output", targetTensor) + .fetch(lossName) + .run() + .get(0); initialLoss[i] = lossVal.getFloat(); lossVal.close(); @@ -222,12 +227,15 @@ public void testDeterminism() { .fetch(outputBiasName) .run()); - lossVal = (TFloat32) s.runner() - .addTarget(trainName) - .feed("input", dataTensor) - .feed("output", targetTensor) - .fetch(lossName) - .run().get(0); + lossVal = + (TFloat32) + s.runner() + .addTarget(trainName) + .feed("input", dataTensor) + .feed("output", targetTensor) + .fetch(lossName) + .run() + .get(0); postTrainingLoss[i] = lossVal.getFloat(); lossVal.close(); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java index 7884308c9fb..6dc43ebfe64 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java @@ -14,22 +14,21 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.*; +import org.tensorflow.EagerSession; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.Tensor; import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.IntNdArray; -import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; import java.io.PrintWriter; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; +import java.util.Map; import java.util.function.Predicate; -import static org.junit.jupiter.api.Assertions.*; - /** Eager Mode Test Session */ public class EagerTestSession extends TestSession { @@ -83,676 +82,70 @@ public EagerSession getEagerSession() { /** {@inheritDoc} */ @Override - public void evaluate(double expected, Operand input) { - Class inputType = input.type(); - if (inputType == TFloat32.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals(expected, f.getFloat(), epsilon)); - } else if (inputType == TFloat64.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); - } else if (inputType == TInt32.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals((int) expected, f.getInt())); - } else if (inputType == TInt64.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals((long) expected, f.getLong())); - } else if (inputType == TUint8.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); - } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals((long) expected, f.getByte())); - } + public void evaluate( + double expected, Operand input, Map, Tensor> feedMap) { + super.evaluate(expected, input.asTensor(), input.type()); } /** {@inheritDoc} */ @Override - public void evaluate(Number[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); - Class inputType = input.type(); - if (inputType == TFloat32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> - assertEquals( - expected[index.getAndIncrement()].floatValue(), f.getFloat(), epsilon)); - } else if (inputType == TFloat64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> - assertEquals( - expected[index.getAndIncrement()].doubleValue(), f.getDouble(), epsilon)); - } else if (inputType == TInt32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); - } else if (inputType == TInt64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); - } else if (inputType == TUint8.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%x). %d\n", index.getAndIncrement(), f.getByte())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].byteValue(), f.getByte())); - } + public void evaluate( + Number[] expected, Operand input, Map, Tensor> feedMap) { + super.evaluate(expected, input.asTensor(), input.type()); } /** {@inheritDoc} */ @Override - public void evaluate(FloatNdArray expected, Output input) { - Class inputType = input.type(); - if (inputType == TFloat32.class) { - Output o = (Output) input; - AtomicLong index = new AtomicLong(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> assertEquals(expected.getFloat(index.getAndIncrement()), f.getFloat(), epsilon)); - } else if (inputType == TFloat64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> - assertEquals(expected.getFloat(index.getAndIncrement()), f.getDouble(), epsilon)); - } else if (inputType == TInt32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - index.set(0); - for (IntNdArray f : o.asTensor().scalars()) { - assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt()); - } - } else if (inputType == TInt64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); - } else if (inputType == TUint8.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getByte())); - } + public void evaluate( + FloatNdArray expected, Operand input, Map, Tensor> feedMap) { + super.evaluate(expected, input.asTensor(), input.type()); } /** {@inheritDoc} */ @Override - public void evaluateString(Output input, Predicate predicate) { - AtomicInteger index = new AtomicInteger(); - boolean isScalar = input.shape().equals(Shape.scalar()); - if (debug) { - if (isScalar) { - System.out.printf( - "0). %b <==> %s\n", predicate.test(input.asTensor().getObject()), input.asTensor().getObject()); - } else { - input - .asTensor() - .scalars() - .forEachIndexed( - (idx, s) -> - System.out.printf( - "%d). %b <==> %s\n", - index.getAndIncrement(), predicate.test(s.getObject()), s.getObject())); - } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(input.asTensor().getObject())); - } else { - input.asTensor().scalars().forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); - } + public void evaluateString( + Operand input, + Predicate predicate, + Map, Tensor> feedMap) { + + super.evaluateString(input.asTensor(), predicate); } /** {@inheritDoc} */ @Override - public void evaluate(Output input, Predicate predicate) { - AtomicInteger index = new AtomicInteger(); - Class inputType = input.type(); - boolean isScalar = input.shape().equals(Shape.scalar()); - if (inputType == TFloat32.class) { - Output o = (Output) input; - if (debug) { - if (isScalar) { - System.out.printf( - "0). %b <==> %f\n", predicate.test(o.asTensor().getFloat()), o.asTensor().getFloat()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getFloat()), f.getFloat())); - } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getFloat())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getFloat()))); - } - } else if (inputType == TFloat64.class) { - Output o = (Output) input; - if (debug) { - if (isScalar) { - System.out.printf( - "0). %b <==> %f\n", predicate.test(o.asTensor().getDouble()), o.asTensor().getDouble()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getDouble()), f.getDouble())); - } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getDouble())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getDouble()))); - } - } else if (inputType == TFloat16.class) { - Output o = (Output) input; - if (debug) { - if (isScalar) { - System.out.printf( - "0). %b <==> %f\n", predicate.test(o.asTensor().getFloat()), o.asTensor().getFloat()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getFloat()), f.getFloat())); - } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getFloat())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getFloat()))); - } - } else if (inputType == TInt32.class) { - Output o = (Output) input; - if (debug) { - if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", predicate.test(o.asTensor().getInt()), o.asTensor().getInt()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getInt()), f.getInt())); - } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getInt())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getInt()))); - } - } else if (inputType == TInt64.class) { - Output o = (Output) input; - if (debug) { - if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", predicate.test(o.asTensor().getLong()), o.asTensor().getLong()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getLong()), f.getLong())); - } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getLong())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getLong()))); - } - } else if (inputType == TUint8.class) { - Output o = (Output) input; - if (debug) { - if (isScalar) { - System.out.printf( - "0). %b <==> %x\n", predicate.test(o.asTensor().getByte()), o.asTensor().getByte()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %x\n", - index.getAndIncrement(), predicate.test(f.getByte()), f.getByte())); - } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getByte())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getByte()))); - } - } else { - fail("Unexpected Class: " + inputType); - } + public void evaluate( + Operand input, + Predicate predicate, + Map, Tensor> feedMap) { + super.evaluate(input.asTensor(), input.type(), predicate); } /** {@inheritDoc} */ @Override - public void evaluate(String[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); - AtomicInteger index = new AtomicInteger(); - if (debug) { - input - .asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); - } - index.set(0); - input - .asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); + public void evaluate( + String[] expected, Operand input, Map, Tensor> feedMap) { + super.evaluate(expected, input.asTensor()); } /** {@inheritDoc} */ @Override - public void evaluate(Boolean[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected size (%d) != to input length (%d)", expected.length, size)); - AtomicInteger index = new AtomicInteger(); - if (debug) { - input - .asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); - } - index.set(0); - input - .asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getBoolean())); + public void evaluate( + Boolean[] expected, Operand input, Map, Tensor> feedMap) { + super.evaluate(expected, input.asTensor()); } /** {@inheritDoc} */ @Override - public void evaluate(Output expected, Output input) { - assert input.shape().equals(expected.shape()) - : String.format( - "expected shape (%s) != to input shape (%s)", - expected.shape().toString(), input.shape().toString()); - Class inputType = input.asOutput().type(); - boolean isScalar = input.shape().equals(Shape.scalar()); - if (inputType == TFloat32.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - if (isScalar) { - System.out.printf("0). %f <==> %f\n", x.asTensor().getFloat(), o.asTensor().getFloat()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), x.asTensor().getFloat(idx), f.getFloat())); - } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getFloat(), o.asTensor().getFloat(), epsilon); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(x.asTensor().getFloat(idx), f.getFloat(), epsilon)); - } - } else if (inputType == TFloat64.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - if (isScalar) { - System.out.printf("0). %f <==> %f\n", x.asTensor().getDouble(), o.asTensor().getDouble()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), x.asTensor().getDouble(idx), f.getDouble())); - } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getDouble(), o.asTensor().getDouble(), epsilon); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(x.asTensor().getDouble(idx), f.getDouble(), epsilon)); - } - } else if (inputType == TInt32.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - if (isScalar) { - System.out.printf("0). %d <==> %d\n", x.asTensor().getInt(), o.asTensor().getInt()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), x.asTensor().getInt(idx), f.getInt())); - } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getInt(), o.asTensor().getInt()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getInt(idx), f.getInt())); - } - } else if (inputType == TInt64.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - if (isScalar) { - System.out.printf("0). %d <==> %d\n", x.asTensor().getLong(), o.asTensor().getLong()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), x.asTensor().getLong(idx), f.getLong())); - } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getLong(), o.asTensor().getLong()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getLong(idx), f.getLong())); - } - } else if (inputType == TUint8.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - if (isScalar) { - System.out.printf("0). %x <==> %x\n", x.asTensor().getByte(), o.asTensor().getByte()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %x <==> %x\n", - index.getAndIncrement(), x.asTensor().getByte(idx), f.getByte())); - } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getByte(), o.asTensor().getByte()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getByte(idx), f.getByte())); - } - } else if (inputType == TString.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - if (isScalar) { - System.out.printf("0). %s <==> %s\n", x.asTensor().getObject(), o.asTensor().getObject()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %s <==> %s\n", - index.getAndIncrement(), x.asTensor().getObject(idx), f.getObject())); - } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getObject(), o.asTensor().getObject()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getObject(idx), f.getObject())); - } - } else if (inputType == TBool.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - if (isScalar) { - System.out.printf("0). %b <==> %b\n", x.asTensor().getBoolean(), o.asTensor().getBoolean()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %b\n", - index.getAndIncrement(), x.asTensor().getBoolean(idx), f.getBoolean())); - } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getBoolean(), o.asTensor().getBoolean()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getBoolean(idx), f.getBoolean())); - } - } + public void evaluate( + Operand expected, Operand input, Map, Tensor> feedMap) { + + super.evaluate(expected.asTensor(), input.asTensor(), input.type()); } /** {@inheritDoc} */ @Override - public void print(PrintWriter writer, Output input) { - Class inputType = input.asOutput().type(); - if (inputType == TFloat32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } else if (inputType == TFloat64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } else if (inputType == TInt32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } else if (inputType == TInt64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } else if (inputType == TUint8.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); - } else if (inputType == TString.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); - } else if (inputType == TBool.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); - } else { - writer.println("Unexpected Class: " + inputType); - } - writer.flush(); + public void print( + PrintWriter writer, Operand input, Map, Tensor> feedMap) { + super.print(writer, input.asTensor(), input.type()); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java index 43c0642939e..0583913a6c6 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java @@ -14,34 +14,31 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.*; +import org.tensorflow.EagerSession; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.Tensor; import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; import java.io.PrintWriter; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; +import java.util.Map; import java.util.function.Predicate; -import static org.junit.jupiter.api.Assertions.*; - -/** - * Graph Mode Test Session - */ +/** Graph Mode Test Session */ public class GraphTestSession extends TestSession { private final Graph graph; private final Session session; private final Ops tf; - /** - * Create a Graph mode test session. - */ + /** Create a Graph mode test session. */ public GraphTestSession() { graph = new Graph(); session = new Session(graph); @@ -49,16 +46,25 @@ public GraphTestSession() { } /** - * {@inheritDoc} + * Create a Graph mode test session. + * + * @param graph the graph + * @param session the session + * @param tf the TensorFlow Ops */ + public GraphTestSession(Graph graph, Session session, Ops tf) { + this.graph = graph; + this.session = session; + this.tf = tf; + } + + /** {@inheritDoc} */ @Override public Ops getTF() { return tf; } - /** - * Get the Graph object that is represented by this Test Session - */ + /** Get the Graph object that is represented by this Test Session */ public Graph getGraph() { return graph; } @@ -72,1051 +78,154 @@ public Session getSession() { return session; } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public void close() { session.close(); graph.close(); } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public boolean isEager() { return false; } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public Session getGraphSession() { return this.session; } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public EagerSession getEagerSession() { return null; } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public void initialize() { graph.initializers().forEach(initializer -> session.runner().addTarget(initializer).run()); } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override - public void run(Op op) { - session.run(op); + public void run(Op op, Map, Tensor> feedMap) { + createRunner(op, feedMap).run(); } /** - * {@inheritDoc} + * Create a runner for the Operation + * + * @param feedMap the dictionary of values to use for the runner's feed operations. Required when + * placeholders are used. + * @return the runner */ - @Override - public void evaluate(double expected, Operand input) { - Class inputType = input.type(); - if (inputType == TFloat32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals((float) expected, f.getFloat(), epsilon)); - } - } else if (inputType == TFloat64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); - } - } else if (inputType == TInt32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals((int) expected, f.getInt())); - } - } else if (inputType == TInt64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals((long) expected, f.getLong())); - } - } else if (inputType == TUint8.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); - } - } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals((long) expected, f.getByte())); - } - } else { - fail("Unexpected type class: " + inputType); - } + public final Session.Runner createRunner(Map, Tensor> feedMap) { + return createRunner(null, feedMap); } /** - * {@inheritDoc} + * Create a runner for the Operation + * + * @param op the operation + * @param feedMap the dictionary of values to use for the runner's feed operations. Required when + * placeholders are used. + * @return the runner */ + public final Session.Runner createRunner(Op op, Map, Tensor> feedMap) { + Session.Runner runner = session.runner(); + if (op != null) runner.addTarget(op.op()); + if (feedMap != null) feedMap.forEach((operand, tensor) -> runner.feed(operand, tensor)); + + return runner; + } + + /** {@inheritDoc} */ @Override - public void evaluate(Number[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - if (size != Shape.UNKNOWN_SIZE) { - assertEquals( - expected.length, - size, - () -> - String.format("expected length (%d) != to input length (%d)", expected.length, size)); - } - Class inputType = input.type(); - if (inputType == TFloat32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> - assertEquals( - expected[index.getAndIncrement()].floatValue(), f.getFloat(), epsilon)); - } - } else if (inputType == TFloat64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> - assertEquals( - expected[index.getAndIncrement()].doubleValue(), f.getDouble(), epsilon)); - } - } else if (inputType == TInt32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); - } - } else if (inputType == TInt64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); - } - } else if (inputType == TUint8.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); - } - } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getByte())); - } - } else { - fail("Unexpected type class: " + inputType); + public void evaluate( + double expected, Operand input, Map, Tensor> feedMap) { + try (Tensor tensor = createRunner(feedMap).fetch(input).run().get(0)) { + super.evaluate(expected, tensor, input.type()); } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override - public void evaluate(FloatNdArray expected, Output input) { - Class inputType = input.type(); - if (inputType == TFloat32.class) { - AtomicLong index = new AtomicLong(); - if (debug) { - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> - assertEquals( - expected.getFloat(index.getAndIncrement()), f.getFloat(), epsilon)); - } - } else if (inputType == TFloat64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> - assertEquals( - expected.getFloat(index.getAndIncrement()), f.getDouble(), epsilon)); - } - } else if (inputType == TInt32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt())); - } - } else if (inputType == TInt64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); - } - } else if (inputType == TUint8.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); - } - } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getByte())); - } - } else { - fail("Unexpected type class: " + inputType); + public void evaluate( + Number[] expected, Operand input, Map, Tensor> feedMap) { + + try (Tensor tensor = createRunner(feedMap).fetch(input).run().get(0)) { + super.evaluate(expected, tensor, input.type()); } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override - public void evaluate(String[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - if (size != Shape.UNKNOWN_SIZE) { - assertEquals( - expected.length, - size, - () -> - String.format("expected length (%d) != to input length (%d)", expected.length, size)); - } - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); - } - } - index.set(0); - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); + public void evaluate( + FloatNdArray expected, Operand input, Map, Tensor> feedMap) { + try (Tensor tensor = createRunner(feedMap).fetch(input).run().get(0)) { + super.evaluate(expected, tensor, input.type()); } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override - public void evaluate(Boolean[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getObject())); - } - } - index.set(0); - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); + public void evaluateString( + Operand input, + Predicate predicate, + Map, Tensor> feedMap) { + try (TString tensor = (TString) createRunner(feedMap).fetch(input).run().get(0)) { + super.evaluateString(tensor, predicate); } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override - public void evaluate(Output expected, Output input) { - assert input.shape().equals(expected.shape()) - : String.format( - "expected shape (%s) != to input shape (%s)", - expected.shape().toString(), input.shape().toString()); - AtomicInteger index = new AtomicInteger(); - Class inputType = input.type(); - if (!inputType.equals(expected.type())) { - throw new IllegalArgumentException( - String.format( - "Both data type must be equal, inout = %s, expected = %s", - inputType, expected.dataType())); - } - boolean isScalar = input.shape().equals(Shape.scalar()); - if (inputType == TFloat32.class) { - final Output finalExpected = (Output) expected; - if (debug) { - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0); - TFloat32 expectedResult = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), - finalExpected.asTensor().getFloat(idx), - f.getFloat())); - } - } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0); - TFloat32 expectedResult = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getFloat(), result.getFloat(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); - } - } - } else if (inputType == TFloat64.class) { - final Output finalExpected = (Output) expected; - if (debug) { - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0); - TFloat64 expectedResult = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %f <==> %f\n", expectedResult.getDouble(), result.getDouble()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), - finalExpected.asTensor().getDouble(idx), - f.getDouble())); - } - } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0); - TFloat64 expectedResult = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getDouble(), result.getDouble(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getDouble(idx), f.getDouble(), epsilon)); - } - } - } else if (inputType == TFloat16.class) { - final Output finalExpected = (Output) expected; - if (debug) { - try (TFloat16 result = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0); - TFloat16 expectedResult = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), - finalExpected.asTensor().getFloat(idx), - f.getFloat())); - } - } - } - index.set(0); - try (TFloat16 result = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0); - TFloat16 expectedResult = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getFloat(), result.getFloat(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); - } - } - } else if (inputType == TInt32.class) { - final Output finalExpected = (Output) expected; - if (debug) { - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0); - TInt32 expectedResult = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %d <==> %d\n", expectedResult.getInt(), result.getInt()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), finalExpected.asTensor().getInt(idx), f.getInt())); - } - } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0); - TInt32 expectedResult = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getInt(), result.getInt(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(expectedResult.getInt(idx), f.getInt(), epsilon)); - } - } - } else if (inputType == TInt64.class) { - final Output finalExpected = (Output) expected; - if (debug) { - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0); - TInt64 expectedResult = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %d <==> %d\n", expectedResult.getLong(), result.getLong()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), - finalExpected.asTensor().getLong(idx), - f.getLong())); - } - } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0); - TInt64 expectedResult = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getLong(), result.getLong(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getLong(idx), f.getLong(), epsilon)); - } - } - } else if (inputType == TUint8.class) { - final Output finalExpected = (Output) expected; - if (debug) { - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0); - TUint8 expectedResult = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %d <==> %d\n", expectedResult.getByte(), result.getByte()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), - finalExpected.asTensor().getByte(idx), - f.getByte())); - } - } - } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0); - TUint8 expectedResult = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getByte(), result.getByte(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getByte(idx), f.getByte(), epsilon)); - } - } - } else if (inputType == TBool.class) { - final Output finalExpected = (Output) expected; - if (debug) { - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0); - TBool expectedResult = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %b <==> %b\n", expectedResult.getBoolean(), result.getBoolean()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %b\n", - index.getAndIncrement(), - finalExpected.asTensor().getBoolean(idx), - f.getBoolean())); - } - } - } - index.set(0); - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0); - TBool expectedResult = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getBoolean(), result.getBoolean()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(expectedResult.getBoolean(idx), f.getBoolean())); - } - } - } else if (inputType == TString.class) { - final Output finalExpected = (Output) expected; - if (debug) { - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0); - TString expectedResult = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %s <==> %s\n", expectedResult.getObject(), result.getObject()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %s <==> %s\n", - index.getAndIncrement(), - finalExpected.asTensor().getObject(idx), - f.getObject())); - } - } - } - index.set(0); - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0); - TString expectedResult = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getObject(), result.getObject()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(expectedResult.getObject(idx), f.getObject())); - } - } - } else { - fail("Unexpected type class: " + inputType); + public void evaluate( + Operand input, + Predicate predicate, + Map, Tensor> feedMap) { + try (Tensor tensor = createRunner(feedMap).fetch(input).run().get(0)) { + super.evaluate(tensor, input.type(), predicate); } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override - public void evaluateString(Output input, Predicate predicate) { - boolean isScalar = input.shape().equals(Shape.scalar()); - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %b <==> %s\n", - predicate.test(result.getObject()), result.getObject()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %s\n", - index.getAndIncrement(), predicate.test(f.getObject()), f.getObject())); - } - } - } - index.set(0); - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getObject())); - } else { - result - .scalars() - .forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); - } + public void evaluate( + String[] expected, Operand input, Map, Tensor> feedMap) { + try (TString tensor = (TString) createRunner(feedMap).fetch(input).run().get(0)) { + super.evaluate(expected, tensor); } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override - public void evaluate(Output input, Predicate predicate) { - AtomicInteger index = new AtomicInteger(); - Class inputType = input.type(); - boolean isScalar = input.shape().equals(Shape.scalar()); - if (inputType == TFloat32.class) { - if (debug) { - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %b <==> %f\n", - predicate.test(result.getFloat()), result.getFloat()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getFloat()), f.getFloat())); - } - } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getFloat())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getFloat()))); - } - } - } else if (inputType == TFloat64.class) { - if (debug) { - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %b <==> %f\n", - predicate.test(result.getDouble()), result.getDouble()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getDouble()), f.getDouble())); - } - } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getDouble())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getDouble()))); - } - } - } else if (inputType == TInt32.class) { - if (debug) { - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", predicate.test(result.getInt()), result.getInt()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getInt()), f.getInt())); - } - } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getInt())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getInt()))); - } - } - } else if (inputType == TInt64.class) { - if (debug) { - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", - predicate.test(result.getLong()), result.getLong()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getLong()), f.getLong())); - } - } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getLong())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getLong()))); - } - } - } else if (inputType == TUint8.class) { - if (debug) { - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", - predicate.test(result.getByte()), result.getByte()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getByte()), f.getByte())); - } - } - } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getByte())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getByte()))); - } - } - } else { - fail("Unexpected type class: " + inputType); + public void evaluate( + Boolean[] expected, Operand input, Map, Tensor> feedMap) { + try (TBool tensor = (TBool) createRunner(feedMap).fetch(input).run().get(0)) { + super.evaluate(expected, tensor); } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override - public void print(PrintWriter writer, Output input) { - boolean isScalar = input.shape().size() == 1; - - Class inputType = input.type(); - if (inputType == TFloat32.class) { - AtomicInteger index = new AtomicInteger(); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf("%d). %f\n", index.getAndIncrement(), result.getFloat()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - } - } else if (inputType == TFloat64.class) { - AtomicInteger index = new AtomicInteger(); - - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %f\n", index.getAndIncrement(), result.getDouble()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - } - } else if (inputType == TInt32.class) { - AtomicInteger index = new AtomicInteger(); - - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %d\n", index.getAndIncrement(),result.getInt()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - } - } else if (inputType == TInt64.class) { - AtomicInteger index = new AtomicInteger(); - - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %d\n", index.getAndIncrement(), result.getLong()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - } - } else if (inputType == TUint8.class) { - AtomicInteger index = new AtomicInteger(); - - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %x\n", index.getAndIncrement(), result.getByte()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); - } - } - } else if (inputType == TBool.class) { - AtomicInteger index = new AtomicInteger(); - - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %b\n", index.getAndIncrement(), result.getBoolean()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); - } - } - } else if (inputType == TString.class) { - AtomicInteger index = new AtomicInteger(); + public void evaluate( + Operand expected, Operand input, Map, Tensor> feedMap) { + try (Tensor tensor = createRunner(feedMap).fetch(input).run().get(0); + Tensor expectedTensor = createRunner(feedMap).fetch(expected).run().get(0)) { + super.evaluate(expectedTensor, tensor, input.type()); + } + } - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %s\n", index.getAndIncrement(), result.getObject()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); - } - } - } else { - writer.println("Unexpected type class: " + inputType); + /** {@inheritDoc} */ + @Override + public void print( + PrintWriter writer, Operand input, Map, Tensor> feedMap) { + try (Tensor tensor = createRunner(feedMap).fetch(input).run().get(0)) { + super.print(writer, tensor, input.type()); } - writer.flush(); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java index c0c0f12fbf9..a103ef9884a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java @@ -14,7 +14,11 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.ndarray.*; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; import java.util.Arrays; import java.util.concurrent.atomic.AtomicBoolean; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/QuadConsumer.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/QuadConsumer.java new file mode 100644 index 00000000000..7f576e8b614 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/QuadConsumer.java @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.utils; + +/** + * Represents an operation that accepts four input arguments and returns no result. This is the + * quaternary specialization of {@link java.util.function.Consumer}. Unlike most other functional + * interfaces, {@code QuadConsumer} is expected to operate via side-effects. + * + *

This is a functional interface whose functional method is {@link #accept(Object, Object, + * Object, Object)}. + * + * @param the type of the first argument to the operation + * @param the type of the second argument to the operation + * @param the type of the third argument to the operation + * @param the type of the fourth argument to the operation + */ +@FunctionalInterface +interface QuadConsumer { + + /** + * Performs this operation on the given arguments. + * + * @param t the first input argument + * @param s the second input argument + * @param u the third input argument + * @param v the forth input argument + */ + void accept(T t, S s, U u, V v); +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java index 2c252d467c7..d5578042321 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java @@ -14,13 +14,25 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.*; -import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.EagerSession; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.Tensor; import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; +import org.tensorflow.types.TBfloat16; import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat16; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; +import org.tensorflow.types.TUint8; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; @@ -28,14 +40,229 @@ import java.io.OutputStreamWriter; import java.io.PrintWriter; import java.io.Writer; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; import java.util.function.Predicate; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; /** Base class for Test Session */ public abstract class TestSession implements AutoCloseable { - protected float epsilon = 1e-5F; + protected static final Map< + Class, TriConsumer>> + printMap = + new HashMap< + Class, + TriConsumer>>() { + { + put( + TUint8.class, + (writer, idx, o) -> + writer.printf( + "%s. %s\n", Arrays.toString(idx), ((Number) o.getObject()).byteValue())); + put( + TInt32.class, + (writer, idx, o) -> + writer.printf( + "%s. %d\n", Arrays.toString(idx), ((Number) o.getObject()).intValue())); + put( + TInt64.class, + (writer, idx, o) -> + writer.printf( + "%s. %d\n", Arrays.toString(idx), ((Number) o.getObject()).longValue())); + put( + TFloat32.class, + (writer, idx, o) -> + writer.printf( + "%s. %f\n", Arrays.toString(idx), ((Number) o.getObject()).floatValue())); + put( + TFloat64.class, + (writer, idx, o) -> + writer.printf( + "%s. %f\n", + Arrays.toString(idx), ((Number) o.getObject()).doubleValue())); + put( + TBfloat16.class, + (writer, idx, o) -> + writer.printf( + "%s. %f\n", Arrays.toString(idx), ((Number) o.getObject()).floatValue())); + put( + TFloat16.class, + (writer, idx, o) -> + writer.printf( + "%s. %f\n", Arrays.toString(idx), ((Number) o.getObject()).floatValue())); + put( + TBool.class, + (writer, idx, o) -> + writer.printf("%s. %b\n", Arrays.toString(idx), o.getObject())); + put( + TString.class, + (writer, idx, o) -> + writer.printf("%s. %s\n", Arrays.toString(idx), o.getObject())); + } + }; + protected static final Map< + Class, + QuadConsumer, NdArray>> + printPredicate = + new HashMap< + Class, + QuadConsumer, NdArray>>() { + { + put( + TUint8.class, + (writer, idx, predicate, o) -> + writer.printf( + "%s. %b <==> %d\n", + Arrays.toString(idx), + predicate.test(((Number) o.getObject()).byteValue()), + ((Number) o.getObject()).byteValue())); + put( + TInt32.class, + (writer, idx, predicate, o) -> + writer.printf( + "%s. %b <==> %d\n", + Arrays.toString(idx), + predicate.test(((Number) o.getObject()).intValue()), + ((Number) o.getObject()).intValue())); + put( + TInt64.class, + (writer, idx, predicate, o) -> + writer.printf( + "%s. %b <==> %d\n", + Arrays.toString(idx), + predicate.test(((Number) o.getObject()).longValue()), + ((Number) o.getObject()).longValue())); + put( + TFloat32.class, + (writer, idx, predicate, o) -> + writer.printf( + "%s. %b <==> %f\n", + Arrays.toString(idx), + predicate.test(((Number) o.getObject()).floatValue()), + ((Number) o.getObject()).floatValue())); + put( + TFloat64.class, + (writer, idx, predicate, o) -> + writer.printf( + "%s. %b <==> %f\n", + Arrays.toString(idx), + predicate.test(((Number) o.getObject()).doubleValue()), + ((Number) o.getObject()).doubleValue())); + put( + TBfloat16.class, + (writer, idx, predicate, o) -> + writer.printf( + "%s. %b <==> %f\n", + Arrays.toString(idx), + predicate.test(((Number) o.getObject()).floatValue()), + ((Number) o.getObject()).floatValue())); + put( + TFloat16.class, + (writer, idx, predicate, o) -> + writer.printf( + "%s. %b <==> %f\n", + Arrays.toString(idx), + predicate.test(((Number) o.getObject()).floatValue()), + ((Number) o.getObject()).floatValue())); + } + }; + protected static final Map< + Class, BiConsumer, NdArray>> + evalPredicate = + new HashMap< + Class, BiConsumer, NdArray>>() { + { + put( + TUint8.class, + (predicate, o) -> + assertTrue(predicate.test(((Number) o.getObject()).byteValue()))); + put( + TInt32.class, + (predicate, o) -> + assertTrue(predicate.test(((Number) o.getObject()).intValue()))); + put( + TInt64.class, + (predicate, o) -> + assertTrue(predicate.test(((Number) o.getObject()).longValue()))); + put( + TFloat32.class, + (predicate, o) -> + assertTrue(predicate.test(((Number) o.getObject()).floatValue()))); + put( + TFloat64.class, + (predicate, o) -> + assertTrue(predicate.test(((Number) o.getObject()).doubleValue()))); + put( + TBfloat16.class, + (predicate, o) -> + assertTrue(predicate.test(((Number) o.getObject()).floatValue()))); + put( + TFloat16.class, + (predicate, o) -> + assertTrue(predicate.test(((Number) o.getObject()).floatValue()))); + } + }; + private static final long[] ZERO_IDX = new long[0]; + private static final PrintWriter DEFAULT_WRITER = new PrintWriter(System.out); + protected static float epsilon = 1e-5F; + protected static final Map, BiConsumer>> + evalMap = + new HashMap, BiConsumer>>() { + { + put( + TUint8.class, + (expected, o) -> + assertEquals( + ((Number) expected).byteValue(), ((Number) o.getObject()).byteValue())); + put( + TInt32.class, + (expected, o) -> + assertEquals( + ((Number) expected).intValue(), ((Number) o.getObject()).intValue())); + put( + TInt64.class, + (expected, o) -> + assertEquals( + ((Number) expected).longValue(), ((Number) o.getObject()).longValue())); + put( + TFloat32.class, + (expected, o) -> + assertEquals( + ((Number) expected).floatValue(), + ((Number) o.getObject()).floatValue(), + epsilon)); + put( + TFloat64.class, + (expected, o) -> + assertEquals( + ((Number) expected).doubleValue(), + ((Number) o.getObject()).doubleValue(), + epsilon)); + put( + TBfloat16.class, + (expected, o) -> + assertEquals( + ((Number) expected).floatValue(), + ((Number) o.getObject()).floatValue(), + epsilon)); + put( + TFloat16.class, + (expected, o) -> + assertEquals( + ((Number) expected).floatValue(), + ((Number) o.getObject()).floatValue(), + epsilon)); + put(TBool.class, (expected, o) -> assertEquals(expected, o.getObject())); + put(TString.class, (expected, o) -> assertEquals(expected, o.getObject().toString())); + } + }; protected boolean debug; /** @@ -56,6 +283,18 @@ public static TestSession createGraphSession() { return new GraphTestSession(); } + /** + * Creates a Graph Test Session without creating its own graph + * + * @param graph the graph + * @param session the session + * @param tf the TensorFlow Ops + * @return the Graph Test Session + */ + public static TestSession createGraphSession(Graph graph, Session session, Ops tf) { + return new GraphTestSession(graph, session, tf); + } + /** * Creates a Test Session * @@ -66,17 +305,47 @@ public static TestSession createTestSession(Mode mode) { return mode == Mode.EAGER ? createEagerSession() : createGraphSession(); } + /** + * Get the epsilon value for evaluating float values + * + * @return the epsilon value for evaluating float values + */ + public static float getEpsilon() { + return epsilon; + } + + /** + * Set the epsilon value for evaluating float values + * + * @param epsilonValue the epsilon value for evaluating float values + */ + public static void setEpsilon(float epsilonValue) { + epsilon = epsilonValue; + } + /** Initializes the Test Session, default implementation is do nothing. */ public void initialize() { // empty } /** - * Runs the Operation + * Runs the Operation, in EagerMode this does nothing * - * @param op the Operation to run + * @param op the Operation to run. */ + @SuppressWarnings("unused") public void run(Op op) { + run(op, null); + } + + /** + * Runs the Operation, in EagerMode this does nothing + * + * @param op the Operation to run. + * @param feedMap a optional Map to feed to the run session when placeholders are used. + */ + @SuppressWarnings("unused") + public void run(Op op, Map, Tensor> feedMap) { // empty } @@ -98,7 +367,7 @@ public Graph getGraph() { * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(Number expected, Operand input) { - evaluate(new Number[] {expected}, input); + evaluate(new Number[] {expected}, input, null); } /** @@ -106,22 +375,25 @@ public void evaluate(Number expected, Operand input) { * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(Number expected, Op input) { - evaluate(new Number[] {expected}, input); + public void evaluate( + Number expected, Operand input, Map, Tensor> feedMap) { + evaluate(new Number[] {expected}, input, feedMap); } /** - * Evaluates the input against the expected values + * Evaluates the input against the expected value * - * @param expected the expected values + * @param expected the expected value * @param input the operand to evaluate + * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(Number[] expected, Op input) { - Output output = input.op().output(0); - evaluate(expected, output); + public void evaluate(byte expected, Operand input) { + evaluate((double) expected, input, null); } /** @@ -129,12 +401,13 @@ public void evaluate(Number[] expected, Op input) { * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(Number[] expected, Operand input) { - Output output = input.asOutput(); - evaluate(expected, output); + public void evaluate( + byte expected, Operand input, Map, Tensor> feedMap) { + evaluate((double) expected, input, feedMap); } /** @@ -145,8 +418,8 @@ public void evaluate(Number[] expected, Operand input) { * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(byte expected, Operand input) { - evaluate((double) expected, input); + public void evaluate(int expected, Operand input) { + evaluate((double) expected, input, null); } /** @@ -154,11 +427,13 @@ public void evaluate(byte expected, Operand input) { * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(int expected, Operand input) { - evaluate((double) expected, input); + public void evaluate( + int expected, Operand input, Map, Tensor> feedMap) { + evaluate((double) expected, input, feedMap); } /** @@ -170,7 +445,21 @@ public void evaluate(int expected, Operand input) { * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(long expected, Operand input) { - evaluate((double) expected, input); + evaluate((double) expected, input, null); + } + + /** + * Evaluates the input against the expected value + * + * @param expected the expected value + * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + public void evaluate( + long expected, Operand input, Map, Tensor> feedMap) { + evaluate((double) expected, input, feedMap); } /** @@ -182,7 +471,21 @@ public void evaluate(long expected, Operand input) { * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(float expected, Operand input) { - evaluate((double) expected, input); + evaluate((double) expected, input, null); + } + + /** + * Evaluates the input against the expected value + * + * @param expected the expected value + * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + public void evaluate( + float expected, Operand input, Map, Tensor> feedMap) { + evaluate((double) expected, input, feedMap); } /** @@ -193,7 +496,21 @@ public void evaluate(float expected, Operand input) { * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract void evaluate(double expected, Operand input); + public void evaluate(double expected, Operand input) { + evaluate(expected, input, null); + } + + /** + * Evaluates the input against the expected value + * + * @param expected the expected value + * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + public abstract void evaluate( + double expected, Operand input, Map, Tensor> feedMap); /** * Evaluates the input against the expected value @@ -204,9 +521,23 @@ public void evaluate(float expected, Operand input) { * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(byte[] expected, Operand input) { + evaluate(expected, input, null); + } + + /** + * Evaluates the input against the expected value + * + * @param expected the expected value + * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + public void evaluate( + byte[] expected, Operand input, Map, Tensor> feedMap) { Byte[] iArray = new Byte[expected.length]; for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + evaluate(iArray, input, feedMap); } /** @@ -218,9 +549,23 @@ public void evaluate(byte[] expected, Operand input) { * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(int[] expected, Operand input) { + evaluate(expected, input, null); + } + + /** + * Evaluates the input against the expected value + * + * @param expected the expected value + * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + public void evaluate( + int[] expected, Operand input, Map, Tensor> feedMap) { Integer[] iArray = new Integer[expected.length]; for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + evaluate(iArray, input, feedMap); } /** @@ -232,9 +577,23 @@ public void evaluate(int[] expected, Operand input) { * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(long[] expected, Operand input) { + evaluate(expected, input, null); + } + + /** + * Evaluates the input against the expected value + * + * @param expected the expected value + * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + public void evaluate( + long[] expected, Operand input, Map, Tensor> feedMap) { Long[] iArray = new Long[expected.length]; for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + evaluate(iArray, input, feedMap); } /** @@ -246,23 +605,23 @@ public void evaluate(long[] expected, Operand input) { * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(float[] expected, Operand input) { - Float[] iArray = new Float[expected.length]; - for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + evaluate(expected, input, null); } /** - * Evaluates the input against the expected value + * Evaluates the input against the expected values * - * @param expected the expected value + * @param expected the expected values * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(double[] expected, Operand input) { - Double[] iArray = new Double[expected.length]; + public void evaluate( + float[] expected, Operand input, Map, Tensor> feedMap) { + Float[] iArray = new Float[expected.length]; for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + evaluate(iArray, input, feedMap); } /** @@ -273,17 +632,24 @@ public void evaluate(double[] expected, Operand input) { * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract void evaluate(Number[] expected, Output input); + public void evaluate(double[] expected, Operand input) { + evaluate(expected, input, null); + } /** * Evaluates the input against the expected value * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(String expected, Operand input) { - evaluate(new String[] {expected}, input); + public void evaluate( + double[] expected, Operand input, Map, Tensor> feedMap) { + Double[] iArray = new Double[expected.length]; + for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; + evaluate(iArray, input, feedMap); } /** @@ -291,10 +657,11 @@ public void evaluate(String expected, Operand input) { * * @param expected the expected value * @param input the operand to evaluate + * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(String expected, Op input) { - evaluate(new String[] {expected}, input); + public void evaluate(Number[] expected, Operand input) { + evaluate(expected, input, null); } /** @@ -302,12 +669,12 @@ public void evaluate(String expected, Op input) { * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(String[] expected, Op input) { - Output output = input.op().output(0); - evaluate(expected, output); - } + public abstract void evaluate( + Number[] expected, Operand input, Map, Tensor> feedMap); /** * Evaluates the input against the expected value @@ -316,9 +683,8 @@ public void evaluate(String[] expected, Op input) { * @param input the operand to evaluate * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(String[] expected, Operand input) { - Output output = input.asOutput(); - evaluate(expected, output); + public void evaluate(String expected, Operand input) { + evaluate(new String[] {expected}, input, null); } /** @@ -326,9 +692,13 @@ public void evaluate(String[] expected, Operand input) { * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract void evaluate(String[] expected, Output input); + public void evaluate( + String expected, Operand input, Map, Tensor> feedMap) { + evaluate(new String[] {expected}, input, feedMap); + } /** * Evaluates the input against the expected value @@ -337,8 +707,8 @@ public void evaluate(String[] expected, Operand input) { * @param input the operand to evaluate * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(Boolean expected, Operand input) { - evaluate(new Boolean[] {expected}, input); + public void evaluate(String[] expected, Operand input) { + evaluate(expected, input, null); } /** @@ -346,11 +716,11 @@ public void evaluate(Boolean expected, Operand input) { * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(Boolean expected, Op input) { - evaluate(new Boolean[] {expected}, input); - } + public abstract void evaluate( + String[] expected, Operand input, Map, Tensor> feedMap); /** * Evaluates the input against the expected value @@ -359,9 +729,8 @@ public void evaluate(Boolean expected, Op input) { * @param input the operand to evaluate * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(Boolean[] expected, Op input) { - Output output = input.op().output(0); - evaluate(expected, output); + public void evaluate(Boolean expected, Operand input) { + evaluate(new Boolean[] {expected}, input, null); } /** @@ -369,11 +738,12 @@ public void evaluate(Boolean[] expected, Op input) { * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(Boolean[] expected, Operand input) { - Output output = input.asOutput(); - evaluate(expected, output); + public void evaluate( + Boolean expected, Operand input, Map, Tensor> feedMap) { + evaluate(new Boolean[] {expected}, input, feedMap); } /** @@ -383,20 +753,20 @@ public void evaluate(Boolean[] expected, Operand input) { * @param input the operand to evaluate * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract void evaluate(Boolean[] expected, Output input); + public void evaluate(Boolean[] expected, Operand input) { + evaluate(expected, input, null); + } /** * Evaluates the input against the expected value * * @param expected the expected value * @param input the operand to evaluate - * @param the data type of the expected Operand + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(Operand expected, Op input) { - Output output = input.op().output(0); - evaluate(expected, output); - } + public abstract void evaluate( + Boolean[] expected, Operand input, Map, Tensor> feedMap); /** * Evaluates the input against the expected value @@ -407,7 +777,7 @@ public void evaluate(Operand expected, Op input) { * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(Operand expected, Operand input) { - evaluate(expected.asOutput(), input.asOutput()); + evaluate(expected, input, null); } /** @@ -415,10 +785,12 @@ public void evaluate(Operand expected, Operand input) { * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract void evaluate(Output expected, Output input); + public abstract void evaluate( + Operand expected, Operand input, Map, Tensor> feedMap); /** * Evaluates the input against the expected value @@ -429,7 +801,7 @@ public void evaluate(Operand expected, Operand input) { * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(FloatNdArray expected, Operand input) { - evaluate(expected, input.asOutput()); + evaluate(expected, input, null); } /** @@ -437,21 +809,23 @@ public void evaluate(FloatNdArray expected, Operand input) * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract void evaluate(FloatNdArray expected, Output input); + public abstract void evaluate( + FloatNdArray expected, Operand input, Map, Tensor> feedMap); /** * Evaluates the input against the expected value * * @param input the operand to evaluate - * @param predicate the Predicate + * @param predicate The Predicate that evaluates the each value from input * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(Operand input, Predicate predicate) { - evaluate(input.asOutput(), predicate); + evaluate(input, predicate, null); } /** @@ -459,10 +833,12 @@ public void evaluate(Operand input, Predicate predi * * @param input the operand to evaluate * @param predicate The Predicate that evaluates the each value from input + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract void evaluate(Output input, Predicate predicate); + public abstract void evaluate( + Operand input, Predicate predicate, Map, Tensor> feedMap); /** * Evaluates the input against the expected string value @@ -471,7 +847,7 @@ public void evaluate(Operand input, Predicate predi * @param predicate The Predicate that evaluates the each value from input */ public void evaluateString(Operand input, Predicate predicate) { - evaluateString(input.asOutput(), predicate); + evaluateString(input, predicate, null); } /** @@ -479,8 +855,13 @@ public void evaluateString(Operand input, Predicate predicate) * * @param input the operand to evaluate * @param predicate The Predicate that evaluates the each value from input + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract void evaluateString(Output input, Predicate predicate); + public abstract void evaluateString( + Operand input, + Predicate predicate, + Map, Tensor> feedMap); /** * Evaluates the input against the expected value @@ -494,106 +875,110 @@ public void evaluate(FloatNdArray input, Predicate predicate) { } /** - * Evaluates the input against the expected value - * - * @param input the operand to evaluate - * @param predicate The Predicate that evaluates the each value from input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails - */ - public void evaluate(DoubleNdArray input, Predicate predicate) { - input.scalars().forEach(f -> assertTrue(predicate.test(f.getDouble()))); - } - - /** - * Print the input + * Prints the input's values to standard out * - * @param out the output stream * @param input the operand to print * @param the data type of the input + * @throws IllegalArgumentException if the data type for the input does not have a print function + * registered. */ - public void print(OutputStream out, Operand input) { - print(new PrintWriter(new OutputStreamWriter(out)), input.asOutput()); + public void print(Operand input) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input, null); } /** - * Print the input to standard out + * Prints the input's values to standard out * - * @param input the op to print + * @param input the operand to print + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input + * @throws IllegalArgumentException if the data type for the input does not have a print function + * registered. */ - public void print(Op input) { - print(new PrintWriter(new OutputStreamWriter(System.out)), input.op().output(0)); + public void print( + Operand input, Map, Tensor> feedMap) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input, feedMap); } /** - * Print the input + * Prints the input's values to the output stream * * @param out the output stream - * @param input the op to print - */ - public void print(OutputStream out, Op input) { - print(new PrintWriter(new OutputStreamWriter(out)), input.op().output(0)); - } - - /** - * Print the input to standard out - * - * @param input the op to print + * @param input the operand to print * @param the data type of the input + * @throws IllegalArgumentException if the data type for the input does not have a print function + * registered. */ - public void print(Output input) { - print(new PrintWriter(new OutputStreamWriter(System.out)), input); + public void print(OutputStream out, Operand input) { + print(new PrintWriter(new OutputStreamWriter(out)), input, null); } /** - * Print the input + * Prints the input's values to the output stream * * @param out the output stream - * @param input the op to print + * @param input the operand to print + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @param the data type of the input + * @throws IllegalArgumentException if the data type for the input does not have a print function + * registered. */ - public void print(OutputStream out, Output input) { - print(new PrintWriter(new OutputStreamWriter(out)), input); + public void print( + OutputStream out, Operand input, Map, Tensor> feedMap) { + print(new PrintWriter(new OutputStreamWriter(out)), input, feedMap); } /** - * Print the input + * Prints the input's values to the PrintWriter * * @param writer the output writer - * @param input the operand to print + * @param input the op to print * @param the data type of the input + * @throws IllegalArgumentException if the data type for the input does not have a print function + * registered. */ public void print(Writer writer, Operand input) { - print(new PrintWriter(writer), input.asOutput()); + print(new PrintWriter(writer), input, null); } /** - * Print the input + * Prints the input's values to the PrintWriter * * @param writer the output writer * @param input the op to print + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input + * @throws IllegalArgumentException if the data type for the input does not have a print function + * registered. */ - public void print(Writer writer, Op input) { - print(new PrintWriter(writer), input.op().output(0)); + public void print( + Writer writer, Operand input, Map, Tensor> feedMap) { + print(new PrintWriter(writer), input, feedMap); } /** - * Print the input + * Prints the input's values to the PrintWriter * * @param writer the output writer * @param input the op to print - * @param the data type of the input + * @throws IllegalArgumentException if the data type for the input does not have a print function + * registered. */ - public void print(Writer writer, Output input) { - print(new PrintWriter(writer), input); + public void print(PrintWriter writer, Operand input) { + print(writer, input, null); } /** - * Print the input + * Prints the input's values to the PrintWriter * * @param writer the output writer * @param input the op to print + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @throws IllegalArgumentException if the data type for the input does not have a print function + * registered. */ - public abstract void print(PrintWriter writer, Output input); + public abstract void print( + PrintWriter writer, Operand input, Map, Tensor> feedMap); /** * Get the TensorFlow Ops @@ -619,36 +1004,246 @@ public boolean isGraph() { } /** - * Get the epsilon value for evaluating float values + * Get the TensorFlow session object associated with this Test Session * - * @return the epsilon value for evaluating float values + * @return a TensorFlow session if this is a Graph session, otherwise null + */ + public abstract Session getGraphSession(); + + /** + * Get the TensorFlow eager session object associated with this Test Session + * + * @return a TensorFlow session if this is an eager session, otherwise null + */ + public abstract EagerSession getEagerSession(); + + // The following methods are called by the subclasses, + // after resolving the tensor for the Operands + + /** + * Evaluates the tensor's values against the expected value + * + * @param expected the expected value + * @param tensor the tensor whose values are compared to the expected values. + * @param type the data type of the tensor + * @param the data type of the tensor + * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public float getEpsilon() { - return this.epsilon; + @SuppressWarnings("unchecked") + protected void evaluate(double expected, Tensor tensor, Class type) { + boolean isScalar = tensor.shape().equals(Shape.scalar()); + if (debug) { + print(DEFAULT_WRITER, tensor, type); + } + BiConsumer> evaluateFunc = evalMap.get(type); + if (evaluateFunc == null) fail("Unexpected Type Class: " + type); + if (isScalar) evaluateFunc.accept(expected, (NdArray) tensor); + else ((NdArray) tensor).scalars().forEach(f -> evaluateFunc.accept(expected, f)); } /** - * Set the epsilon value for evaluating float values + * Evaluates the tensor's values against the expected values * - * @param epsilon the epsilon value for evaluating float values + * @param expected the expected values + * @param tensor the tensor whose values are compared to the expected values. + * @param type the data type of the tensor + * @param the data type of the tensor + * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void setEpsilon(float epsilon) { - this.epsilon = epsilon; + @SuppressWarnings("unchecked") + protected void evaluate(Number[] expected, Tensor tensor, Class type) { + int size = tensor.shape().size() == 0 ? 1 : (int) tensor.shape().size(); + assertEquals( + expected.length, + size, + () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); + + boolean isScalar = tensor.shape().equals(Shape.scalar()); + AtomicInteger index = new AtomicInteger(); + if (debug) { + print(DEFAULT_WRITER, tensor, type); + } + BiConsumer> evaluateFunc = evalMap.get(type); + if (evaluateFunc == null) fail("Unexpected Type Class: " + type); + if (isScalar) evaluateFunc.accept(expected[0], (NdArray) tensor); + else + ((NdArray) tensor) + .scalars() + .forEach(f -> evaluateFunc.accept(expected[index.getAndIncrement()], f)); } /** - * Get the TensorFlow session object associated with this Test Session + * Evaluates the tensor's values against the expected values * - * @return a TensorFlow session if this is a Graph session, otherwise null + * @param expected the expected values + * @param tensor the tensor whose values are compared to the expected values. + * @param type the data type of the tensor + * @param the data type of the tensor + * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract Session getGraphSession(); + @SuppressWarnings("unchecked") + protected void evaluate(FloatNdArray expected, Tensor tensor, Class type) { + boolean isScalar = tensor.shape().equals(Shape.scalar()); + AtomicInteger index = new AtomicInteger(); + if (debug) { + print(DEFAULT_WRITER, tensor, type); + } + BiConsumer> evaluateFunc = evalMap.get(type); + if (evaluateFunc == null) fail("Unexpected Type Class: " + type); + if (isScalar) + evaluateFunc.accept(expected.getObject(index.getAndIncrement()), (NdArray) tensor); + else + ((NdArray) tensor) + .scalars() + .forEach(f -> evaluateFunc.accept(expected.getObject(index.getAndIncrement()), f)); + } /** - * Get the TensorFlow eager session object associated with this Test Session + * Evaluates the tensor's values against the predicate test * - * @return a TensorFlow session if this is an eager session, otherwise null + * @param tensor the tensor to evaluate + * @param predicate the predicate to test the value of the tensor values + * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract EagerSession getEagerSession(); + protected void evaluateString(TString tensor, Predicate predicate) { + AtomicInteger index = new AtomicInteger(); + boolean isScalar = tensor.shape().equals(Shape.scalar()); + if (debug) { + print(DEFAULT_WRITER, tensor, TString.class); + } + index.set(0); + if (isScalar) { + assertTrue(predicate.test(tensor.getObject())); + } else { + tensor.scalars().forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); + } + } + + /** + * Evaluates the tensor's values against the predicate test + * + * @param tensor the tensor to evaluate + * @param type the data type of the tensor + * @param predicate the predicate to test the value of the tensor values + * @param the data type of the tensor + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + @SuppressWarnings("unchecked") + protected void evaluate( + Tensor tensor, Class type, Predicate predicate) { + boolean isScalar = tensor.shape().equals(Shape.scalar()); + if (debug) { + print(DEFAULT_WRITER, tensor, type); + } + BiConsumer, NdArray> evalFunc = evalPredicate.get(type); + if (evalFunc == null) fail("Unexpected Type Class: " + type); + if (isScalar) { + evalFunc.accept(predicate, (NdArray) tensor); + } else { + ((NdArray) tensor).scalars().forEach(f -> evalFunc.accept(predicate, f)); + } + } + + /** + * Evaluates the tensor's values against the expected values + * + * @param expected the expected values + * @param tensor the tensor whose values are compared to the expected values + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + protected void evaluate(String[] expected, TString tensor) { + int size = tensor.shape().size() == 0 ? 1 : (int) tensor.shape().size(); + assertEquals( + expected.length, + size, + () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); + AtomicInteger index = new AtomicInteger(); + boolean isScalar = tensor.shape().equals(Shape.scalar()); + if (debug) { + print(DEFAULT_WRITER, tensor, TString.class); + } + if (isScalar) assertEquals(expected[0], tensor.getObject()); + else + tensor.scalars().forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); + } + + /** + * Evaluates the tensor's values against the expected values + * + * @param expected the expected value + * @param tensor the tensor whose values are compared to the expected values + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + protected void evaluate(Boolean[] expected, TBool tensor) { + int size = tensor.shape().size() == 0 ? 1 : (int) tensor.shape().size(); + assertEquals( + expected.length, + size, + () -> String.format("expected size (%d) != to input length (%d)", expected.length, size)); + AtomicInteger index = new AtomicInteger(); + boolean isScalar = tensor.shape().equals(Shape.scalar()); + if (debug) { + print(DEFAULT_WRITER, tensor, TBool.class); + } + if (isScalar) assertEquals(expected[index.getAndIncrement()], tensor.getBoolean()); + else + tensor + .scalars() + .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getBoolean())); + } + + /** + * Evaluates the tensor's values against the expected tensor's values + * + * @param expected the tensor whose values are expected + * @param tensor the tensor whose values are compared to the expected values. + * @param type the data type of the tensor + * @param the data type of the tensor + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + @SuppressWarnings("unchecked") + protected void evaluate(Tensor expected, Tensor tensor, Class type) { + assert tensor.shape().equals(expected.shape()) + : String.format( + "expected shape (%s) != to input shape (%s)", + expected.shape().toString(), tensor.shape().toString()); + boolean isScalar = tensor.shape().equals(Shape.scalar()); + + AtomicInteger index = new AtomicInteger(); + if (debug) { + print(DEFAULT_WRITER, tensor, type); + } + index.set(0); + BiConsumer> evaluateFunc = evalMap.get(type); + if (evaluateFunc == null) fail("Unexpected Type Class: " + type); + NdArray expectedArray = (NdArray) expected; + if (isScalar) evaluateFunc.accept(expectedArray.getObject(), (NdArray) tensor); + else + ((NdArray) tensor) + .scalars() + .forEachIndexed((idx, f) -> evaluateFunc.accept(expectedArray.getObject(idx), f)); + } + + /** + * Prints the tensor's values to the print writer + * + * @param writer the output writer + * @param tensor teh tensor to print + * @param type the data type of the tensor + * @param the data type of the tensor + * @throws IllegalArgumentException if the data type for the tensor does not have a print function + * registered. + */ + @SuppressWarnings("unchecked") + protected void print(PrintWriter writer, Tensor tensor, Class type) { + boolean isScalar = tensor.shape().equals(Shape.scalar()); + TriConsumer> printFunc = printMap.get(type); + if (printFunc == null) throw new IllegalArgumentException("Unexpected Type Class: " + type); + if (isScalar) printFunc.accept(writer, ZERO_IDX, (NdArray) tensor); + else + ((NdArray) tensor).scalars().forEachIndexed((idx, f) -> printFunc.accept(writer, idx, f)); + writer.flush(); + } /** {@inheritDoc} */ @Override diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TriConsumer.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TriConsumer.java new file mode 100644 index 00000000000..e67829eca92 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TriConsumer.java @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.utils; + +/** + * Represents an operation that accepts three input arguments and returns no result. This is the + * tertiary specialization of {@link java.util.function.Consumer}. Unlike most other functional + * interfaces, {@code TriConsumer} is expected to operate via side-effects. + * + *

This is a functional interface whose functional method is {@link #accept(Object, Object, + * Object)}. + * + * @param the type of the first argument to the operation + * @param the type of the second argument to the operation + * @param the type of the third argument to the operation + */ +@FunctionalInterface +interface TriConsumer { + + /** + * Performs this operation on the given arguments. + * + * @param t the first input argument + * @param u the second input argument + * @param v the third input argument + */ + void accept(T t, U u, V v); +}