diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java index 5fe51121b13..6686abd9148 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java @@ -55,29 +55,7 @@ * * @param the type of values to be mapped */ -public interface NdArray { - - /** - * @return the shape of this N-dimensional array - */ - Shape shape(); - - /** - * @return the rank of this N-dimensional array - */ - default int rank() { - return shape().numDimensions(); - } - - /** - * Computes and returns the total size of this N-dimensional array, in number of values. - * - *

For example, given a 3x3x2 matrix, the return value will be 18. - * @return total size of this nd array - */ - default long size() { - return shape().size(); - } +public interface NdArray extends Shaped { /** * Returns a sequence of all elements at a given dimension. diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java b/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java index 21b33402e98..8ad55cae7ed 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java @@ -65,7 +65,7 @@ public static ByteNdArray vectorOf(byte... values) { if (values == null) { throw new IllegalArgumentException("Values cannot be null"); } - return wrap(DataBuffers.of(values, false, false), Shape.of(values.length)); + return wrap(Shape.of(values.length), DataBuffers.of(values, false, false)); } /** @@ -81,19 +81,19 @@ public static ByteNdArray ofBytes(Shape shape) { if (shape == null) { throw new IllegalArgumentException("Shape cannot be null"); } - return wrap(DataBuffers.ofBytes(shape.size()), shape); + return wrap(shape, DataBuffers.ofBytes(shape.size())); } /** * Wraps a buffer in a byte N-dimensional array of a given shape. * - * @param buffer buffer to wrap * @param shape shape of the array + * @param buffer buffer to wrap * @return new byte N-dimensional array * @throws IllegalArgumentException if shape is null, has unknown dimensions or has size bigger * in the buffer size */ - public static ByteNdArray wrap(ByteDataBuffer buffer, Shape shape) { + public static ByteNdArray wrap(Shape shape, ByteDataBuffer buffer) { return ByteDenseNdArray.create(buffer, shape); } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/Shaped.java b/ndarray/src/main/java/org/tensorflow/ndarray/Shaped.java new file mode 100644 index 00000000000..fbe19d75623 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/Shaped.java @@ -0,0 +1,51 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +package org.tensorflow.ndarray; + +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.index.Index; + +/** + * Any data container with a given {@link Shape}. + */ +public interface Shaped { + + /** + * @return the shape of this container + */ + Shape shape(); + + /** + * @return the rank of this container + */ + default int rank() { + return shape().numDimensions(); + } + + /** + * Computes and returns the total size of this container, in number of values. + * + *

For example, given a 3x3x2 matrix, the return value will be 18. + * + * @return number of values in this element + */ + default long size() { + return shape().size(); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_BarrierIncompleteSize.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_BarrierIncompleteSize.pbtxt index fb11b18e951..a2fed2a43de 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_BarrierIncompleteSize.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_BarrierIncompleteSize.pbtxt @@ -1,3 +1,7 @@ op { graph_op_name: "BarrierIncompleteSize" + out_arg { + name: "size" + rename_to: "output" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_BarrierReadySize.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_BarrierReadySize.pbtxt index 0ed50b25799..0f768476610 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_BarrierReadySize.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_BarrierReadySize.pbtxt @@ -1,3 +1,7 @@ op { graph_op_name: "BarrierReadySize" + out_arg { + name: "size" + rename_to: "output" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LookupTableSizeV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LookupTableSizeV2.pbtxt index ad646e25a6b..b5526230d76 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LookupTableSizeV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_LookupTableSizeV2.pbtxt @@ -3,4 +3,8 @@ op { endpoint { name: "LookupTableSize" } + out_arg { + name: "size" + rename_to: "output" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MapIncompleteSize.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MapIncompleteSize.pbtxt index 659993e42b0..2472209d20a 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MapIncompleteSize.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MapIncompleteSize.pbtxt @@ -1,3 +1,7 @@ op { graph_op_name: "MapIncompleteSize" + out_arg { + name: "size" + rename_to: "output" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MapSize.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MapSize.pbtxt index 4da151152c9..fe1d5701b4e 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MapSize.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_MapSize.pbtxt @@ -1,3 +1,7 @@ op { graph_op_name: "MapSize" + out_arg { + name: "size" + rename_to: "output" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OrderedMapIncompleteSize.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OrderedMapIncompleteSize.pbtxt index c609e9e50a2..27d68e2d99d 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OrderedMapIncompleteSize.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OrderedMapIncompleteSize.pbtxt @@ -1,3 +1,7 @@ op { graph_op_name: "OrderedMapIncompleteSize" + out_arg { + name: "size" + rename_to: "output" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OrderedMapSize.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OrderedMapSize.pbtxt index 7beef3f376b..30e6215a0ee 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OrderedMapSize.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_OrderedMapSize.pbtxt @@ -1,3 +1,7 @@ op { graph_op_name: "OrderedMapSize" + out_arg { + name: "size" + rename_to: "output" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_QueueSizeV2.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_QueueSizeV2.pbtxt index e93e07a2b32..bc17c8daf96 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_QueueSizeV2.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_QueueSizeV2.pbtxt @@ -3,4 +3,8 @@ op { endpoint { name: "io.QueueSize" } + out_arg { + name: "size" + rename_to: "output" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SetSize.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SetSize.pbtxt index 1c000e9c8aa..78c43275762 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SetSize.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SetSize.pbtxt @@ -1,3 +1,7 @@ op { graph_op_name: "SetSize" + out_arg { + name: "size" + rename_to: "output" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_StageSize.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_StageSize.pbtxt index d8188c3e0b3..a697b775571 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_StageSize.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_StageSize.pbtxt @@ -1,3 +1,7 @@ op { graph_op_name: "StageSize" + out_arg { + name: "size" + rename_to: "output" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_TensorArraySizeV3.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_TensorArraySizeV3.pbtxt index 2df9a2d3f13..55fe2ae46e9 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_TensorArraySizeV3.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_TensorArraySizeV3.pbtxt @@ -3,4 +3,8 @@ op { endpoint { name: "TensorArraySize" } + out_arg { + name: "size" + rename_to: "output" + } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierIncompleteSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierIncompleteSize.java index 72dbe1533d6..34b6b673665 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierIncompleteSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierIncompleteSize.java @@ -53,23 +53,23 @@ public static BarrierIncompleteSize create(Scope scope, Operand handle) * The number of incomplete elements (i.e. those with some of their value * components not set) in the barrier. */ - public Output size() { - return size; + public Output output() { + return output; } @Override public Output asOutput() { - return size; + return output; } /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BarrierIncompleteSize"; - private Output size; + private Output output; private BarrierIncompleteSize(Operation operation) { super(operation); int outputIdx = 0; - size = operation.output(outputIdx++); + output = operation.output(outputIdx++); } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierReadySize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierReadySize.java index 1cb29f97e1d..54910190258 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierReadySize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierReadySize.java @@ -53,23 +53,23 @@ public static BarrierReadySize create(Scope scope, Operand handle) { * The number of complete elements (i.e. those with all of their value * components set) in the barrier. */ - public Output size() { - return size; + public Output output() { + return output; } @Override public Output asOutput() { - return size; + return output; } /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BarrierReadySize"; - private Output size; + private Output output; private BarrierReadySize(Operation operation) { super(operation); int outputIdx = 0; - size = operation.output(outputIdx++); + output = operation.output(outputIdx++); } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableSize.java index 10e84953496..dc4b1f29d9c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableSize.java @@ -51,23 +51,23 @@ public static LookupTableSize create(Scope scope, Operand tableHandle) { /** * Scalar that contains number of elements in the table. */ - public Output size() { - return size; + public Output output() { + return output; } @Override public Output asOutput() { - return size; + return output; } /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LookupTableSizeV2"; - private Output size; + private Output output; private LookupTableSize(Operation operation) { super(operation); int outputIdx = 0; - size = operation.output(outputIdx++); + output = operation.output(outputIdx++); } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapIncompleteSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapIncompleteSize.java index 19e9e87a08a..281297e552b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapIncompleteSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapIncompleteSize.java @@ -147,23 +147,23 @@ public static Options sharedName(String sharedName) { /** */ - public Output size() { - return size; + public Output output() { + return output; } @Override public Output asOutput() { - return size; + return output; } /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "MapIncompleteSize"; - private Output size; + private Output output; private MapIncompleteSize(Operation operation) { super(operation); int outputIdx = 0; - size = operation.output(outputIdx++); + output = operation.output(outputIdx++); } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapSize.java index 7f4eea906f5..a4c497e0011 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapSize.java @@ -147,23 +147,23 @@ public static Options sharedName(String sharedName) { /** */ - public Output size() { - return size; + public Output output() { + return output; } @Override public Output asOutput() { - return size; + return output; } /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "MapSize"; - private Output size; + private Output output; private MapSize(Operation operation) { super(operation); int outputIdx = 0; - size = operation.output(outputIdx++); + output = operation.output(outputIdx++); } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapIncompleteSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapIncompleteSize.java index 865810568db..ccb906bd75d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapIncompleteSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapIncompleteSize.java @@ -147,23 +147,23 @@ public static Options sharedName(String sharedName) { /** */ - public Output size() { - return size; + public Output output() { + return output; } @Override public Output asOutput() { - return size; + return output; } /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "OrderedMapIncompleteSize"; - private Output size; + private Output output; private OrderedMapIncompleteSize(Operation operation) { super(operation); int outputIdx = 0; - size = operation.output(outputIdx++); + output = operation.output(outputIdx++); } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapSize.java index afdee7de1bd..e8120e07e4c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapSize.java @@ -147,23 +147,23 @@ public static Options sharedName(String sharedName) { /** */ - public Output size() { - return size; + public Output output() { + return output; } @Override public Output asOutput() { - return size; + return output; } /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "OrderedMapSize"; - private Output size; + private Output output; private OrderedMapSize(Operation operation) { super(operation); int outputIdx = 0; - size = operation.output(outputIdx++); + output = operation.output(outputIdx++); } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/SetSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/SetSize.java index 2fcdf728542..e2d193f95e1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/SetSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/SetSize.java @@ -100,23 +100,23 @@ public static Options validateIndices(Boolean validateIndices) { * `n-1` dimensions as `set`. Each value is the number of unique elements in * the corresponding `[0...n-1]` dimension of `set`. */ - public Output size() { - return size; + public Output output() { + return output; } @Override public Output asOutput() { - return size; + return output; } /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "SetSize"; - private Output size; + private Output output; private SetSize(Operation operation) { super(operation); int outputIdx = 0; - size = operation.output(outputIdx++); + output = operation.output(outputIdx++); } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageSize.java index 94ef566e708..d7731377203 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageSize.java @@ -147,23 +147,23 @@ public static Options sharedName(String sharedName) { /** */ - public Output size() { - return size; + public Output output() { + return output; } @Override public Output asOutput() { - return size; + return output; } /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "StageSize"; - private Output size; + private Output output; private StageSize(Operation operation) { super(operation); int outputIdx = 0; - size = operation.output(outputIdx++); + output = operation.output(outputIdx++); } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArraySize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArraySize.java index 3e7987bc388..0c07c646bd4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArraySize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArraySize.java @@ -54,23 +54,23 @@ public static TensorArraySize create(Scope scope, Operand handle, Operand size() { - return size; + public Output output() { + return output; } @Override public Output asOutput() { - return size; + return output; } /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "TensorArraySizeV3"; - private Output size; + private Output output; private TensorArraySize(Operation operation) { super(operation); int outputIdx = 0; - size = operation.output(outputIdx++); + output = operation.output(outputIdx++); } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueSize.java index cefc1b3fd11..7706e4c6e81 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueSize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueSize.java @@ -51,23 +51,23 @@ public static QueueSize create(Scope scope, Operand handle) { /** * The number of elements in the given queue. */ - public Output size() { - return size; + public Output output() { + return output; } @Override public Output asOutput() { - return size; + return output; } /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "QueueSizeV2"; - private Output size; + private Output output; private QueueSize(Operation operation) { super(operation); int outputIdx = 0; - size = operation.output(outputIdx++); + output = operation.output(outputIdx++); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java index fa21f32d4ce..3f13e1004e8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java @@ -15,6 +15,8 @@ package org.tensorflow; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.Shaped; import org.tensorflow.op.Op; import org.tensorflow.types.family.TType; @@ -39,7 +41,7 @@ * tf.concat(split, tf.constant(0)); * } */ -public interface Operand extends Op { +public interface Operand extends Op, Shaped { /** * Returns the symbolic handle of the tensor. @@ -76,4 +78,12 @@ default Tensor asTensor() { default T data() { return asOutput().tensor().data(); } + + /** + * Returns the (possibly partially known) shape of the tensor referred to by the {@link Output} of this operand. + */ + @Override + default Shape shape() { + return asOutput().shape(); + } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java index a873df8ff4c..f7977c78474 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java @@ -18,6 +18,7 @@ import java.util.Objects; import org.bytedeco.javacpp.Pointer; import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.Shaped; import org.tensorflow.types.family.TType; /** @@ -36,11 +37,6 @@ public int index() { return index; } - /** Returns the (possibly partially known) shape of the tensor referred to by this Output. */ - public Shape shape() { - return operation.shape(index); - } - /** Returns the DataType of the tensor referred to by this Output. */ @SuppressWarnings("unchecked") public DataType dataType() { @@ -84,6 +80,14 @@ public Tensor tensor() { return (Tensor) operation.tensor(index); } + /** + * Returns the (possibly partially known) shape of the tensor referred to by this output. + */ + @Override + public Shape shape() { + return operation.shape(index); + } + @Override public Operation op() { return operation; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index 6787713418f..62530e923ac 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -26,6 +26,7 @@ import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.Shaped; import org.tensorflow.ndarray.buffer.ByteDataBuffer; import org.tensorflow.types.family.TType; @@ -44,7 +45,7 @@ * } * } */ -public final class Tensor implements AutoCloseable { +public final class Tensor implements Shaped, AutoCloseable { /** * Allocates a tensor of a given datatype and shape. @@ -232,12 +233,8 @@ public long numBytes() { return numBytes; } - /** - * Returns the shape of - * the Tensor, i.e., the sizes of each dimension. - * - * @return shape of this tensor - */ + /** Returns the shape of this tensor. */ + @Override public Shape shape() { return shape; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java index 4f3e9569103..e823ed9f6bd 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java @@ -63,11 +63,11 @@ public class SigmoidCrossEntropyWithLogits { @Endpoint(name = "sigmoidCrossEntropyWithLogits") public static Operand sigmoidCrossEntropyWithLogits( Scope scope, Operand labels, Operand logits) { - if (!isCompatible(labels.asOutput().shape(), logits.asOutput().shape())) { + if (!isCompatible(labels.shape(), logits.shape())) { throw new IllegalArgumentException( String.format( "logits and labels must have the same shape (%s vs %s)", - labels.asOutput().shape().toString(), logits.asOutput().shape())); + labels.shape(), logits.shape())); } scope = scope.withSubScope("SigmoidCrossEntropyWithLogits"); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java index 0c8bac697ed..67cbe3fb98c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java @@ -73,9 +73,9 @@ public class SoftmaxCrossEntropyWithLogits { public static Operand softmaxCrossEntropyWithLogits( Scope scope, Operand labels, Operand logits, int axis) { scope = scope.withSubScope("SoftmaxCrossEntropyWithLogits"); - axis = axis % logits.asOutput().shape().numDimensions(); + axis = axis % logits.shape().numDimensions(); if (axis < 0) { - axis += logits.asOutput().shape().numDimensions(); + axis += logits.shape().numDimensions(); } @@ -96,15 +96,15 @@ public static Operand softmaxCrossEntr } Operand inputRank = Cast.create(scope, Rank.create(scope, logits), TInt64.DTYPE); - Shape shape = logits.asOutput().shape(); + Shape shape = logits.shape(); // Move the dim to the end if dim is not the last dimension. - if (axis != -1 && axis != logits.asOutput().shape().numDimensions() - 1) { + if (axis != -1 && axis != logits.shape().numDimensions() - 1) { logits = moveDimToEnd(scope, logits, axis, inputRank); labels = moveDimToEnd(scope, labels, axis, inputRank); } - Shape inputShape = logits.asOutput().shape(); + Shape inputShape = logits.shape(); logits = flattenOuterDims(scope, logits); labels = flattenOuterDims(scope, labels); @@ -149,7 +149,7 @@ public static Operand softmaxCrossEntr private static Operand flattenOuterDims(Scope scope, Operand logits) { Operand one = Constant.scalarOf(scope, 1L); - Shape shape = logits.asOutput().shape(); + Shape shape = logits.shape(); int ndims = shape.numDimensions(); if (!shape.hasUnknownDimension()) { long product = 1L; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java index ebd6f74e7d8..3598edbe223 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java @@ -80,10 +80,10 @@ public static Operand sparseSoftmaxCrossE if (convertToFloat32) { preciseLogits = Cast.create(scope, logits, TFloat32.DTYPE); } - Shape labelsStaticShape = labels.asOutput().shape(); + Shape labelsStaticShape = labels.shape(); org.tensorflow.op.core.Shape labelsShape = org.tensorflow.op.core.Shape.create(scope, labels); - Shape logitsShape = logits.asOutput().shape(); + Shape logitsShape = logits.shape(); Shape logitsShortened = logitsShape.take(logitsShape.numDimensions() - 1); boolean staticShapesFullyDefined = @@ -98,7 +98,7 @@ public static Operand sparseSoftmaxCrossE throw new IllegalArgumentException( String.format( "Rank mismatch: Rank of labels (received %s) should equal rank of logits minus 1 (received %s).", - labelsStaticShape.toString(), logitsShape.toString())); + labelsStaticShape, logitsShape)); } if (staticShapesFullyDefined && !labelsStaticShape.equals(logitsShortened)) { @@ -107,7 +107,7 @@ public static Operand sparseSoftmaxCrossE "Shape mismatch: The shape of labels (received %s) " + "should equal the shape of logits except for the last " + "dimension (received %s).", - labelsStaticShape.toString(), logitsShape.toString())); + labelsStaticShape, logitsShape)); } // Check if no reshapes are required. if (logitsShape.numDimensions() == 2) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java index d31eebd9007..154e1ecc84a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java @@ -73,7 +73,7 @@ public Softmax(Ops tf, int axis) { */ @Override public Operand call(Operand input) { - Shape shape = input.asOutput().shape(); + Shape shape = input.shape(); int numDimensions = shape.numDimensions(); if (numDimensions == 2) { return tf.nn.softmax(input); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TensorSliceDataset.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TensorSliceDataset.java index 14405ebdaf5..495014f1753 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TensorSliceDataset.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/impl/TensorSliceDataset.java @@ -31,7 +31,7 @@ public TensorSliceDataset(Ops tf, List> components, List> } private static List outputShapes(List> components) { - return components.stream().map(c -> c.asOutput().shape().tail()).collect(Collectors.toList()); + return components.stream().map(c -> c.shape().tail()).collect(Collectors.toList()); } private static Operand makeVariant( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 7a633ede2bf..57875ffd6f2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -535,7 +535,7 @@ public static Operand sparseCategorica predictions = tf.clipByValue(predictions, epsilonConst, oneMinusEpsilonConst); predictions = tf.math.log(predictions); } - Shape predictionsShape = predictions.asOutput().shape(); + Shape predictionsShape = predictions.shape(); int predictionsRank = predictionsShape.numDimensions(); axis %= predictionsRank; if (axis < 0) { @@ -549,7 +549,7 @@ public static Operand sparseCategorica Operand iLabels = cast(tf, labels, TInt64.DTYPE); // Try to adjust the shape so that rank of labels = rank of logits - 1. - Shape labelsShape = labels.asOutput().shape(); + Shape labelsShape = labels.shape(); int labelsRank = labelsShape.numDimensions(); boolean updateShape = labelsRank != predictionsRank - 1; @@ -635,7 +635,7 @@ private static Operand smoothCategoricalLabels( Ops tf, Operand labels, float labelSmoothing) { DataType dataType = labels.asOutput().dataType(); Operand smoothing = cast(tf, tf.constant(labelSmoothing), dataType); - Shape labelsShape = labels.asOutput().shape(); + Shape labelsShape = labels.shape(); int numDims = labelsShape.numDimensions(); Operand numClasses = cast(tf, tf.constant(labelsShape.size(numDims - 1)), dataType); Operand oneMinusSmoothing = cast(tf, tf.constant(1.f - labelSmoothing), dataType); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index 463296a1f50..3e46cf4e825 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -88,13 +88,13 @@ public static LossTuple squeezeOrExpandDimensions( public static LossTuple squeezeOrExpandDimensions( Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { - Shape predictionsShape = predictions.asOutput().shape(); + Shape predictionsShape = predictions.shape(); long predictionsRank = predictionsShape.numDimensions(); // Default case when no modifications are made. LossTuple lossTuple = new LossTuple<>(labels, predictions, sampleWeights); if (labels != null) { - Shape labelsShape = labels.asOutput().shape(); + Shape labelsShape = labels.shape(); long labelsRank = labelsShape.numDimensions(); if (labelsRank != Shape.UNKNOWN_SIZE && predictionsRank != Shape.UNKNOWN_SIZE) { // Use static rank for 'label' and 'prediction'. @@ -108,7 +108,7 @@ public static LossTuple squeezeOrExpandDimensions( if (sampleWeights == null) { // nothing more to do. return lossTuple; } - Shape weightsShape = sampleWeights.asOutput().shape(); + Shape weightsShape = sampleWeights.shape(); long weightsRank = weightsShape.numDimensions(); if (weightsRank == 0) { // scalar return new LossTuple<>(lossTuple.getLabels(), lossTuple.getTarget(), sampleWeights); @@ -200,9 +200,9 @@ public static LossTuple removeSqueezableDimensions( Ops tf, Operand labels, Operand predictions, int expectedRankDiff) { tf = tf.withSubScope("removeSqueezableDimensions"); - Shape predictionsShape = predictions.asOutput().shape(); + Shape predictionsShape = predictions.shape(); int predictionsRank = predictionsShape.numDimensions(); - Shape labelsShape = labels.asOutput().shape(); + Shape labelsShape = labels.shape(); int labelsRank = labelsShape.numDimensions(); if (predictionsRank != Shape.UNKNOWN_SIZE || labelsRank != Shape.UNKNOWN_SIZE) { @@ -280,7 +280,7 @@ private static Operand reduceWeightedLoss( loss = tf.reduceSum(weightedLoss, allAxes(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE)); if (reduction == Reduction.AUTO || reduction == Reduction.SUM_OVER_BATCH_SIZE) { - loss = safeMean(tf, loss, weightedLoss.asOutput().shape().size()); + loss = safeMean(tf, loss, weightedLoss.shape().size()); } } return loss; @@ -312,7 +312,7 @@ public static Operand safeMean( * @return a Constant that represents all the axes of the operand. */ public static Operand allAxes(Ops tf, Operand op) { - int rank = op.asOutput().shape().numDimensions(); + int rank = op.shape().numDimensions(); if (rank != Shape.UNKNOWN_SIZE) { int[] axes = new int[rank]; for (int i = 0; i < rank; i++) { @@ -385,9 +385,9 @@ public static Operand rangeCheck( public static Operand valueCheck( Ops tf, String prefix, Operand values, Operand allowedValues) { Operand flatValues = - tf.reshape(values, tf.constant(Shape.of(values.asOutput().shape().size()))); + tf.reshape(values, tf.constant(Shape.of(values.shape().size()))); SetDiff1d diff = tf.setDiff1d(flatValues, allowedValues, TInt32.DTYPE); - long diffSize = diff.out().asOutput().shape().size(); + long diffSize = diff.out().shape().size(); if (diffSize != Shape.UNKNOWN_SIZE) { if (diffSize != 0) { // at least 1 value in the diff did not match the allowed values. diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java index 5c4ce542c65..339157c99d1 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java @@ -118,16 +118,16 @@ public void testBasic() { Variable[] slotUpdates = new Variable[2]; slots[0] = adaDelta.getSlot(var0.asOutput(), ACCUMULATOR).get(); - assertEquals(slots[0].asOutput().shape(), var0.asOutput().shape()); + assertEquals(slots[0].shape(), var0.shape()); slotUpdates[0] = adaDelta.getSlot(var0.asOutput(), ACCUMULATOR_UPDATE).get(); - assertEquals(slotUpdates[0].asOutput().shape(), var0.asOutput().shape()); + assertEquals(slotUpdates[0].shape(), var0.shape()); slots[1] = adaDelta.getSlot(var1.asOutput(), ACCUMULATOR).get(); - assertEquals(slots[1].asOutput().shape(), var1.asOutput().shape()); + assertEquals(slots[1].shape(), var1.shape()); slotUpdates[1] = adaDelta.getSlot(var1.asOutput(), ACCUMULATOR_UPDATE).get(); - assertEquals(slotUpdates[1].asOutput().shape(), var1.asOutput().shape()); + assertEquals(slotUpdates[1].shape(), var1.shape()); /* initialize the local variables */ session.run(var0Initializer); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java index c5ae178b84c..03717083efc 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java @@ -99,10 +99,10 @@ public void testBasic() { Variable[] accumulatorSlots = new Variable[2]; accumulatorSlots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get(); - assertEquals(accumulatorSlots[0].asOutput().shape(), var0.asOutput().shape()); + assertEquals(accumulatorSlots[0].shape(), var0.shape()); accumulatorSlots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get(); - assertEquals(accumulatorSlots[1].asOutput().shape(), var1.asOutput().shape()); + assertEquals(accumulatorSlots[1].shape(), var1.shape()); /* initialize the local variables */ session.run(var0Initializer); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java index 461fa75397f..4b992b0a79d 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java @@ -111,16 +111,16 @@ public void testBasic() { Variable[] secondMomentSlots = new Variable[2]; firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); - assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + assertEquals(firstMomentSlots[0].shape(), var0.shape()); secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); - assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + assertEquals(secondMomentSlots[0].shape(), var0.shape()); firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); - assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + assertEquals(firstMomentSlots[1].shape(), var1.shape()); secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); - assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + assertEquals(secondMomentSlots[1].shape(), var1.shape()); /* initialize the accumulators */ session.run(tf.init()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java index de17395f76a..a303067cdc8 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java @@ -127,16 +127,16 @@ public void testBasic() { Variable[] secondMomentSlots = new Variable[2]; firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); - assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + assertEquals(firstMomentSlots[0].shape(), var0.shape()); secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); - assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + assertEquals(secondMomentSlots[0].shape(), var0.shape()); firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); - assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + assertEquals(firstMomentSlots[1].shape(), var1.shape()); secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); - assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + assertEquals(secondMomentSlots[1].shape(), var1.shape()); /* initialize the accumulators */ session.run(tf.init()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java index bcfff97773d..3649fbd8287 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java @@ -148,9 +148,9 @@ public void testMomentum() { Op update = instance.applyGradients(gradsAndVars, "SGDTest"); Variable momentumSlot0 = instance.getSlot(var0.asOutput(), MOMENTUM).get(); - assertEquals(momentumSlot0.asOutput().shape(), var0.asOutput().shape()); + assertEquals(momentumSlot0.shape(), var0.shape()); Variable momentumSlot1 = instance.getSlot(var1.asOutput(), MOMENTUM).get(); - assertEquals(momentumSlot1.asOutput().shape(), var1.asOutput().shape()); + assertEquals(momentumSlot1.shape(), var1.shape()); /* initialize the local variables */ session.run(var0Initializer); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java index a583d74246b..6c26aab2995 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java @@ -124,16 +124,16 @@ public void testBasic() { Variable[] secondMomentSlots = new Variable[2]; firstMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.FIRST_MOMENT).get(); - assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + assertEquals(firstMomentSlots[0].shape(), var0.shape()); secondMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.SECOND_MOMENT).get(); - assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + assertEquals(secondMomentSlots[0].shape(), var0.shape()); firstMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.FIRST_MOMENT).get(); - assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + assertEquals(firstMomentSlots[1].shape(), var1.shape()); secondMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.SECOND_MOMENT).get(); - assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + assertEquals(secondMomentSlots[1].shape(), var1.shape()); /* initialize the local variables */ session.run(var0Initializer); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java index 33ddec6dce3..c58667c15d0 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java @@ -1021,7 +1021,7 @@ public void evaluate(Output input, Predicate predic /** {@inheritDoc} */ @Override public void print(PrintWriter writer, Output input) { - boolean isScalar = input.asOutput().shape().size() == 1; + boolean isScalar = input.shape().size() == 1; DataType dtype = input.dataType(); if (dtype == TFloat32.DTYPE) {