From 2576e35752ae11dbb060ef05042b08df1460df26 Mon Sep 17 00:00:00 2001 From: klessard Date: Sat, 28 Nov 2020 16:51:26 -0500 Subject: [PATCH 1/3] Merge TType and Tensor instances as a single entity --- .../src/bazel/op_generator/op_specs.cc | 3 +- .../org/tensorflow/op/DebuggingOps.java | 50 -- .../annotations/org/tensorflow/op/Ops.java | 29 +- .../tensorflow/op/summary/ImageSummary.java | 6 +- .../org/tensorflow/AbstractOperation.java | 2 +- .../java/org/tensorflow/ConcreteFunction.java | 28 +- .../main/java/org/tensorflow/DataType.java | 17 +- .../java/org/tensorflow/EagerOperation.java | 22 +- .../org/tensorflow/EagerOperationBuilder.java | 6 +- .../java/org/tensorflow/GraphOperation.java | 2 +- .../org/tensorflow/GraphOperationBuilder.java | 10 +- .../src/main/java/org/tensorflow/Operand.java | 19 +- .../java/org/tensorflow/OperationBuilder.java | 4 +- .../src/main/java/org/tensorflow/Output.java | 6 +- .../main/java/org/tensorflow/RawTensor.java | 206 ++++++ .../java/org/tensorflow/SavedModelBundle.java | 4 +- .../src/main/java/org/tensorflow/Session.java | 30 +- .../src/main/java/org/tensorflow/Tensor.java | 265 ++------ .../java/org/tensorflow/op/core/Constant.java | 163 +++-- .../java/org/tensorflow/types/TBfloat16.java | 42 +- .../main/java/org/tensorflow/types/TBool.java | 38 +- .../java/org/tensorflow/types/TFloat16.java | 42 +- .../java/org/tensorflow/types/TFloat32.java | 39 +- .../java/org/tensorflow/types/TFloat64.java | 39 +- .../java/org/tensorflow/types/TInt32.java | 39 +- .../java/org/tensorflow/types/TInt64.java | 39 +- .../java/org/tensorflow/types/TString.java | 61 +- .../java/org/tensorflow/types/TUint8.java | 39 +- .../tensorflow/types/family/TFloating.java | 13 +- .../org/tensorflow/types/family/TNumber.java | 13 +- .../org/tensorflow/types/family/TType.java | 41 +- .../org/tensorflow/ConcreteFunctionTest.java | 16 +- .../tensorflow/EagerOperationBuilderTest.java | 2 +- .../org/tensorflow/EagerOperationTest.java | 4 +- .../tensorflow/GraphOperationBuilderTest.java | 10 +- .../org/tensorflow/GraphOperationTest.java | 2 +- .../test/java/org/tensorflow/GraphTest.java | 49 +- .../org/tensorflow/SavedModelBundleTest.java | 48 +- .../test/java/org/tensorflow/SessionTest.java | 52 +- .../test/java/org/tensorflow/TensorTest.java | 182 +++-- .../java/org/tensorflow/op/ScopeTest.java | 11 +- .../org/tensorflow/op/core/ConstantTest.java | 53 +- .../op/core/GeneratedOperationsTest.java | 12 +- .../org/tensorflow/op/core/GradientsTest.java | 20 +- .../org/tensorflow/op/core/ShapesTest.java | 206 +++--- .../org/tensorflow/op/core/ZerosTest.java | 30 +- .../types/NumericTypesTestBase.java | 81 +-- .../org/tensorflow/types/TBfloat16Test.java | 10 +- .../org/tensorflow/types/TFloat16Test.java | 10 +- .../org/tensorflow/types/TFloat32Test.java | 10 +- .../org/tensorflow/types/TFloat64Test.java | 10 +- .../java/org/tensorflow/types/TInt32Test.java | 12 +- .../java/org/tensorflow/types/TInt64Test.java | 10 +- .../org/tensorflow/types/TStringTest.java | 55 +- .../java/org/tensorflow/types/TUint8Test.java | 10 +- .../framework/data/DatasetIterator.java | 2 +- .../framework/losses/impl/LossesHelper.java | 4 +- .../framework/utils/ShapeUtils.java | 74 +-- .../framework/data/BatchDatasetTest.java | 40 +- .../framework/data/DatasetIteratorTest.java | 18 +- .../framework/data/MapDatasetTest.java | 21 +- .../framework/data/SkipDatasetTest.java | 10 +- .../framework/data/TakeDatasetTest.java | 8 +- .../framework/optimizers/AdamTest.java | 18 +- .../framework/optimizers/AdamaxTest.java | 9 +- .../framework/optimizers/NadamTest.java | 18 +- .../framework/utils/EagerTestSession.java | 234 +++---- .../framework/utils/GraphTestSession.java | 626 +++++++++--------- 68 files changed, 1695 insertions(+), 1609 deletions(-) delete mode 100644 tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DebuggingOps.java create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_specs.cc b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_specs.cc index c9e0525edb7..7d184bf2a46 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_specs.cc +++ b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_specs.cc @@ -152,8 +152,7 @@ std::pair TypeResolver::TypesOf(const OpDef_AttrDef& attr_def, types = MakeTypePair(Type::Class("Shape", "org.tensorflow.ndarray")); } else if (attr_type == "tensor") { - types = MakeTypePair(Type::Class("Tensor", "org.tensorflow") - .add_parameter(Type::Wildcard())); + types = MakeTypePair(Type::Class("Tensor", "org.tensorflow")); } else if (attr_type == "type") { Type type = *iterable_out ? Type::Wildcard() : NextGeneric(); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DebuggingOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DebuggingOps.java deleted file mode 100644 index f12d18f925b..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DebuggingOps.java +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2020 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== -// -// This class has been generated, DO NOT EDIT! -// -package org.tensorflow.op; - -import org.tensorflow.Operand; -import org.tensorflow.op.debugging.CheckNumerics; -import org.tensorflow.types.family.TNumber; - -/** - * An API for building {@code debugging} operations as {@link Op Op}s - * - * @see {@link Ops} - */ -public final class DebuggingOps { - private final Scope scope; - - DebuggingOps(Scope scope) { - this.scope = scope; - } - - /** - * Checks a tensor for NaN and Inf values. - *

- * When run, reports an `InvalidArgument` error if `tensor` has any values - * that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is. - * - * @param data type for {@code output()} output - * @param tensor - * @param message Prefix of the error message. - * @return a new instance of CheckNumerics - */ - public CheckNumerics checkNumerics(Operand tensor, String message) { - return CheckNumerics.create(scope, tensor, message); - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 59ab4d05f73..f4e70f54e39 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -24,7 +24,6 @@ import org.tensorflow.EagerSession; import org.tensorflow.ExecutionEnvironment; import org.tensorflow.Operand; -import org.tensorflow.Tensor; import org.tensorflow.ndarray.BooleanNdArray; import org.tensorflow.ndarray.ByteNdArray; import org.tensorflow.ndarray.DoubleNdArray; @@ -349,10 +348,10 @@ public final class Ops { public final SignalOps signal; - public final TrainOps train; - public final QuantizationOps quantization; + public final TrainOps train; + private final Scope scope; private Ops(Scope scope) { @@ -374,8 +373,8 @@ private Ops(Scope scope) { math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - train = new TrainOps(this); quantization = new QuantizationOps(this); + train = new TrainOps(this); } /** @@ -1073,6 +1072,17 @@ public Bucketize bucketize(Operand input, List bou return Bucketize.create(scope, input, boundaries); } + /** + * Capture a {@code tensor} by making a constant copy of it. + * + * @param scope is a scope used to add the underlying operation. + * @param tensor a Tensor holding the constant value + * @return a constant of the same data type as `tensor` + */ + public Constant capture(T tensor) { + return Constant.create(scope, tensor); + } + /** * Clips tensor values to a specified min and max. *

@@ -1708,17 +1718,6 @@ public Constant constant(Shape shape) { return Constant.tensorOf(scope, shape); } - /** - * Create a constant from a Tensor. - * - * @param scope is a scope used to add the underlying operation. - * @param tensor a Tensor holding the constant value - * @return a constant of the same data type as `tensor` - */ - public Constant constant(Tensor tensor) { - return Constant.create(scope, tensor); - } - /** * Creates a constant of {@code String} elements, using the given charset. * diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImageSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImageSummary.java index 13abcd4276d..796f5e24bfc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImageSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImageSummary.java @@ -99,13 +99,13 @@ public Options maxImages(Long maxImages) { /** * @param badColor Color to use for pixels with non-finite values. */ - public Options badColor(Tensor badColor) { + public Options badColor(Tensor badColor) { this.badColor = badColor; return this; } private Long maxImages; - private Tensor badColor; + private Tensor badColor; private Options() { } @@ -150,7 +150,7 @@ public static Options maxImages(Long maxImages) { /** * @param badColor Color to use for pixels with non-finite values. */ - public static Options badColor(Tensor badColor) { + public static Options badColor(Tensor badColor) { return new Options().badColor(badColor); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java index 96da6bc5ff4..18f42d08e82 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java @@ -86,5 +86,5 @@ public String toString() { * @param outputIdx index of the output of this operation * @return output tensor */ - abstract Tensor tensor(int outputIdx); + abstract Tensor tensor(int outputIdx); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 872b4b4d16d..0bb0d20cae3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -34,7 +34,7 @@ * *

{@code
  * ConcreteFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
- * Map> outputTensorMap = myFunction.call(inputTensorMap);
+ * Map outputTensorMap = myFunction.call(inputTensorMap);
  * }
*/ public class ConcreteFunction implements AutoCloseable { @@ -61,8 +61,8 @@ public class ConcreteFunction implements AutoCloseable { * * public static void main(String args[]) { * try (ConcreteFunction function = ConcreteFunction.create(MyModel::addTwo); - * Tensor x = TFloat32.scalarOf(2.0f)) { - * assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat()); + * TFloat32 x = TFloat32.scalarOf(2.0f)) { + * assertEquals(4.0f, ((TFloat32)function.call(x)).getFloat()); * } * } * } @@ -97,8 +97,8 @@ public static ConcreteFunction create(Function functionBuilder) * Signature signature = Signature.builder().input("x", input).output("y", output).build(); * * try (ConcreteFunction f = ConcreteFunction.create(signature, g); - * Tensor x = TFloat32.scalarOf(2.0f)) { - * assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat()); + * TFloat32 x = TFloat32.scalarOf(2.0f)) { + * assertEquals(4.0f, ((TFloat32)function.call(x)).getFloat()); * } * // Graph g is still valid at this point * } @@ -129,8 +129,8 @@ public static ConcreteFunction create(Signature signature, Graph graph) { * // Auto-closing the function just as an example but this is not required since it has * // no effect * try (ConcreteFunction f = ConcreteFunction.create(signature, s); - * Tensor t = TFloat32.scalarOf(2.0f)) { - * assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat()); + * TFloat32 t = TFloat32.scalarOf(2.0f)) { + * assertEquals(4.0f, ((TFloat32)function.call(x)).getFloat()); * } * // Session s is still valid at this point * } @@ -163,14 +163,14 @@ public Signature signature() { * @return output tensors resulting from the execution of the function, * mapped by their signature name */ - public Map> call(Map> arguments) + public Map call(Map arguments) throws IllegalArgumentException { final SignatureDef signatureDef = signature.asSignatureDef(); final Session.Runner runner = session.runner(); signatureDef.getInputsMap().forEach((argName, t) -> { - Tensor tensor = arguments.get(argName); + Tensor tensor = arguments.get(argName); if (tensor == null) { throw new IllegalArgumentException(String.format("Missing argument [%s]", argName)); } @@ -180,10 +180,10 @@ public Map> call(Map> arguments) Map outputToNode = signatureDef.getOutputsMap(); outputToNode.values().forEach(t -> runner.fetch(t.getName())); - List> resultTensors = runner.run(); + List resultTensors = runner.run(); try { - ListIterator> resultTensorIter = resultTensors.listIterator(); - Map> returnMap = new HashMap>(); + ListIterator resultTensorIter = resultTensors.listIterator(); + Map returnMap = new HashMap(); // Use the output names as present in the signature definition for (String nodeName: outputToNode.keySet()) { @@ -193,7 +193,7 @@ public Map> call(Map> arguments) } catch (Exception e) { // Release tensors before throwing exception - for (Tensor t : resultTensors) { + for (Tensor t : resultTensors) { t.close(); } throw e; @@ -210,7 +210,7 @@ public Map> call(Map> arguments) * @throws IllegalArgumentException if there are multiple input or output parameters defined * in the function */ - public Tensor call(Tensor tensor) throws IllegalArgumentException { + public Tensor call(Tensor tensor) throws IllegalArgumentException { final SignatureDef signatureDef = signature.asSignatureDef(); if (signatureDef.getInputsCount() != 1) { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java index 7b76b6dd02e..f76dc1696a7 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java @@ -16,7 +16,6 @@ package org.tensorflow; import org.tensorflow.internal.c_api.TF_Tensor; -import org.tensorflow.ndarray.Shape; import org.tensorflow.types.TBfloat16; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat16; @@ -35,13 +34,13 @@ public final class DataType { public interface TensorMapper { /** - * Maps tensor memory to a data structure for manipulating elements of this type. + * Maps the tensor memory to a n-dimensional typed data space. * - * @param nativeTensor pointer to the native tensor - * @param shape the shape of the tensor - * @return data structure of elements of this type + * @param tensor the tensor to map in its raw nature + * @param nativeHandle native handle of the tensor + * @return a typed tensor of type {@code T} */ - T apply(TF_Tensor nativeTensor, Shape shape); + T apply(RawTensor tensor, TF_Tensor nativeHandle); } /** @@ -158,13 +157,13 @@ int nativeCode() { } /** - * Maps a tensor to a data structure for manipulating elements of this type. + * Maps a raw tensor to a typed tensor. * * @param tensor tensor to map * @return data structure of elements of this type */ - T map(Tensor tensor) { - return tensorMapper.apply(tensor.nativeHandle(), tensor.shape()); + T map(RawTensor tensor) { + return tensorMapper.apply(tensor, tensor.nativeHandle()); } private final int nativeCode; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java index 012981ac59c..30387e390ed 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java @@ -91,7 +91,7 @@ public TFE_TensorHandle getUnsafeNativeHandle(int outputIndex) { public Shape shape(int outputIndex) { // If the tensor of this output has already been resolved, return its shape. // Otherwise, retrieve the tensor shape from the native library. - Tensor tensor = outputTensors.get(outputIndex); + Tensor tensor = outputTensors.get(outputIndex); if (tensor != null) { return tensor.shape(); } @@ -107,7 +107,7 @@ public Shape shape(int outputIndex) { public DataType dtype(int outputIndex) { // If the tensor of this output has already been resolved, return its datatype. // Otherwise, retrieve the tensor datatype from the native library. - Tensor tensor = outputTensors.get(outputIndex); + Tensor tensor = outputTensors.get(outputIndex); if (tensor != null) { return tensor.dataType(); } @@ -116,8 +116,8 @@ public DataType dtype(int outputIndex) { } @Override - public Tensor tensor(int outputIndex) { - Tensor tensor = outputTensors.get(outputIndex); + public Tensor tensor(int outputIndex) { + Tensor tensor = outputTensors.get(outputIndex); if (tensor == null) { tensor = resolveTensor(outputIndex); } @@ -127,21 +127,21 @@ public Tensor tensor(int outputIndex) { private final EagerSession session; private final String type; private final String name; - private final AtomicReferenceArray> outputTensors; + private final AtomicReferenceArray outputTensors; - private Tensor resolveTensor(int outputIndex) { + private Tensor resolveTensor(int outputIndex) { // Take an optimistic approach, where we attempt to resolve the output tensor without locking. // If another thread has resolved it meanwhile, release our copy and reuse the existing one // instead. - Tensor tensor = resolveTensorHandle(getUnsafeNativeHandle(outputIndex), session); + Tensor tensor = resolveTensorHandle(getUnsafeNativeHandle(outputIndex), session); if (!outputTensors.compareAndSet(outputIndex, null, tensor)) { - session.detach(tensor.nativeHandle()); + session.detach(tensor.asRawTensor().nativeHandle()); tensor = outputTensors.get(outputIndex); } return tensor; } - private TFE_Op opHandle; + private final TFE_Op opHandle; private final TFE_TensorHandle[] outputHandles; private static void requireOp(TFE_Op handle) { @@ -156,13 +156,13 @@ private static void requireTensorHandle(TFE_TensorHandle handle) { } } - private static Tensor resolveTensorHandle(TFE_TensorHandle handle, EagerSession session) { + private static Tensor resolveTensorHandle(TFE_TensorHandle handle, EagerSession session) { requireTensorHandle(handle); try (PointerScope scope = new PointerScope()) { TF_Status status = TF_Status.newStatus(); TF_Tensor tensor = TFE_TensorHandleResolve(handle, status).withDeallocator(); status.throwExceptionIfNotOK(); - return Tensor.fromHandle(tensor, session); + return RawTensor.fromHandle(tensor, session).asTypedTensor(); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java index ad11e63c7c8..5c975929fee 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -175,13 +175,13 @@ public EagerOperationBuilder setAttr(String name, DataType[] values) { } @Override - public EagerOperationBuilder setAttr(String name, Tensor value) { - setAttrTensor(opHandle, name, value.nativeHandle()); + public EagerOperationBuilder setAttr(String name, Tensor value) { + setAttrTensor(opHandle, name, value.asRawTensor().nativeHandle()); return this; } @Override - public EagerOperationBuilder setAttr(String name, Tensor[] values) { + public EagerOperationBuilder setAttr(String name, Tensor[] values) { // TODO (karllessard) could be supported by adding this attribute type in the eager C API throw new UnsupportedOperationException( "Tensor list attributes are not supported in eager mode"); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java index 70cd31366ce..d2fbc4e4995 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java @@ -158,7 +158,7 @@ DataType dtype(int outputIdx) { } @Override - Tensor tensor(int outputIdx) { + Tensor tensor(int outputIdx) { throw new IllegalStateException("Graph tensors must be fetched by running a session"); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java index 2ef5c9010a1..5fda65480e9 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java @@ -247,10 +247,10 @@ public GraphOperationBuilder setAttr(String name, DataType[] value) { } @Override - public GraphOperationBuilder setAttr(String name, Tensor value) { + public GraphOperationBuilder setAttr(String name, Tensor value) { Graph.Reference r = graph.ref(); try { - setAttrTensor(unsafeNativeHandle, name, value.nativeHandle()); + setAttrTensor(unsafeNativeHandle, name, value.asRawTensor().nativeHandle()); } finally { r.close(); } @@ -258,11 +258,11 @@ public GraphOperationBuilder setAttr(String name, Tensor value) { } @Override - public GraphOperationBuilder setAttr(String name, Tensor[] value) { + public GraphOperationBuilder setAttr(String name, Tensor[] value) { TF_Tensor[] handles = new TF_Tensor[value.length]; int idx = 0; - for (Tensor t : value) { - handles[idx++] = t.nativeHandle(); + for (Tensor t : value) { + handles[idx++] = t.asRawTensor().nativeHandle(); } Graph.Reference r = graph.ref(); try { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java index 3f13e1004e8..31a93fb999a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java @@ -54,29 +54,22 @@ public interface Operand extends Op, Shaped { Output asOutput(); /** - * Returns this operand as a tensor. + * Returns the tensor at this operand. * * Only works when running in an eager execution - *

This helper method is equivalent to {@code asOutput().tensor()} * * @return the tensor * @throws IllegalStateException if this is an operand of a graph */ - default Tensor asTensor() { - return asOutput().tensor(); + default T asTensor() { + return asOutput().asTensor(); } /** - * Returns the data of this operand. - * - * Only works when running in an eager execution - *

This helper method is equivalent to {@code asTensor().data()} - * - * @return the tensor data - * @throws IllegalStateException if this is an operand of a graph + * Returns the data type of this operand */ - default T data() { - return asOutput().tensor().data(); + default DataType dataType() { + return asOutput().dataType(); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java index af1b8cc9130..79f21c33fb7 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/OperationBuilder.java @@ -195,7 +195,7 @@ public interface OperationBuilder { * @param value attribute value * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, Tensor value); + OperationBuilder setAttr(String name, Tensor value); /** * Set the tensor values of an attribute of the operation being built. @@ -204,7 +204,7 @@ public interface OperationBuilder { * @param value attribute values * @return the OperationBuilder instance for chaining. */ - OperationBuilder setAttr(String name, Tensor[] value); + OperationBuilder setAttr(String name, Tensor[] value); /** * Set the shape value of an attribute of the operation being built. diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java index f7977c78474..5b8337f4d70 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java @@ -60,7 +60,6 @@ public Output expect(DataType dt) { return ((Output) this); } - /** * Returns the tensor at this output. * @@ -73,11 +72,12 @@ public Output expect(DataType dt) { * * @return tensor * @throws IllegalStateException if this output results from a graph + * @throws ClassCastException if the type of the tensor and this output are unexpectedly incompatible * @see EagerSession */ @SuppressWarnings("unchecked") - public Tensor tensor() { - return (Tensor) operation.tensor(index); + public T asTensor() { + return (T)operation.tensor(index); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java new file mode 100644 index 00000000000..dde040ff141 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java @@ -0,0 +1,206 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +import static org.tensorflow.internal.c_api.global.tensorflow.TF_Dim; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_NumDims; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_TensorByteSize; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_TensorType; + +import org.bytedeco.javacpp.PointerScope; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.Shaped; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.types.family.TType; + +/** + * A tensor which memory has not been mapped. + * + *

A raw tensor is a minimalist representation of a tensor allocated in native memory by the + * TensorFlow runtime library and it controls its lifetime within the current process. The data + * is represented by a flat {@link ByteDataBuffer buffer of bytes}, until it is mapped in a + * n-dimensional typed space by a {@link TType typed tensor}.

+ * + *

Instances of a RawTensor are not thread-safe and their resource must be released + * by calling {@link #close()} explicitly or implicitly (try-with-resources).

+ */ +public final class RawTensor implements Tensor { + + /** + * Returns a typed version of this tensor + */ + TType asTypedTensor() { + return dtype.map(this); + } + + @Override + public DataType dataType() { + return dtype; + } + + @Override + public long numBytes() { + return TF_TensorByteSize(nativeHandle()); + } + + @Override + public Shape shape() { + return shape; + } + + @Override + public RawTensor asRawTensor() { + return this; + } + + @Override + public void close() { + tensorScope.close(); + } + + /** + * Returns the raw data of this tensor as a buffer of bytes. + * + * @return the tensor bytes + * @throws IllegalStateException if the tensor has been closed + */ + public ByteDataBuffer data() { + if (buffer == null) { + buffer = TensorBuffers.toBytes(nativeHandle()); + } + return buffer; + } + + /** + * Returns a string describing the type and shape of the tensor. + */ + @Override + public String toString() { + return String.format("%s tensor with shape %s", dtype.toString(), shape); + } + + /** + * Allocates a new tensor in native memory of the given type, shape and size. + * + *

The size of the tensor must be at least large enough to contain all scalars for the + * given type and shape, i.e. size >= dtype.byteSize() * shape.size(). More memory + * can be allocated to store also metadata within the tensor itself, e.g. a lookup table + * in a string tensor. + * + * @param dtype data type + * @param shape shape of the tensor + * @param size size of the tensor + * @return allocated tensor + */ + static RawTensor allocate(DataType dtype, Shape shape, long size) { + // Minimum requirements for datatypes of variable length cannot be verified in a relevant way so + // we only validate them for fixed length datatypes + if (!dtype.isVariableLength() && shape.size() * dtype.byteSize() > size) { + throw new IllegalArgumentException("Tensor size is not large enough to contain all scalar values"); + } + TF_Tensor nativeHandle = allocate(dtype.nativeCode(), shape.asArray(), size); + try (PointerScope scope = new PointerScope()) { + scope.attach(nativeHandle); + + RawTensor t = new RawTensor(dtype, shape); + t.tensorHandle = nativeHandle; + t.tensorScope = scope.extend(); + return t; + } + } + + /** + * Create a Tensor object from a handle to the C TF_Tensor object. + * + *

Takes ownership of the handle. + */ + static RawTensor fromHandle(TF_Tensor handle) { + RawTensor t = new RawTensor(DataTypes.fromNativeCode(dtype(handle)), Shape.of(shape(handle))); + try (PointerScope scope = new PointerScope()) { + scope.attach(handle); + t.tensorHandle = handle; + t.tensorScope = scope.extend(); + } + return t; + } + + /** + * Create an eager Tensor object from a handle to the C TF_Tensor object. + * + *

Takes ownership of the handle. + */ + static RawTensor fromHandle(TF_Tensor handle, EagerSession session) { + RawTensor t = fromHandle(handle); + session.attach(handle); + t.tensorScope.detach(handle); + return t; + } + + /** + * @return native handle to this tensor + * @throws IllegalStateException if tensor has been closed + */ + TF_Tensor nativeHandle() { + return requireHandle(tensorHandle); + } + + private static TF_Tensor requireHandle(TF_Tensor handle) { + if (handle == null || handle.isNull()) { + throw new IllegalStateException("close() was called on the Tensor"); + } + return handle; + } + + private static TF_Tensor allocate(int dtype, long[] shape, long byteSize) { + TF_Tensor t = TF_Tensor.allocateTensor(dtype, shape, byteSize); + if (t == null || t.isNull()) { + throw new IllegalStateException("unable to allocate memory for the Tensor"); + } + return t; + } + + private static int dtype(TF_Tensor handle) { + requireHandle(handle); + return TF_TensorType(handle); + } + + private static long[] shape(TF_Tensor handle) { + requireHandle(handle); + int numDims = TF_NumDims(handle); + long[] dims = new long[numDims]; + for (int i = 0; i < numDims; ++i) { + dims[i] = TF_Dim(handle, i); + } + return dims; + } + + RawTensor(DataType dtype, Shape shape) { + this.dtype = dtype; + this.shape = shape; + } + + private PointerScope tensorScope; + private TF_Tensor tensorHandle; + private final DataType dtype; + private final Shape shape; + private ByteDataBuffer buffer = null; + + static { + TensorFlow.init(); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 093898ae56c..0974cc94a24 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -301,7 +301,7 @@ public List signatures() { * *

{@code
    * ConcreteFunction myFunction = savedModelBundle.function("mySignatureKey");
-   * Map> outputTensorMap = myFunction.call(session, inputTensorMap);
+   * Map outputTensorMap = myFunction.call(session, inputTensorMap);
    * }
* * @param signatureKey name of the {@code SignatureDef} in the saved model. @@ -334,7 +334,7 @@ public ConcreteFunction function(String signatureKey) { * @return list of output tensors, mapped by the signature name * @throws IllegalArgumentException if no function can be selected by default */ - public Map> call(Map> arguments) { + public Map call(Map arguments) { ConcreteFunction function = null; if (functions.size() == 1) { function = functions.values().iterator().next(); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 4e82f3944b8..e9d517a6548 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -158,7 +158,7 @@ public final class Runner { * @param t the tensor substituting the operation * @return this session runner */ - public Runner feed(String operation, Tensor t) { + public Runner feed(String operation, Tensor t) { return feed(parseOutput(operation), t); } @@ -173,7 +173,7 @@ public Runner feed(String operation, Tensor t) { * @param t the tensor substituting the operation * @return this session runner */ - public Runner feed(String operation, int index, Tensor t) { + public Runner feed(String operation, int index, Tensor t) { Operation op = operationByName(operation); if (op != null) { inputs.add(op.output(index)); @@ -190,7 +190,7 @@ public Runner feed(String operation, int index, Tensor t) { * @param t the tensor substituting the operation * @return this session runner */ - public Runner feed(Operand operand, Tensor t) { + public Runner feed(Operand operand, Tensor t) { inputs.add(operand.asOutput()); inputTensors.add(t); return this; @@ -325,7 +325,7 @@ public Runner setOptions(RunOptions options) { * * @return list of resulting tensors fetched by this session runner */ - public List> run() { + public List run() { return runHelper(false).outputs; } @@ -354,8 +354,8 @@ private Run runHelper(boolean wantMetadata) { // It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the // validity of the Graph and graphRef ensures that. int idx = 0; - for (Tensor t : inputTensors) { - inputTensorHandles[idx++] = t.nativeHandle(); + for (Tensor t : inputTensors) { + inputTensorHandles[idx++] = t.asRawTensor().nativeHandle(); } idx = 0; for (Output o : inputs) { @@ -375,7 +375,7 @@ private Run runHelper(boolean wantMetadata) { } Reference runRef = new Reference(); RunMetadata metadata = null; - List> outputs = new ArrayList<>(); + List outputs = new ArrayList<>(); try { metadata = Session.run( @@ -390,7 +390,7 @@ private Run runHelper(boolean wantMetadata) { wantMetadata, outputs); } catch (Exception e) { - for (Tensor t : outputs) { + for (Tensor t : outputs) { t.close(); } outputs.clear(); @@ -450,10 +450,10 @@ private Output parseOutput(String opName) { } } - private ArrayList> inputs = new ArrayList<>(); - private ArrayList> inputTensors = new ArrayList<>(); - private ArrayList> outputs = new ArrayList<>(); - private ArrayList targets = new ArrayList<>(); + private final ArrayList> inputs = new ArrayList<>(); + private final ArrayList inputTensors = new ArrayList<>(); + private final ArrayList> outputs = new ArrayList<>(); + private final ArrayList targets = new ArrayList<>(); private RunOptions runOptions = null; } @@ -518,7 +518,7 @@ public void save(String prefix) { */ public static final class Run { /** Tensors from requested fetches. */ - public List> outputs; + public List outputs; /** * Metadata about the run. @@ -627,7 +627,7 @@ private static RunMetadata run( int[] outputOpIndices, TF_Operation[] targetOpHandles, boolean wantRunMetadata, - List> outputTensors) { + List outputTensors) { requireHandle(handle); int ninputs = inputTensorHandles.length; @@ -667,7 +667,7 @@ private static RunMetadata run( for (int i = 0; i < noutputs; ++i) { TF_Tensor h = outputValues.get(TF_Tensor.class, i).withDeallocator(); - outputTensors.add(Tensor.fromHandle(h)); + outputTensors.add(RawTensor.fromHandle(h).asTypedTensor()); } try { return runMetadata != null ? RunMetadata.parseFrom(runMetadata.dataAsByteBuffer()) : null; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index 62530e923ac..aa4618db3e5 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -31,9 +31,12 @@ import org.tensorflow.types.family.TType; /** - * A statically typed multi-dimensional array whose elements are of a type described by T. + * A statically typed multi-dimensional array. * - *

Instances of a Tensor are not thread-safe. + *

There are two categories of tensors in TensorFlow Java: {@link TType typed tensors} and + * {@link RawTensor raw tensors}. The former maps the tensor native memory to an + * n-dimensional typed data space, allowing direct I/O operations from the JVM, while the latter + * is only a reference to a native tensor allowing basic operations and flat data access.

* *

WARNING: Resources consumed by the Tensor object must be explicitly freed by * invoking the {@link #close()} method when the object is no longer needed. For example, using a @@ -43,25 +46,17 @@ * try (Tensor t = Tensor.of(...)) { * doSomethingWith(t); * } + * + *

Instances of a Tensor are not thread-safe. * } */ -public final class Tensor implements Shaped, AutoCloseable { +public interface Tensor extends Shaped, AutoCloseable { /** * Allocates a tensor of a given datatype and shape. * - *

The amount of memory to allocate is derived from the datatype and the shape of the tensor. - * Memory is left uninitialized after this method returns, so it is the responsibility of the - * caller to initialize the tensor data before it is used, via the {@link #data()} accessor. - * For example: - * - *

{@code
-   * FloatNdArray data = ...
-   * try (Tensor t = Tensor.of(TFloat32.DTYPE, Shape.of(2, 2))) {
-   *   data.copyTo(t.data());
-   *   ...
-   * }
-   * }
+ *

The amount of memory to allocate is derived from the datatype and the shape of the tensor, + * and is left uninitialized. * * @param the tensor element type * @param dtype datatype of the tensor @@ -69,7 +64,7 @@ public final class Tensor implements Shaped, AutoCloseable { * @return an allocated but uninitialized tensor * @throws IllegalStateException if tensor failed to be allocated */ - public static Tensor of(DataType dtype, Shape shape) { + static T of(DataType dtype, Shape shape) { return of(dtype, shape, shape.size() * dtype.byteSize()); } @@ -77,10 +72,8 @@ public static Tensor of(DataType dtype, Shape shape) { * Allocates a tensor of a given datatype, shape and size. * *

This method is identical to {@link #of(DataType, Shape)}, except that the final size of the - * tensor is explicitly set instead of computing it from the datatype and shape. - * - *

This could be useful for tensor types that stores data but also metadata in the tensor memory, - * like {@link org.tensorflow.types.TString TString}. + * tensor is explicitly set instead of computing it from the datatype and shape, which could be + * larger than the actual space required to store the data but not smaller. * * @param the tensor element type * @param dtype datatype of the tensor @@ -92,19 +85,13 @@ public static Tensor of(DataType dtype, Shape shape) { * store the tensor data * @throws IllegalStateException if tensor failed to be allocated */ - public static Tensor of(DataType dtype, Shape shape, long size) { - // Minimum requirements for datatypes of variable length cannot be verified in a relevant way so - // we only validate them for fixed length datatypes - if (!dtype.isVariableLength() && shape.size() * dtype.byteSize() > size) { - throw new IllegalArgumentException("Tensor size is not large enough to contain all scalar values"); - } - Tensor t = new Tensor<>(dtype, shape); - TF_Tensor nativeHandle = allocate(t.dtype.nativeCode(), shape.asArray(), size); - try (PointerScope scope = new PointerScope()) { - scope.attach(nativeHandle); - t.tensorHandle = nativeHandle; - t.tensorScope = scope.extend(); - return t; + static T of(DataType dtype, Shape shape, long size) { + RawTensor tensor = RawTensor.allocate(dtype, shape, size); + try { + return dtype.map(tensor); + } catch (Exception e) { + tensor.close(); + throw e; } } @@ -117,7 +104,7 @@ public static Tensor of(DataType dtype, Shape shape, lon * *

{@code
    * FloatNdArray data = ...
-   * try (Tensor t = Tensor.of(TFloat32.DTYPE, Shape.of(2, 2), data::copyTo)) {
+   * try (TFloat32 t = Tensor.of(TFloat32.DTYPE, Shape.of(2, 2), data::copyTo)) {
    *   ...
    * }
    * }
@@ -132,8 +119,7 @@ public static Tensor of(DataType dtype, Shape shape, lon * @return an allocated and initialized tensor * @throws IllegalStateException if tensor failed to be allocated */ - public static Tensor of(DataType dtype, Shape shape, - Consumer dataInitializer) { + static T of(DataType dtype, Shape shape, Consumer dataInitializer) { return of(dtype, shape, shape.size() * dtype.byteSize(), dataInitializer); } @@ -144,7 +130,7 @@ public static Tensor of(DataType dtype, Shape shape, * size for the tensor is explicitly set instead of being computed from the datatype and shape. * *

This could be useful for tensor types that stores data but also metadata in the tensor memory, - * such as {@link org.tensorflow.types.TString TString}. + * such as lookup table in a tensor of strings. * * @param the tensor element type * @param dtype datatype of the tensor @@ -157,11 +143,10 @@ public static Tensor of(DataType dtype, Shape shape, * store the tensor data * @throws IllegalStateException if tensor failed to be allocated */ - public static Tensor of(DataType dtype, Shape shape, long size, - Consumer dataInitializer) { - Tensor tensor = of(dtype, shape, size); + static T of(DataType dtype, Shape shape, long size, Consumer dataInitializer) { + T tensor = of(dtype, shape, size); try { - dataInitializer.accept(tensor.data()); + dataInitializer.accept(tensor); return tensor; } catch (Throwable t) { tensor.close(); @@ -182,203 +167,39 @@ public static Tensor of(DataType dtype, Shape shape, lon * @throws IllegalArgumentException if {@code rawData} is not large enough to contain the tensor data * @throws IllegalStateException if tensor failed to be allocated with the given parameters */ - public static Tensor of(DataType dtype, Shape shape, ByteDataBuffer rawData) { - Tensor t = of(dtype, shape, rawData.size()); - rawData.copyTo(TensorBuffers.toBytes(t.nativeHandle()), rawData.size()); - return t; - } - - /** - * Returns this Tensor object with the type {@code Tensor}. This method is useful when given a - * value of type {@code Tensor}. - * - * @param dt any supported tensor data type - * @param a tensor type - * @return a tensor of the requested data type - * @throws IllegalArgumentException if the actual data type of this object does not match the type - * {@code U}. - */ - @SuppressWarnings("unchecked") - public Tensor expect(DataType dt) { - if (!dt.equals(this.dtype)) { - throw new IllegalArgumentException( - "Cannot cast from tensor of " + dtype + " to tensor of " + dt); - } - return ((Tensor) this); + static T of(DataType dtype, Shape shape, ByteDataBuffer rawData) { + return of(dtype, shape, rawData.size(), t -> rawData.copyTo(t.asRawTensor().data(), rawData.size())); } /** - * Release resources associated with the Tensor. - * - *

WARNING:This must be invoked for all tensors that were not been produced by an eager - * operation or memory will be leaked. - * - *

The Tensor object is no longer usable after {@code close} returns. + * Returns the {@link DataType} of elements stored in the tensor. */ - @Override - public void close() { - tensorScope.close(); - } - - /** Returns the {@link DataType} of elements stored in the Tensor. */ - public DataType dataType() { - return dtype; - } - - /** Returns the size, in bytes, of the tensor data. */ - public long numBytes() { - if (numBytes == null) { - numBytes = TF_TensorByteSize(tensorHandle); - } - return numBytes; - } - - /** Returns the shape of this tensor. */ - @Override - public Shape shape() { - return shape; - } + DataType dataType(); /** - * Returns the data of this tensor. - * - *

This method returns an accessor to the tensor data as an instance of {@code T}, which - * commonly maps this data to an {@link NdArray NdArray}. Input and - * output operations performed on the returned n-dimensional array are applied directly to the - * tensor native memory. For example: - * - *

{@code
-   * Ops tf = Ops.create();
-   * try (Tensor t = TFloat32.tensorOf(Shape.of(2, 2))) {
-   *   TFloat32 data = t.data();
-   *
-   *   StdArrays.copyTo(data, new float[][] {
-   *     {1.0f, 2.0f},
-   *     {3.0f, 4.0f}
-   *   });
-   *   assertEquals(NdArrays.vectorOf(3.0f, 4.0f), data.getFloat(1));
-   *
-   *   Constant c = tf.constant(t);
-   *   assertEquals(4.0f, c.data().getFloat(1, 1));
-   * }
-   * }
- * - *

Please refer to the documentation of the {@link NdArray NdArray} - * classes for more information on the various techniques to read or write data in an - * n-dimensional space using this data structure. - * - * @return the tensor data mapped to an n-dimensional space - * @throws IllegalStateException if the tensor has been closed - * @see NdArray + * Returns the size, in bytes, of the tensor data. */ - public T data() { - if (data == null) { - data = dtype.map(this); - } else { - nativeHandle(); // Checks that the tensor has not been released or will throw - } - return data; - } + long numBytes(); /** - * Returns the raw data of this tensor as a buffer of bytes. - * - *

Use this method to obtain a read-only serializable view of the tensor raw data and must be - * used with care since there is no guard on the element boundaries. For regular input or output - * operations, use {@link #data()}. - * - * @return the tensor raw data mapped to a read-only byte buffer - * @throws IllegalStateException if the tensor has been closed + * Returns the shape of the tensor. */ - public ByteDataBuffer rawData() { - return TensorBuffers.toBytes(nativeHandle(), true); - } - - /** Returns a string describing the type and shape of the Tensor. */ @Override - public String toString() { - return String.format("%s tensor with shape %s", dtype.toString(), shape); - } + Shape shape(); /** - * Create a Tensor object from a handle to the C TF_Tensor object. - * - *

Takes ownership of the handle. + * Returns a raw (untyped) representation of this tensor */ - static Tensor fromHandle(TF_Tensor handle) { - Tensor t = new Tensor<>(DataTypes.fromNativeCode(dtype(handle)), Shape.of(shape(handle))); - try (PointerScope scope = new PointerScope()) { - scope.attach(handle); - t.tensorHandle = handle; - t.tensorScope = scope.extend(); - } - return t; - } + RawTensor asRawTensor(); /** - * Create an eager Tensor object from a handle to the C TF_Tensor object. + * Release resources associated with the Tensor. * - *

Takes ownership of the handle. - */ - static Tensor fromHandle(TF_Tensor handle, EagerSession session) { - Tensor t = fromHandle(handle); - session.attach(handle); - t.tensorScope.detach(handle); - return t; - } - - /** - * @return native handle to this tensor - * @throws IllegalStateException if tensor has been closed + *

WARNING:This must be invoked for all tensors that were not been produced by an eager + * operation or memory will be leaked. + * + *

The Tensor object is no longer usable after {@code close} returns. */ - TF_Tensor nativeHandle() { - return requireHandle(tensorHandle); - } - - private PointerScope tensorScope; - private TF_Tensor tensorHandle; - - private static TF_Tensor requireHandle(TF_Tensor handle) { - if (handle == null || handle.isNull()) { - throw new IllegalStateException("close() was called on the Tensor"); - } - return handle; - } - - private static TF_Tensor allocate(int dtype, long[] shape, long byteSize) { - TF_Tensor t = TF_Tensor.allocateTensor(dtype, shape, byteSize); - if (t == null || t.isNull()) { - throw new IllegalStateException("unable to allocate memory for the Tensor"); - } - return t; - } - - private static int dtype(TF_Tensor handle) { - requireHandle(handle); - return TF_TensorType(handle); - } - - private static long[] shape(TF_Tensor handle) { - requireHandle(handle); - int numDims = TF_NumDims(handle); - long[] dims = new long[numDims]; - for (int i = 0; i < numDims; ++i) { - dims[i] = TF_Dim(handle, i); - } - return dims; - } - - private final DataType dtype; - private final Shape shape; - private T data = null; - private Long numBytes = null; - - private Tensor(DataType dtype, Shape shape) { - this.dtype = dtype; - this.shape = shape; - } - - static { - TensorFlow.init(); - } + @Override + void close(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java index 6c214cc6819..ff1c990d3ea 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java @@ -77,7 +77,7 @@ public final class Constant extends RawOp implements Operand */ @Endpoint public static Constant scalarOf(Scope scope, int data) { - try (Tensor value = TInt32.scalarOf(data)) { + try (TInt32 value = TInt32.scalarOf(data)) { return create(scope, value); } } @@ -92,7 +92,7 @@ public static Constant scalarOf(Scope scope, int data) { */ @Endpoint public static Constant vectorOf(Scope scope, int[] data) { - try (Tensor value = TInt32.vectorOf(data)) { + try (TInt32 value = TInt32.vectorOf(data)) { return create(scope, value); } } @@ -122,7 +122,7 @@ public static Constant arrayOf(Scope scope, int... data) { */ @Endpoint public static Constant tensorOf(Scope scope, int[][] data) { - try (Tensor value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -138,7 +138,7 @@ public static Constant tensorOf(Scope scope, int[][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, int[][][] data) { - try (Tensor value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -154,7 +154,7 @@ public static Constant tensorOf(Scope scope, int[][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, int[][][][] data) { - try (Tensor value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -170,7 +170,7 @@ public static Constant tensorOf(Scope scope, int[][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, int[][][][][] data) { - try (Tensor value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -186,7 +186,7 @@ public static Constant tensorOf(Scope scope, int[][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, int[][][][][][] data) { - try (Tensor value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -201,7 +201,10 @@ public static Constant tensorOf(Scope scope, int[][][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, IntNdArray data) { - try (Tensor value = TInt32.tensorOf(data)) { + if (data instanceof TInt32) { + return create(scope, (TInt32) data); + } + try (TInt32 value = TInt32.tensorOf(data)) { return create(scope, value); } } @@ -217,7 +220,7 @@ public static Constant tensorOf(Scope scope, IntNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, IntDataBuffer data) { - try (Tensor value = TInt32.tensorOf(shape, data)) { + try (TInt32 value = TInt32.tensorOf(shape, data)) { return create(scope, value); } } @@ -231,7 +234,7 @@ public static Constant tensorOf(Scope scope, Shape shape, IntDataBuffer */ @Endpoint public static Constant scalarOf(Scope scope, float data) { - try (Tensor value = TFloat32.scalarOf(data)) { + try (TFloat32 value = TFloat32.scalarOf(data)) { return create(scope, value); } } @@ -246,7 +249,7 @@ public static Constant scalarOf(Scope scope, float data) { */ @Endpoint public static Constant vectorOf(Scope scope, float[] data) { - try (Tensor value = TFloat32.vectorOf(data)) { + try (TFloat32 value = TFloat32.vectorOf(data)) { return create(scope, value); } } @@ -276,7 +279,7 @@ public static Constant arrayOf(Scope scope, float... data) { */ @Endpoint public static Constant tensorOf(Scope scope, float[][] data) { - try (Tensor value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -292,7 +295,7 @@ public static Constant tensorOf(Scope scope, float[][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, float[][][] data) { - try (Tensor value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -308,7 +311,7 @@ public static Constant tensorOf(Scope scope, float[][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, float[][][][] data) { - try (Tensor value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -324,7 +327,7 @@ public static Constant tensorOf(Scope scope, float[][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, float[][][][][] data) { - try (Tensor value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -340,7 +343,7 @@ public static Constant tensorOf(Scope scope, float[][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, float[][][][][][] data) { - try (Tensor value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -355,7 +358,10 @@ public static Constant tensorOf(Scope scope, float[][][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, FloatNdArray data) { - try (Tensor value = TFloat32.tensorOf(data)) { + if (data instanceof TFloat32) { + return create(scope, (TFloat32) data); + } + try (TFloat32 value = TFloat32.tensorOf(data)) { return create(scope, value); } } @@ -371,7 +377,7 @@ public static Constant tensorOf(Scope scope, FloatNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, FloatDataBuffer data) { - try (Tensor value = TFloat32.tensorOf(shape, data)) { + try (TFloat32 value = TFloat32.tensorOf(shape, data)) { return create(scope, value); } } @@ -385,7 +391,7 @@ public static Constant tensorOf(Scope scope, Shape shape, FloatDataBuf */ @Endpoint public static Constant scalarOf(Scope scope, double data) { - try (Tensor value = TFloat64.scalarOf(data)) { + try (TFloat64 value = TFloat64.scalarOf(data)) { return create(scope, value); } } @@ -400,7 +406,7 @@ public static Constant scalarOf(Scope scope, double data) { */ @Endpoint public static Constant vectorOf(Scope scope, double[] data) { - try (Tensor value = TFloat64.vectorOf(data)) { + try (TFloat64 value = TFloat64.vectorOf(data)) { return create(scope, value); } } @@ -430,7 +436,7 @@ public static Constant arrayOf(Scope scope, double... data) { */ @Endpoint public static Constant tensorOf(Scope scope, double[][] data) { - try (Tensor value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -446,7 +452,7 @@ public static Constant tensorOf(Scope scope, double[][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, double[][][] data) { - try (Tensor value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -462,7 +468,7 @@ public static Constant tensorOf(Scope scope, double[][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, double[][][][] data) { - try (Tensor value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -478,7 +484,7 @@ public static Constant tensorOf(Scope scope, double[][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, double[][][][][] data) { - try (Tensor value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -494,7 +500,7 @@ public static Constant tensorOf(Scope scope, double[][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, double[][][][][][] data) { - try (Tensor value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( + try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( data, t))) { return create(scope, value); } @@ -509,7 +515,10 @@ public static Constant tensorOf(Scope scope, double[][][][][][] data) */ @Endpoint public static Constant tensorOf(Scope scope, DoubleNdArray data) { - try (Tensor value = TFloat64.tensorOf(data)) { + if (data instanceof TFloat64) { + return create(scope, (TFloat64) data); + } + try (TFloat64 value = TFloat64.tensorOf(data)) { return create(scope, value); } } @@ -525,7 +534,7 @@ public static Constant tensorOf(Scope scope, DoubleNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, DoubleDataBuffer data) { - try (Tensor value = TFloat64.tensorOf(shape, data)) { + try (TFloat64 value = TFloat64.tensorOf(shape, data)) { return create(scope, value); } } @@ -539,7 +548,7 @@ public static Constant tensorOf(Scope scope, Shape shape, DoubleDataBu */ @Endpoint public static Constant scalarOf(Scope scope, long data) { - try (Tensor value = TInt64.scalarOf(data)) { + try (TInt64 value = TInt64.scalarOf(data)) { return create(scope, value); } } @@ -554,7 +563,7 @@ public static Constant scalarOf(Scope scope, long data) { */ @Endpoint public static Constant vectorOf(Scope scope, long[] data) { - try (Tensor value = TInt64.vectorOf(data)) { + try (TInt64 value = TInt64.vectorOf(data)) { return create(scope, value); } } @@ -569,7 +578,7 @@ public static Constant vectorOf(Scope scope, long[] data) { */ @Endpoint public static Constant tensorOf(Scope scope, long[][] data) { - try (Tensor value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -600,7 +609,7 @@ public static Constant arrayOf(Scope scope, long... data) { */ @Endpoint public static Constant tensorOf(Scope scope, long[][][] data) { - try (Tensor value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -616,7 +625,7 @@ public static Constant tensorOf(Scope scope, long[][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, long[][][][] data) { - try (Tensor value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -632,7 +641,7 @@ public static Constant tensorOf(Scope scope, long[][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, long[][][][][] data) { - try (Tensor value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -648,7 +657,7 @@ public static Constant tensorOf(Scope scope, long[][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, long[][][][][][] data) { - try (Tensor value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -663,7 +672,10 @@ public static Constant tensorOf(Scope scope, long[][][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, LongNdArray data) { - try (Tensor value = TInt64.tensorOf(data)) { + if (data instanceof TInt64) { + return create(scope, (TInt64) data); + } + try (TInt64 value = TInt64.tensorOf(data)) { return create(scope, value); } } @@ -679,7 +691,7 @@ public static Constant tensorOf(Scope scope, LongNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, LongDataBuffer data) { - try (Tensor value = TInt64.tensorOf(shape, data)) { + try (TInt64 value = TInt64.tensorOf(shape, data)) { return create(scope, value); } } @@ -693,7 +705,7 @@ public static Constant tensorOf(Scope scope, Shape shape, LongDataBuffer */ @Endpoint public static Constant scalarOf(Scope scope, boolean data) { - try (Tensor value = TBool.scalarOf(data)) { + try (TBool value = TBool.scalarOf(data)) { return create(scope, value); } } @@ -708,7 +720,7 @@ public static Constant scalarOf(Scope scope, boolean data) { */ @Endpoint public static Constant vectorOf(Scope scope, boolean[] data) { - try (Tensor value = TBool.vectorOf(data)) { + try (TBool value = TBool.vectorOf(data)) { return create(scope, value); } } @@ -738,7 +750,7 @@ public static Constant arrayOf(Scope scope, boolean... data) { */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][] data) { - try (Tensor value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -754,7 +766,7 @@ public static Constant tensorOf(Scope scope, boolean[][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][][] data) { - try (Tensor value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -770,7 +782,7 @@ public static Constant tensorOf(Scope scope, boolean[][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][][][] data) { - try (Tensor value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -786,7 +798,7 @@ public static Constant tensorOf(Scope scope, boolean[][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][][][][] data) { - try (Tensor value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -802,7 +814,7 @@ public static Constant tensorOf(Scope scope, boolean[][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][][][][][] data) { - try (Tensor value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, t))) { return create(scope, value); } @@ -817,7 +829,10 @@ public static Constant tensorOf(Scope scope, boolean[][][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, BooleanNdArray data) { - try (Tensor value = TBool.tensorOf(data)) { + if (data instanceof TBool) { + return create(scope, (TBool) data); + } + try (TBool value = TBool.tensorOf(data)) { return create(scope, value); } } @@ -833,7 +848,7 @@ public static Constant tensorOf(Scope scope, BooleanNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, BooleanDataBuffer data) { - try (Tensor value = TBool.tensorOf(shape, data)) { + try (TBool value = TBool.tensorOf(shape, data)) { return create(scope, value); } } @@ -847,7 +862,7 @@ public static Constant tensorOf(Scope scope, Shape shape, BooleanDataBuff */ @Endpoint public static Constant scalarOf(Scope scope, byte data) { - try (Tensor value = TUint8.scalarOf(data)) { + try (TUint8 value = TUint8.scalarOf(data)) { return create(scope, value); } } @@ -862,7 +877,7 @@ public static Constant scalarOf(Scope scope, byte data) { */ @Endpoint public static Constant vectorOf(Scope scope, byte[] data) { - try (Tensor value = TUint8.vectorOf(data)) { + try (TUint8 value = TUint8.vectorOf(data)) { return create(scope, value); } } @@ -892,7 +907,7 @@ public static Constant arrayOf(Scope scope, byte... data) { */ @Endpoint public static Constant tensorOf(Scope scope, byte[][] data) { - try (Tensor value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, d))) { return create(scope, value); } @@ -908,7 +923,7 @@ public static Constant tensorOf(Scope scope, byte[][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, byte[][][] data) { - try (Tensor value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, d))) { return create(scope, value); } @@ -924,7 +939,7 @@ public static Constant tensorOf(Scope scope, byte[][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, byte[][][][] data) { - try (Tensor value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, d))) { return create(scope, value); } @@ -940,7 +955,7 @@ public static Constant tensorOf(Scope scope, byte[][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, byte[][][][][] data) { - try (Tensor value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, d))) { return create(scope, value); } @@ -956,7 +971,7 @@ public static Constant tensorOf(Scope scope, byte[][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, byte[][][][][][] data) { - try (Tensor value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, d))) { return create(scope, value); } @@ -971,7 +986,10 @@ public static Constant tensorOf(Scope scope, byte[][][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, ByteNdArray data) { - try (Tensor value = TUint8.tensorOf(data)) { + if (data instanceof TUint8) { + return create(scope, (TUint8) data); + } + try (TUint8 value = TUint8.tensorOf(data)) { return create(scope, value); } } @@ -987,7 +1005,7 @@ public static Constant tensorOf(Scope scope, ByteNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, ByteDataBuffer data) { - try (Tensor value = TUint8.tensorOf(shape, data)) { + try (TUint8 value = TUint8.tensorOf(shape, data)) { return create(scope, value); } } @@ -1006,7 +1024,7 @@ public static Constant tensorOf(Scope scope, Shape shape, ByteDataBuffer @Endpoint public static Constant tensorOf(Scope scope, DataType type, Shape shape, ByteDataBuffer data) { - try (Tensor value = Tensor.of(type, shape, data)) { + try (T value = Tensor.of(type, shape, data)) { return create(scope, value); } } @@ -1020,7 +1038,7 @@ public static Constant tensorOf(Scope scope, DataType ty */ @Endpoint public static Constant scalarOf(Scope scope, String data) { - try (Tensor value = TString.scalarOf(data)) { + try (TString value = TString.scalarOf(data)) { return create(scope, value); } } @@ -1035,7 +1053,7 @@ public static Constant scalarOf(Scope scope, String data) { */ @Endpoint public static Constant scalarOf(Scope scope, Charset charset, String data) { - try (Tensor value = TString.tensorOf(charset, NdArrays.scalarOfObject(data))) { + try (TString value = TString.tensorOf(charset, NdArrays.scalarOfObject(data))) { return create(scope, value); } } @@ -1049,7 +1067,7 @@ public static Constant scalarOf(Scope scope, Charset charset, String da */ public static Constant vectorOf(Scope scope, String[] data) { NdArray src = NdArrays.vectorOfObjects(data); - try (Tensor value = TString.tensorOf(src)) { + try (TString value = TString.tensorOf(src)) { return create(scope, value); } } @@ -1065,7 +1083,7 @@ public static Constant vectorOf(Scope scope, String[] data) { */ @Endpoint public static Constant vectorOf(Scope scope, Charset charset, String[] data) { - try (Tensor value = TString.tensorOf(charset, NdArrays.vectorOfObjects(data))) { + try (TString value = TString.tensorOf(charset, NdArrays.vectorOfObjects(data))) { return Constant.create(scope, value); } } @@ -1112,7 +1130,7 @@ public static Constant arrayOf(Scope scope, Charset charset, String... public static Constant tensorOf(Scope scope, String[][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (Tensor value = TString.tensorOf(src)) { + try (TString value = TString.tensorOf(src)) { return create(scope, value); } } @@ -1127,7 +1145,7 @@ public static Constant tensorOf(Scope scope, String[][] data) { public static Constant tensorOf(Scope scope, String[][][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (Tensor value = TString.tensorOf(src)) { + try (TString value = TString.tensorOf(src)) { return create(scope, value); } } @@ -1142,7 +1160,7 @@ public static Constant tensorOf(Scope scope, String[][][] data) { public static Constant tensorOf(Scope scope, String[][][][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (Tensor value = TString.tensorOf(src)) { + try (TString value = TString.tensorOf(src)) { return create(scope, value); } } @@ -1157,7 +1175,7 @@ public static Constant tensorOf(Scope scope, String[][][][] data) { public static Constant tensorOf(Scope scope, String[][][][][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (Tensor value = TString.tensorOf(src)) { + try (TString value = TString.tensorOf(src)) { return create(scope, value); } } @@ -1172,7 +1190,7 @@ public static Constant tensorOf(Scope scope, String[][][][][] data) { public static Constant tensorOf(Scope scope, String[][][][][][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (Tensor value = TString.tensorOf(src)) { + try (TString value = TString.tensorOf(src)) { return create(scope, value); } } @@ -1187,7 +1205,10 @@ public static Constant tensorOf(Scope scope, String[][][][][][] data) { */ @Endpoint public static Constant tensorOf(Scope scope, NdArray data) { - try (Tensor value = TString.tensorOf(data)) { + if (data instanceof TString) { + return create(scope, (TString) data); + } + try (TString value = TString.tensorOf(data)) { return create(scope, value); } } @@ -1203,7 +1224,7 @@ public static Constant tensorOf(Scope scope, NdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Charset charset, NdArray data) { - try (Tensor value = TString.tensorOf(charset, data)) { + try (TString value = TString.tensorOf(charset, data)) { return create(scope, value); } } @@ -1220,7 +1241,7 @@ public static Constant tensorOf(Scope scope, Charset charset, NdArray tensorOf(Scope scope, Shape shape, DataBuffer data) { - try (Tensor value = TString.tensorOf(shape, data)) { + try (TString value = TString.tensorOf(shape, data)) { return create(scope, value); } } @@ -1238,7 +1259,7 @@ public static Constant tensorOf(Scope scope, Shape shape, DataBuffer tensorOf(Scope scope, Charset charset, Shape shape, DataBuffer data) { - try (Tensor value = TString.tensorOf(charset, shape, data)) { + try (TString value = TString.tensorOf(charset, shape, data)) { return create(scope, value); } } @@ -1257,14 +1278,14 @@ public static Constant tensorOf(Scope scope, Shape shape) { } /** - * Create a constant from a Tensor. + * Capture a {@code tensor} by making a constant copy of it. * * @param scope is a scope used to add the underlying operation. * @param tensor a Tensor holding the constant value * @return a constant of the same data type as `tensor` */ - @Endpoint - public static Constant create(Scope scope, Tensor tensor) { + @Endpoint(name = "capture") // Cannot be "constant" since it will conflict with other endpoints accepting an NdArray + public static Constant create(Scope scope, T tensor) { return new Constant<>( scope .env() diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java index 50f6ea49b06..e7fd03af46a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java @@ -19,16 +19,17 @@ import java.util.function.Consumer; import org.tensorflow.DataType; +import org.tensorflow.RawTensor; import org.tensorflow.Tensor; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.buffer.TensorBuffers; import org.tensorflow.internal.c_api.TF_Tensor; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.FloatDataBuffer; -import org.tensorflow.ndarray.buffer.layout.DataLayouts; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.layout.DataLayouts; import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; import org.tensorflow.types.family.TFloating; @@ -61,7 +62,7 @@ public interface TBfloat16 extends FloatNdArray, TFloating { * @param value float to store in the new tensor * @return the new tensor */ - static Tensor scalarOf(float value) { + static TBfloat16 scalarOf(float value) { return Tensor.of(DTYPE, Shape.scalar(), data -> data.setFloat(value)); } @@ -71,7 +72,7 @@ static Tensor scalarOf(float value) { * @param values floats to store in the new tensor * @return the new tensor */ - static Tensor vectorOf(float... values) { + static TBfloat16 vectorOf(float... values) { if (values == null) { throw new IllegalArgumentException(); } @@ -86,7 +87,7 @@ static Tensor vectorOf(float... values) { * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOf(NdArray src) { + static TBfloat16 tensorOf(NdArray src) { return Tensor.of(DTYPE, src.shape(), src::copyTo); } @@ -96,7 +97,7 @@ static Tensor tensorOf(NdArray src) { * @param shape shape of the tensor to allocate * @return the new tensor */ - static Tensor tensorOf(Shape shape) { + static TBfloat16 tensorOf(Shape shape) { return Tensor.of(DTYPE, shape); } @@ -107,7 +108,7 @@ static Tensor tensorOf(Shape shape) { * @param data buffer of floats to initialize the tensor with * @return the new tensor */ - static Tensor tensorOf(Shape shape, FloatDataBuffer data) { + static TBfloat16 tensorOf(Shape shape, FloatDataBuffer data) { return Tensor.of(DTYPE, shape, d -> d.write(data)); } @@ -119,7 +120,7 @@ static Tensor tensorOf(Shape shape, FloatDataBuffer data) { * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static Tensor tensorOf(Shape shape, Consumer dataInit) { + static TBfloat16 tensorOf(Shape shape, Consumer dataInit) { return Tensor.of(DTYPE, shape, dataInit); } } @@ -127,12 +128,25 @@ static Tensor tensorOf(Shape shape, Consumer dataInit) { /** Hidden implementation of a {@code TBfloat16} */ class TBfloat16Impl extends FloatDenseNdArray implements TBfloat16 { - static TBfloat16 mapTensor(TF_Tensor nativeTensor, Shape shape) { - return new TBfloat16Impl( - DataLayouts.BFLOAT16.applyTo(TensorBuffers.toShorts(nativeTensor)), shape); + @Override + public DataType dataType() { + return TBfloat16.DTYPE; } - private TBfloat16Impl(FloatDataBuffer buffer, Shape shape) { - super(buffer, shape); + @Override + public RawTensor asRawTensor() { + return rawTensor; + } + + static TBfloat16 mapTensor(RawTensor tensor, TF_Tensor nativeHandle) { + FloatDataBuffer buffer = DataLayouts.BFLOAT16.applyTo(TensorBuffers.toShorts(nativeHandle)); + return new TBfloat16Impl(tensor, buffer); + } + + private final RawTensor rawTensor; + + private TBfloat16Impl(RawTensor rawTensor, FloatDataBuffer buffer) { + super(buffer, rawTensor.shape()); + this.rawTensor = rawTensor; } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java index 3cc72101893..0571dce410c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java @@ -17,7 +17,9 @@ package org.tensorflow.types; +import java.util.function.Consumer; import org.tensorflow.DataType; +import org.tensorflow.RawTensor; import org.tensorflow.Tensor; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.buffer.TensorBuffers; @@ -31,8 +33,6 @@ import org.tensorflow.ndarray.impl.dense.BooleanDenseNdArray; import org.tensorflow.types.family.TType; -import java.util.function.Consumer; - /** * Boolean tensor type. * @@ -53,7 +53,7 @@ public interface TBool extends BooleanNdArray, TType { * @param value boolean to store in the new tensor * @return the new tensor */ - static Tensor scalarOf(boolean value) { + static TBool scalarOf(boolean value) { return Tensor.of(DTYPE, Shape.scalar(), data -> data.setBoolean(value)); } @@ -63,7 +63,7 @@ static Tensor scalarOf(boolean value) { * @param values booleans to store in the new tensor * @return the new tensor */ - static Tensor vectorOf(boolean... values) { + static TBool vectorOf(boolean... values) { if (values == null) { throw new IllegalArgumentException(); } @@ -78,7 +78,7 @@ static Tensor vectorOf(boolean... values) { * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOf(NdArray src) { + static TBool tensorOf(NdArray src) { return Tensor.of(DTYPE, src.shape(), src::copyTo); } @@ -88,7 +88,7 @@ static Tensor tensorOf(NdArray src) { * @param shape shape of the tensor to allocate * @return the new tensor */ - static Tensor tensorOf(Shape shape) { + static TBool tensorOf(Shape shape) { return Tensor.of(DTYPE, shape); } @@ -99,7 +99,7 @@ static Tensor tensorOf(Shape shape) { * @param data buffer of booleans to initialize the tensor with * @return the new tensor */ - static Tensor tensorOf(Shape shape, BooleanDataBuffer data) { + static TBool tensorOf(Shape shape, BooleanDataBuffer data) { return Tensor.of(DTYPE, shape, d -> d.write(data)); } @@ -111,7 +111,7 @@ static Tensor tensorOf(Shape shape, BooleanDataBuffer data) { * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static Tensor tensorOf(Shape shape, Consumer dataInit) { + static TBool tensorOf(Shape shape, Consumer dataInit) { return Tensor.of(DTYPE, shape, dataInit); } } @@ -119,11 +119,25 @@ static Tensor tensorOf(Shape shape, Consumer dataInit) { /** Hidden implementation of a {@code TBool} */ class TBoolImpl extends BooleanDenseNdArray implements TBool { - static TBool mapTensor(TF_Tensor nativeTensor, Shape shape) { - return new TBoolImpl(TensorBuffers.toBooleans(nativeTensor), shape); + @Override + public DataType dataType() { + return TBool.DTYPE; + } + + @Override + public RawTensor asRawTensor() { + return rawTensor; + } + + static TBool mapTensor(RawTensor tensor, TF_Tensor nativeHandle) { + BooleanDataBuffer buffer = TensorBuffers.toBooleans(nativeHandle); + return new TBoolImpl(tensor, buffer); } - private TBoolImpl(BooleanDataBuffer buffer, Shape shape) { - super(buffer, shape); + private final RawTensor rawTensor; + + private TBoolImpl(RawTensor rawTensor, BooleanDataBuffer buffer) { + super(buffer, rawTensor.shape()); + this.rawTensor = rawTensor; } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java index 0cd441a1ff1..b675701b0d0 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java @@ -19,16 +19,17 @@ import java.util.function.Consumer; import org.tensorflow.DataType; +import org.tensorflow.RawTensor; import org.tensorflow.Tensor; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.buffer.TensorBuffers; import org.tensorflow.internal.c_api.TF_Tensor; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.FloatDataBuffer; -import org.tensorflow.ndarray.buffer.layout.DataLayouts; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.layout.DataLayouts; import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; import org.tensorflow.types.family.TFloating; @@ -59,7 +60,7 @@ public interface TFloat16 extends FloatNdArray, TFloating { * @param value float to store in the new tensor * @return the new tensor */ - static Tensor scalarOf(float value) { + static TFloat16 scalarOf(float value) { return Tensor.of(DTYPE, Shape.scalar(), data -> data.setFloat(value)); } @@ -69,7 +70,7 @@ static Tensor scalarOf(float value) { * @param values floats to store in the new tensor * @return the new tensor */ - static Tensor vectorOf(float... values) { + static TFloat16 vectorOf(float... values) { if (values == null) { throw new IllegalArgumentException(); } @@ -84,7 +85,7 @@ static Tensor vectorOf(float... values) { * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOf(NdArray src) { + static TFloat16 tensorOf(NdArray src) { return Tensor.of(DTYPE, src.shape(), src::copyTo); } @@ -94,7 +95,7 @@ static Tensor tensorOf(NdArray src) { * @param shape shape of the tensor to allocate * @return the new tensor */ - static Tensor tensorOf(Shape shape) { + static TFloat16 tensorOf(Shape shape) { return Tensor.of(DTYPE, shape); } @@ -105,7 +106,7 @@ static Tensor tensorOf(Shape shape) { * @param data buffer of floats to initialize the tensor with * @return the new tensor */ - static Tensor tensorOf(Shape shape, FloatDataBuffer data) { + static TFloat16 tensorOf(Shape shape, FloatDataBuffer data) { return Tensor.of(DTYPE, shape, d -> d.write(data)); } @@ -117,7 +118,7 @@ static Tensor tensorOf(Shape shape, FloatDataBuffer data) { * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static Tensor tensorOf(Shape shape, Consumer dataInit) { + static TFloat16 tensorOf(Shape shape, Consumer dataInit) { return Tensor.of(DTYPE, shape, dataInit); } } @@ -125,12 +126,25 @@ static Tensor tensorOf(Shape shape, Consumer dataInit) { /** Hidden implementation of a {@code TFloat16} */ class TFloat16Impl extends FloatDenseNdArray implements TFloat16 { - static TFloat16 mapTensor(TF_Tensor nativeTensor, Shape shape) { - return new TFloat16Impl( - DataLayouts.FLOAT16.applyTo(TensorBuffers.toShorts(nativeTensor)), shape); + @Override + public DataType dataType() { + return TFloat16.DTYPE; } - private TFloat16Impl(FloatDataBuffer buffer, Shape shape) { - super(buffer, shape); + @Override + public RawTensor asRawTensor() { + return rawTensor; + } + + static TFloat16 mapTensor(RawTensor tensor, TF_Tensor nativeHandle) { + FloatDataBuffer buffer = DataLayouts.FLOAT16.applyTo(TensorBuffers.toShorts(nativeHandle)); + return new TFloat16Impl(tensor, buffer); + } + + private final RawTensor rawTensor; + + private TFloat16Impl(RawTensor rawTensor, FloatDataBuffer buffer) { + super(buffer, rawTensor.shape()); + this.rawTensor = rawTensor; } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java index 571ec118ddc..9bcefd628c2 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java @@ -19,15 +19,16 @@ import java.util.function.Consumer; import org.tensorflow.DataType; +import org.tensorflow.RawTensor; import org.tensorflow.Tensor; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.buffer.TensorBuffers; import org.tensorflow.internal.c_api.TF_Tensor; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.FloatDataBuffer; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; import org.tensorflow.types.family.TFloating; @@ -46,7 +47,7 @@ public interface TFloat32 extends FloatNdArray, TFloating { * @param value float to store in the new tensor * @return the new tensor */ - static Tensor scalarOf(float value) { + static TFloat32 scalarOf(float value) { return Tensor.of(DTYPE, Shape.scalar(), data -> data.setFloat(value)); } @@ -56,7 +57,7 @@ static Tensor scalarOf(float value) { * @param values floats to store in the new tensor * @return the new tensor */ - static Tensor vectorOf(float... values) { + static TFloat32 vectorOf(float... values) { if (values == null) { throw new IllegalArgumentException(); } @@ -71,7 +72,7 @@ static Tensor vectorOf(float... values) { * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOf(NdArray src) { + static TFloat32 tensorOf(NdArray src) { return Tensor.of(DTYPE, src.shape(), src::copyTo); } @@ -81,7 +82,7 @@ static Tensor tensorOf(NdArray src) { * @param shape shape of the tensor to allocate * @return the new tensor */ - static Tensor tensorOf(Shape shape) { + static TFloat32 tensorOf(Shape shape) { return Tensor.of(DTYPE, shape); } @@ -92,7 +93,7 @@ static Tensor tensorOf(Shape shape) { * @param data buffer of floats to initialize the tensor with * @return the new tensor */ - static Tensor tensorOf(Shape shape, FloatDataBuffer data) { + static TFloat32 tensorOf(Shape shape, FloatDataBuffer data) { return Tensor.of(DTYPE, shape, d -> d.write(data)); } @@ -104,7 +105,7 @@ static Tensor tensorOf(Shape shape, FloatDataBuffer data) { * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static Tensor tensorOf(Shape shape, Consumer dataInit) { + static TFloat32 tensorOf(Shape shape, Consumer dataInit) { return Tensor.of(DTYPE, shape, dataInit); } } @@ -112,11 +113,25 @@ static Tensor tensorOf(Shape shape, Consumer dataInit) { /** Hidden implementation of a {@code TFloat32} */ class TFloat32Impl extends FloatDenseNdArray implements TFloat32 { - static TFloat32 mapTensor(TF_Tensor nativeTensor, Shape shape) { - return new TFloat32Impl(TensorBuffers.toFloats(nativeTensor), shape); + @Override + public DataType dataType() { + return TFloat32.DTYPE; } - private TFloat32Impl(FloatDataBuffer buffer, Shape shape) { - super(buffer, shape); + @Override + public RawTensor asRawTensor() { + return rawTensor; + } + + static TFloat32 mapTensor(RawTensor tensor, TF_Tensor nativeHandle) { + FloatDataBuffer buffer = TensorBuffers.toFloats(nativeHandle); + return new TFloat32Impl(tensor, buffer); + } + + private final RawTensor rawTensor; + + private TFloat32Impl(RawTensor rawTensor, FloatDataBuffer buffer) { + super(buffer, rawTensor.shape()); + this.rawTensor = rawTensor; } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java index 5d2744c4b3c..806725d5b21 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java @@ -19,15 +19,16 @@ import java.util.function.Consumer; import org.tensorflow.DataType; +import org.tensorflow.RawTensor; import org.tensorflow.Tensor; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.buffer.TensorBuffers; import org.tensorflow.internal.c_api.TF_Tensor; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.DoubleDataBuffer; import org.tensorflow.ndarray.DoubleNdArray; import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray; import org.tensorflow.types.family.TFloating; @@ -47,7 +48,7 @@ public interface TFloat64 extends DoubleNdArray, TFloating { * @param value double to store in the new tensor * @return the new tensor */ - static Tensor scalarOf(double value) { + static TFloat64 scalarOf(double value) { return Tensor.of(DTYPE, Shape.scalar(), data -> data.setDouble(value)); } @@ -57,7 +58,7 @@ static Tensor scalarOf(double value) { * @param values doubles to store in the new tensor * @return the new tensor */ - static Tensor vectorOf(double... values) { + static TFloat64 vectorOf(double... values) { if (values == null) { throw new IllegalArgumentException(); } @@ -72,7 +73,7 @@ static Tensor vectorOf(double... values) { * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOf(NdArray src) { + static TFloat64 tensorOf(NdArray src) { return Tensor.of(DTYPE, src.shape(), src::copyTo); } @@ -82,7 +83,7 @@ static Tensor tensorOf(NdArray src) { * @param shape shape of the tensor to allocate * @return the new tensor */ - static Tensor tensorOf(Shape shape) { + static TFloat64 tensorOf(Shape shape) { return Tensor.of(DTYPE, shape); } @@ -93,7 +94,7 @@ static Tensor tensorOf(Shape shape) { * @param data buffer of doubles to initialize the tensor with * @return the new tensor */ - static Tensor tensorOf(Shape shape, DoubleDataBuffer data) { + static TFloat64 tensorOf(Shape shape, DoubleDataBuffer data) { return Tensor.of(DTYPE, shape, d -> d.write(data)); } @@ -105,7 +106,7 @@ static Tensor tensorOf(Shape shape, DoubleDataBuffer data) { * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static Tensor tensorOf(Shape shape, Consumer dataInit) { + static TFloat64 tensorOf(Shape shape, Consumer dataInit) { return Tensor.of(DTYPE, shape, dataInit); } } @@ -113,11 +114,25 @@ static Tensor tensorOf(Shape shape, Consumer dataInit) { /** Hidden implementation of a {@code TFloat64} */ class TFloat64Impl extends DoubleDenseNdArray implements TFloat64 { - static TFloat64 mapTensor(TF_Tensor nativeTensor, Shape shape) { - return new TFloat64Impl(TensorBuffers.toDoubles(nativeTensor), shape); + @Override + public DataType dataType() { + return TFloat64.DTYPE; } - private TFloat64Impl(DoubleDataBuffer buffer, Shape shape) { - super(buffer, shape); + @Override + public RawTensor asRawTensor() { + return rawTensor; + } + + static TFloat64 mapTensor(RawTensor tensor, TF_Tensor nativeHandle) { + DoubleDataBuffer buffer = TensorBuffers.toDoubles(nativeHandle); + return new TFloat64Impl(tensor, buffer); + } + + private final RawTensor rawTensor; + + private TFloat64Impl(RawTensor rawTensor, DoubleDataBuffer buffer) { + super(buffer, rawTensor.shape()); + this.rawTensor = rawTensor; } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java index 4a1139ddde2..1aa4333f34f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java @@ -19,14 +19,15 @@ import java.util.function.Consumer; import org.tensorflow.DataType; +import org.tensorflow.RawTensor; import org.tensorflow.Tensor; import org.tensorflow.internal.buffer.TensorBuffers; import org.tensorflow.internal.c_api.TF_Tensor; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.IntDataBuffer; import org.tensorflow.ndarray.IntNdArray; import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.IntDataBuffer; import org.tensorflow.ndarray.impl.dense.IntDenseNdArray; import org.tensorflow.types.family.TNumber; @@ -45,7 +46,7 @@ public interface TInt32 extends IntNdArray, TNumber { * @param value int to store in the new tensor * @return the new tensor */ - static Tensor scalarOf(int value) { + static TInt32 scalarOf(int value) { return Tensor.of(DTYPE, Shape.scalar(), data -> data.setInt(value)); } @@ -56,7 +57,7 @@ static Tensor scalarOf(int value) { * @return the new tensor * @throws IllegalArgumentException if no values are provided */ - static Tensor vectorOf(int... values) { + static TInt32 vectorOf(int... values) { if (values == null) { throw new IllegalArgumentException(); } @@ -71,7 +72,7 @@ static Tensor vectorOf(int... values) { * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOf(NdArray src) { + static TInt32 tensorOf(NdArray src) { return Tensor.of(DTYPE, src.shape(), src::copyTo); } @@ -81,7 +82,7 @@ static Tensor tensorOf(NdArray src) { * @param shape shape of the tensor to allocate * @return the new tensor */ - static Tensor tensorOf(Shape shape) { + static TInt32 tensorOf(Shape shape) { return Tensor.of(DTYPE, shape); } @@ -92,7 +93,7 @@ static Tensor tensorOf(Shape shape) { * @param data buffer of ints to initialize the tensor with * @return the new tensor */ - static Tensor tensorOf(Shape shape, IntDataBuffer data) { + static TInt32 tensorOf(Shape shape, IntDataBuffer data) { return Tensor.of(DTYPE, shape, d -> d.write(data)); } @@ -103,7 +104,7 @@ static Tensor tensorOf(Shape shape, IntDataBuffer data) { * @param dataInit tensor data initializer * @return the new tensor */ - static Tensor tensorOf(Shape shape, Consumer dataInit) { + static TInt32 tensorOf(Shape shape, Consumer dataInit) { return Tensor.of(DTYPE, shape, dataInit); } } @@ -111,11 +112,25 @@ static Tensor tensorOf(Shape shape, Consumer dataInit) { /** Hidden implementation of a {@code TInt32} */ class TInt32Impl extends IntDenseNdArray implements TInt32 { - static TInt32 mapTensor(TF_Tensor nativeTensor, Shape shape) { - return new TInt32Impl(TensorBuffers.toInts(nativeTensor), shape); + @Override + public DataType dataType() { + return TInt32.DTYPE; } - private TInt32Impl(IntDataBuffer buffer, Shape shape) { - super(buffer, shape); + @Override + public RawTensor asRawTensor() { + return rawTensor; + } + + static TInt32 mapTensor(RawTensor tensor, TF_Tensor nativeHandle) { + IntDataBuffer buffer = TensorBuffers.toInts(nativeHandle); + return new TInt32Impl(tensor, buffer); + } + + private final RawTensor rawTensor; + + private TInt32Impl(RawTensor rawTensor, IntDataBuffer buffer) { + super(buffer, rawTensor.shape()); + this.rawTensor = rawTensor; } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java index 04fd4fd7799..0853ae9bac7 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java @@ -19,15 +19,16 @@ import java.util.function.Consumer; import org.tensorflow.DataType; +import org.tensorflow.RawTensor; import org.tensorflow.Tensor; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.buffer.TensorBuffers; import org.tensorflow.internal.c_api.TF_Tensor; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.LongDataBuffer; import org.tensorflow.ndarray.LongNdArray; import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.LongDataBuffer; import org.tensorflow.ndarray.impl.dense.LongDenseNdArray; import org.tensorflow.types.family.TNumber; @@ -46,7 +47,7 @@ public interface TInt64 extends LongNdArray, TNumber { * @param value long to store in the new tensor * @return the new tensor */ - static Tensor scalarOf(long value) { + static TInt64 scalarOf(long value) { return Tensor.of(DTYPE, Shape.scalar(), data -> data.setLong(value)); } @@ -56,7 +57,7 @@ static Tensor scalarOf(long value) { * @param values longs to store in the new tensor * @return the new tensor */ - static Tensor vectorOf(long... values) { + static TInt64 vectorOf(long... values) { if (values == null) { throw new IllegalArgumentException(); } @@ -71,7 +72,7 @@ static Tensor vectorOf(long... values) { * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOf(NdArray src) { + static TInt64 tensorOf(NdArray src) { return Tensor.of(DTYPE, src.shape(), src::copyTo); } @@ -81,7 +82,7 @@ static Tensor tensorOf(NdArray src) { * @param shape shape of the tensor to allocate * @return the new tensor */ - static Tensor tensorOf(Shape shape) { + static TInt64 tensorOf(Shape shape) { return Tensor.of(DTYPE, shape); } @@ -92,7 +93,7 @@ static Tensor tensorOf(Shape shape) { * @param data buffer of longs to initialize the tensor with * @return the new tensor */ - static Tensor tensorOf(Shape shape, LongDataBuffer data) { + static TInt64 tensorOf(Shape shape, LongDataBuffer data) { return Tensor.of(DTYPE, shape, d -> d.write(data)); } @@ -104,7 +105,7 @@ static Tensor tensorOf(Shape shape, LongDataBuffer data) { * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static Tensor tensorOf(Shape shape, Consumer dataInit) { + static TInt64 tensorOf(Shape shape, Consumer dataInit) { return Tensor.of(DTYPE, shape, dataInit); } } @@ -112,11 +113,25 @@ static Tensor tensorOf(Shape shape, Consumer dataInit) { /** Hidden implementation of a {@code TInt64} */ class TInt64Impl extends LongDenseNdArray implements TInt64 { - static TInt64 mapTensor(TF_Tensor nativeTensor, Shape shape) { - return new TInt64Impl(TensorBuffers.toLongs(nativeTensor), shape); + @Override + public DataType dataType() { + return TInt64.DTYPE; } - private TInt64Impl(LongDataBuffer buffer, Shape shape) { - super(buffer, shape); + @Override + public RawTensor asRawTensor() { + return rawTensor; + } + + static TInt64 mapTensor(RawTensor tensor, TF_Tensor nativeHandle) { + LongDataBuffer buffer = TensorBuffers.toLongs(nativeHandle); + return new TInt64Impl(tensor, buffer); + } + + private final RawTensor rawTensor; + + private TInt64Impl(RawTensor rawTensor, LongDataBuffer buffer) { + super(buffer, rawTensor.shape()); + this.rawTensor = rawTensor; } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java index 57a121edcf1..c6f3a9872a9 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java @@ -17,7 +17,11 @@ package org.tensorflow.types; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.function.Function; import org.tensorflow.DataType; +import org.tensorflow.RawTensor; import org.tensorflow.Tensor; import org.tensorflow.internal.buffer.StringTensorBuffer; import org.tensorflow.internal.buffer.TensorBuffers; @@ -31,10 +35,6 @@ import org.tensorflow.ndarray.impl.dense.DenseNdArray; import org.tensorflow.types.family.TType; -import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; -import java.util.function.Function; - /** * String type. * @@ -60,7 +60,7 @@ public interface TString extends NdArray, TType { * @param value scalar value to store in the new tensor * @return the new tensor */ - static Tensor scalarOf(String value) { + static TString scalarOf(String value) { return tensorOf(NdArrays.scalarOfObject(value)); } @@ -72,7 +72,7 @@ static Tensor scalarOf(String value) { * @param values values to store in the new tensor * @return the new tensor */ - static Tensor vectorOf(String... values) { + static TString vectorOf(String... values) { if (values == null) { throw new IllegalArgumentException(); } @@ -88,7 +88,7 @@ static Tensor vectorOf(String... values) { * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static Tensor tensorOf(NdArray src) { + static TString tensorOf(NdArray src) { return tensorOf(StandardCharsets.UTF_8, src); } @@ -103,7 +103,7 @@ static Tensor tensorOf(NdArray src) { * *

{@code
    * // Given `originalStrings` an initialized vector of strings
-   * Tensor tensor = TString.tensorOf(Charsets.UTF_16, originalStrings);
+   * TString tensor = TString.tensorOf(Charsets.UTF_16, originalStrings);
    * ...
    * TString tensorStrings = tensor.data().using(Charsets.UTF_16);
    * assertEquals(originalStrings.getObject(0), tensorStrings.getObject(0));
@@ -113,7 +113,7 @@ static Tensor tensorOf(NdArray src) {
    * @param src the source array giving the shape and data to the new tensor
    * @return the new tensor
    */
-  static Tensor tensorOf(Charset charset, NdArray src) {
+  static TString tensorOf(Charset charset, NdArray src) {
     return TStringImpl.createTensor(src, s -> s.getBytes(charset));
   }
 
@@ -127,7 +127,7 @@ static Tensor tensorOf(Charset charset, NdArray src) {
    * @param data buffer of strings to initialize the tensor with
    * @return the new tensor
    */
-  static Tensor tensorOf(Shape shape, DataBuffer data) {
+  static TString tensorOf(Shape shape, DataBuffer data) {
     return tensorOf(NdArrays.wrap(shape, data));
   }
 
@@ -142,7 +142,7 @@ static Tensor tensorOf(Shape shape, DataBuffer data) {
    *
    * 
{@code
    * // Given `originalStrings` an initialized buffer of strings
-   * Tensor tensor =
+   * TString tensor =
    *    TString.tensorOf(Charsets.UTF_16, Shape.of(originalString.size()), originalStrings);
    * ...
    * TString tensorStrings = tensor.data().using(Charsets.UTF_16);
@@ -154,7 +154,7 @@ static Tensor tensorOf(Shape shape, DataBuffer data) {
    * @param data buffer of strings to initialize the tensor with
    * @return the new tensor
    */
-  static Tensor tensorOf(Charset charset, Shape shape, DataBuffer data) {
+  static TString tensorOf(Charset charset, Shape shape, DataBuffer data) {
     return tensorOf(charset, NdArrays.wrap(shape, data));
   }
 
@@ -173,7 +173,7 @@ static Tensor tensorOf(Charset charset, Shape shape, DataBuffer
    * @param src the source array giving the shape and data to the new tensor
    * @return the new tensor
    */
-  static Tensor tensorOfBytes(NdArray src) {
+  static TString tensorOfBytes(NdArray src) {
     return TStringImpl.createTensor(src, Function.identity());
   }
 
@@ -193,7 +193,7 @@ static Tensor tensorOfBytes(NdArray src) {
    * @param data the source array giving the shape and data to the new tensor
    * @return the new tensor
    */
-  static Tensor tensorOfBytes(Shape shape, DataBuffer data) {
+  static TString tensorOfBytes(Shape shape, DataBuffer data) {
     return tensorOfBytes(NdArrays.wrap(shape, data));
   }
 
@@ -204,7 +204,7 @@ static Tensor tensorOfBytes(Shape shape, DataBuffer data) {
    * created. For example:
    *
    * 
{@code
-   * Tensor tensor =
+   * TString tensor =
    *    TString.tensorOf(StandardCharsets.UTF_16, NdArrays.scalarOfObject("TensorFlow");
    *
    * assertEquals("TensorFlow", tensor.data().using(StandardCharsets.UTF_16).getObject());
@@ -224,7 +224,7 @@ class TStringImpl extends DenseNdArray implements TString {
 
   @Override
   public TString using(Charset charset) {
-    return new TStringImpl(tensorBuffer, DataLayouts.ofStrings(charset), shape());
+    return new TStringImpl(rawTensor, tensorBuffer, DataLayouts.ofStrings(charset));
   }
 
   @Override
@@ -232,7 +232,17 @@ public NdArray asBytes() {
     return NdArrays.wrap(shape(), tensorBuffer);
   }
 
-  static  Tensor createTensor(NdArray src, Function getBytes) {
+  @Override
+  public DataType dataType() {
+    return TString.DTYPE;
+  }
+
+  @Override
+  public RawTensor asRawTensor() {
+    return rawTensor;
+  }
+
+  static  TString createTensor(NdArray src, Function getBytes) {
     long size = StringTensorBuffer.computeSize(src, getBytes);
     return Tensor.of(
         TString.DTYPE,
@@ -241,19 +251,24 @@ static  Tensor createTensor(NdArray src, Function getB
         data -> ((TStringImpl) data).tensorBuffer.init(src, getBytes));
   }
 
-  static TString mapTensor(TF_Tensor nativeTensor, Shape shape) {
-    StringTensorBuffer buffer = TensorBuffers.toStrings(nativeTensor, shape.size());
-    return new TStringImpl(buffer, UTF_8_LAYOUT, shape);
+  static TString mapTensor(RawTensor tensor, TF_Tensor nativeHandle) {
+    StringTensorBuffer buffer = TensorBuffers.toStrings(nativeHandle, tensor.shape().size());
+    return new TStringImpl(tensor, buffer, UTF_8_LAYOUT);
   }
 
-  private static DataLayout, String> UTF_8_LAYOUT =
+  private static final DataLayout, String> UTF_8_LAYOUT =
       DataLayouts.ofStrings(StandardCharsets.UTF_8);
 
+  private final RawTensor rawTensor;
   private final StringTensorBuffer tensorBuffer;
 
   private TStringImpl(
-      StringTensorBuffer buffer, DataLayout, String> layout, Shape shape) {
-    super(layout.applyTo(buffer), shape);
+      RawTensor rawTensor,
+      StringTensorBuffer buffer,
+      DataLayout, String> layout
+  ) {
+    super(layout.applyTo(buffer), rawTensor.shape());
+    this.rawTensor = rawTensor;
     tensorBuffer = buffer;
   }
 }
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java
index 365f41196fb..fd05857c295 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java
@@ -19,15 +19,16 @@
 
 import java.util.function.Consumer;
 import org.tensorflow.DataType;
+import org.tensorflow.RawTensor;
 import org.tensorflow.Tensor;
 import org.tensorflow.exceptions.TensorFlowException;
 import org.tensorflow.internal.buffer.TensorBuffers;
 import org.tensorflow.internal.c_api.TF_Tensor;
-import org.tensorflow.ndarray.Shape;
-import org.tensorflow.ndarray.buffer.ByteDataBuffer;
 import org.tensorflow.ndarray.ByteNdArray;
 import org.tensorflow.ndarray.NdArray;
+import org.tensorflow.ndarray.Shape;
 import org.tensorflow.ndarray.StdArrays;
+import org.tensorflow.ndarray.buffer.ByteDataBuffer;
 import org.tensorflow.ndarray.impl.dense.ByteDenseNdArray;
 import org.tensorflow.types.family.TNumber;
 
@@ -46,7 +47,7 @@ public interface TUint8 extends ByteNdArray, TNumber {
    * @param value byte to store in the new tensor
    * @return the new tensor
    */
-  static Tensor scalarOf(byte value) {
+  static TUint8 scalarOf(byte value) {
     return Tensor.of(DTYPE, Shape.scalar(), data -> data.setByte(value));
   }
 
@@ -56,7 +57,7 @@ static Tensor scalarOf(byte value) {
    * @param values bytes to store in the new tensor
    * @return the new tensor
    */
-  static Tensor vectorOf(byte... values) {
+  static TUint8 vectorOf(byte... values) {
     if (values == null) {
       throw new IllegalArgumentException();
     }
@@ -71,7 +72,7 @@ static Tensor vectorOf(byte... values) {
    * @param src the source array giving the shape and data to the new tensor
    * @return the new tensor
    */
-  static Tensor tensorOf(NdArray src) {
+  static TUint8 tensorOf(NdArray src) {
     return Tensor.of(DTYPE, src.shape(), src::copyTo);
   }
 
@@ -81,7 +82,7 @@ static Tensor tensorOf(NdArray src) {
    * @param shape shape of the tensor to allocate
    * @return the new tensor
    */
-  static Tensor tensorOf(Shape shape) {
+  static TUint8 tensorOf(Shape shape) {
     return Tensor.of(DTYPE, shape);
   }
 
@@ -92,7 +93,7 @@ static Tensor tensorOf(Shape shape) {
    * @param data buffer of bytes to initialize the tensor with
    * @return the new tensor
    */
-  static Tensor tensorOf(Shape shape, ByteDataBuffer data) {
+  static TUint8 tensorOf(Shape shape, ByteDataBuffer data) {
     return Tensor.of(DTYPE, shape, d -> d.write(data));
   }
 
@@ -104,7 +105,7 @@ static Tensor tensorOf(Shape shape, ByteDataBuffer data) {
    * @return the new tensor
    * @throws TensorFlowException if the tensor cannot be allocated or initialized
    */
-  static Tensor tensorOf(Shape shape, Consumer dataInit) {
+  static TUint8 tensorOf(Shape shape, Consumer dataInit) {
     return Tensor.of(DTYPE, shape, dataInit);
   }
 }
@@ -112,11 +113,25 @@ static Tensor tensorOf(Shape shape, Consumer dataInit) {
 /** Hidden implementation of a {@code TUint8} */
 class TUint8Impl extends ByteDenseNdArray implements TUint8 {
 
-  static TUint8 mapTensor(TF_Tensor nativeTensor, Shape shape) {
-    return new TUint8Impl(TensorBuffers.toBytes(nativeTensor), shape);
+  @Override
+  public DataType dataType() {
+    return TUint8.DTYPE;
   }
 
-  private TUint8Impl(ByteDataBuffer buffer, Shape shape) {
-    super(buffer, shape);
+  @Override
+  public RawTensor asRawTensor() {
+    return rawTensor;
+  }
+
+  static TUint8 mapTensor(RawTensor tensor, TF_Tensor nativeHandle) {
+    ByteDataBuffer buffer = TensorBuffers.toBytes(nativeHandle);
+    return new TUint8Impl(tensor, buffer);
+  }
+
+  private final RawTensor rawTensor;
+
+  private TUint8Impl(RawTensor rawTensor, ByteDataBuffer buffer) {
+    super(buffer, rawTensor.shape());
+    this.rawTensor = rawTensor;
   }
 }
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TFloating.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TFloating.java
index 92deaffdc68..2c5d19f3b62 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TFloating.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TFloating.java
@@ -1,19 +1,20 @@
 package org.tensorflow.types.family;
 
 /**
- * Marker interface for floating point tensor types.
+ * Common interface for all floating point tensors.
  *
  * 

Operations that only accepts floating point values as some of their operands enforce that the tensor * types for these operands to be bound to this interface. For example: * *

{@code
- * TFloat32 tensor1 = TFloat32.vectorOf(1, 2, 3);
- * TBool tensor2 = TBool.vectorOf(true, false, true);
- *
  * Ops tf = Ops.create();
+ *
+ * Constant c1 = tf.array(1.0f, 2.0f, 3.0f);
+ * Constant c2 = tf.array(true, false, true);
+ *
  * Exponential exp = new Exponential<>(tf);
- * exp.call(tf.constant(tensor1));  // OK
- * exp.call(tf.constant(tensor2));  // Compilation failure
+ * exp.call(c1);  // OK
+ * exp.call(c2);  // Compilation failure
  * }
*/ public interface TFloating extends TNumber {} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TNumber.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TNumber.java index 97ee59af095..1a1e094e9f5 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TNumber.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TNumber.java @@ -18,18 +18,19 @@ package org.tensorflow.types.family; /** - * Marker interface for numeric tensor types. + * Common interface for all numeric tensors. * *

Operations that only accepts numeric values as some of their operands enforce that the tensor * types for these operands to be bound to this interface. For example: * *

{@code
- * TFloat32 tensor1 = TFloat32.vectorOf(1, 2, 3);
- * TBool tensor2 = TBool.vectorOf(true, false, true);
- *
  * Ops tf = Ops.create();
- * tf.nn.softmax(tf.constant(tensor1));  // OK
- * tf.nn.softmax(tf.constant(tensor2));  // Compilation failure
+ *
+ * Constant c1 = tf.array(1.0f, 2.0f, 3.0f);
+ * Constant c2 = tf.array(true, false, true);
+ *
+ * tf.nn.softmax(c1);  // OK
+ * tf.nn.softmax(c2);  // Compilation failure
  * }
*/ public interface TNumber extends TType {} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java index 8f3451b9a68..304c070cc8e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java @@ -17,21 +17,40 @@ package org.tensorflow.types.family; +import org.tensorflow.Tensor; + /** - * Marker interface for all tensor types. + * Common interface for all typed tensors. * - *

Tensor types are carried as a generic parameter of the {@link org.tensorflow.Tensor Tensor} - * class bound by the {@code TType} interface. This generic parameter ensure type-compatibility - * between operands of a computation at compile-time. For example: + *

Typed tensors wraps a {@link RawTensor} by mapping their native memory to a n-dimensional + * data space allowing direct I/O access from the JVM.

* - *
{@code
- * Tensor tensor1 = TFloat32.ofShape(2, 3, 2);
- * Tensor tensor2 = TFloat32.ofShape(2, 3, 2);
- * Tensor tensor3 = TInt32.ofShape(2, 3, 2);
+ * 

Subinterfaces of {@code TType} are propagated as a generic parameter to various entities of + * TensorFlow to identify the type of the tensor they carry. For example, a + * {@link org.tensorflow.Operand Operand} is an operand which outputs is a 32-bit floating + * point tensor. This parameter ensure type-compatibility between operands of a computation at + * compile-time. For example: * + *

{@code
  * Ops tf = Ops.create();
- * tf.math.add(tf.constant(tensor1), tf.constant(tensor2));  // OK
- * tf.math.add(tf.constant(tensor1), tf.constant(tensor3));  // Compilation failure
+ *
+ * Constant c1 = tf.array(2.0f, 3.0f, 2.0f);
+ * Constant c2 = tf.array(1.0f, 2.0f, 3.0f);
+ * Constant c3 = tf.array(2, 3, 2);
+ *
+ * tf.math.add(c1, c2);  // OK
+ * tf.math.add(c1, c3);  // Compilation failure
  * }
*/ -public interface TType {} +public interface TType extends Tensor { + + @Override + default long numBytes() { + return asRawTensor().numBytes(); + } + + @Override + default void close() { + asRawTensor().close(); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java index 3ea20fcbb46..72dcfb12430 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -44,8 +44,8 @@ private static Signature minusTwo(Ops tf) { @Test public void createFunction() { try (ConcreteFunction f = ConcreteFunction.create(ConcreteFunctionTest::plusFive); - Tensor x = TFloat32.scalarOf(3.0f)) { - assertEquals(8.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); + TFloat32 x = TFloat32.scalarOf(3.0f)) { + assertEquals(8.0f, ((TFloat32)f.call(x)).getFloat()); } } @@ -54,8 +54,8 @@ public void createFunctionFromGraph() { try (Graph g = new Graph()) { Signature signature = plusFive(Ops.create(g)); try (ConcreteFunction f = ConcreteFunction.create(signature, g); - Tensor x = TFloat32.scalarOf(3.0f)) { - assertEquals(8.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); + TFloat32 x = TFloat32.scalarOf(3.0f)) { + assertEquals(8.0f, ((TFloat32)f.call(x)).getFloat()); } } } @@ -66,8 +66,8 @@ public void createFunctionFromSession() { Signature signature = plusFive(Ops.create(g)); try (Session s = new Session(g)) { try (ConcreteFunction f = ConcreteFunction.create(signature, s); - Tensor x = TFloat32.scalarOf(3.0f)) { - assertEquals(8.0f, f.call(x).expect(TFloat32.DTYPE).data().getFloat()); + TFloat32 x = TFloat32.scalarOf(3.0f)) { + assertEquals(8.0f, ((TFloat32)f.call(x)).getFloat()); } } } @@ -77,8 +77,8 @@ public void createFunctionFromSession() { public void chainFunctions() { try (ConcreteFunction f1 = ConcreteFunction.create(ConcreteFunctionTest::plusFive); ConcreteFunction f2 = ConcreteFunction.create(ConcreteFunctionTest::minusTwo); - Tensor x = TFloat32.scalarOf(3.0f)) { - assertEquals(6.0f, f2.call(f1.call(x)).expect(TFloat32.DTYPE).data().getFloat()); + TFloat32 x = TFloat32.scalarOf(3.0f)) { + assertEquals(6.0f, ((TFloat32)f2.call(f1.call(x))).getFloat()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java index b6a5a50a7a4..6751c513ef3 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java @@ -88,7 +88,7 @@ public void setAttrs() { try (EagerSession session = EagerSession.create()) { Ops tf = Ops.create(session); // dtype, tensor attributes. - try (Tensor t = TInt32.scalarOf(1)) { + try (TInt32 t = TInt32.scalarOf(1)) { opBuilder(session, "Const", "DataTypeAndTensor") .setAttr("dtype", TInt32.DTYPE) .setAttr("value", t) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java index 09d2214cc6a..a14a295fddd 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java @@ -46,7 +46,7 @@ public void failToCreateIfSessionIsClosed() { @Test public void outputDataTypeAndShape() { try (EagerSession session = EagerSession.create(); - Tensor t = TInt32.tensorOf(Shape.of(2, 3))) { + TInt32 t = TInt32.tensorOf(Shape.of(2, 3))) { EagerOperation op = opBuilder(session, "Const", "OutputAttrs") .setAttr("dtype", TInt32.DTYPE) @@ -67,7 +67,7 @@ public void outputTensor() { .addInput(tf.constant(2).asOutput()) .addInput(tf.constant(4).asOutput()) .build(); - assertEquals(6, add.tensor(0).expect(TInt32.DTYPE).data().getInt()); + assertEquals(6, ((TInt32)add.tensor(0)).getInt()); // Validate that we retrieve the right shape and datatype from the tensor // that has been resolved diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java index 35bfa808238..7573a25ac13 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java @@ -50,7 +50,7 @@ public void failWhenMixingOperationsOnDifferentGraphs() { @Test public void failOnUseAfterBuild() { try (Graph g = new Graph(); - Tensor t = TInt32.scalarOf(1)) { + TInt32 t = TInt32.scalarOf(1)) { OperationBuilder b = g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t); b.build(); @@ -66,7 +66,7 @@ public void failOnUseAfterBuild() { public void failOnUseAfterGraphClose() { OperationBuilder b = null; try (Graph g = new Graph(); - Tensor t = TInt32.scalarOf(1)) { + TInt32 t = TInt32.scalarOf(1)) { b = g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t); } try { @@ -88,7 +88,7 @@ public void setAttr() { try (Graph g = new Graph()) { Ops tf = Ops.create(g); // dtype, tensor attributes. - try (Tensor t = TInt32.scalarOf(1)) { + try (TInt32 t = TInt32.scalarOf(1)) { g.opBuilder("Const", "DataTypeAndTensor") .setAttr("dtype", TInt32.DTYPE) .setAttr("value", t) @@ -169,8 +169,8 @@ public void setAttrShapeList() { public void addControlInput() { try (Graph g = new Graph(); Session s = new Session(g); - Tensor yes = TBool.scalarOf(true); - Tensor no = TBool.scalarOf(false)) { + TBool yes = TBool.scalarOf(true); + TBool no = TBool.scalarOf(false)) { Ops tf = Ops.create(g); Output placeholder = tf.placeholder(TBool.DTYPE).asOutput(); GraphOperation check = diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationTest.java index d6f5ab9a6d9..b164c129745 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationTest.java @@ -183,7 +183,7 @@ public void outputTensorNotSupported() { Ops tf = Ops.create(g); Operation split = tf.split(tf.constant(0), tf.array(0, 1, 2), 3L).op(); try { - split.output(0).tensor(); + split.output(0).asTensor(); fail(); } catch (IllegalStateException e) { } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java index de376015e3f..a9eed79041a 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java @@ -157,9 +157,9 @@ public void addGradientsToGraph() { assertEquals(TFloat32.DTYPE, grads1[0].dataType()); assertEquals(TFloat32.DTYPE, grads1[1].dataType()); - try (Tensor c1 = TFloat32.scalarOf(3.0f); - Tensor c2 = TFloat32.scalarOf(2.0f); - AutoCloseableList> outputs = new AutoCloseableList<>( + try (TFloat32 c1 = TFloat32.scalarOf(3.0f); + TFloat32 c2 = TFloat32.scalarOf(2.0f); + AutoCloseableList outputs = new AutoCloseableList<>( s.runner() .feed(x1, c1) .feed(x2, c2) @@ -169,9 +169,9 @@ public void addGradientsToGraph() { .run())) { assertEquals(3, outputs.size()); - assertEquals(108.0f, outputs.get(0).expect(TFloat32.DTYPE).data().getFloat(), 0.0f); - assertEquals(6.0f, outputs.get(1).expect(TFloat32.DTYPE).data().getFloat(), 0.0f); - assertEquals(1.0f, outputs.get(2).expect(TFloat32.DTYPE).data().getFloat(), 0.0f); + assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); + assertEquals(6.0f, ((TFloat32)outputs.get(1)).getFloat(), 0.0f); + assertEquals(1.0f, ((TFloat32)outputs.get(2)).getFloat(), 0.0f); } } } @@ -191,14 +191,13 @@ public void addGradientSumsToGraph() { assertEquals(1, grad.length); assertEquals(TFloat32.DTYPE, grad[0].dataType()); - try (Tensor c = TFloat32.scalarOf(3.0f); - Tensor output = s.runner() + try (TFloat32 c = TFloat32.scalarOf(3.0f); + TFloat32 output = (TFloat32)s.runner() .feed(x, c) .fetch(grad[0]) .run() - .get(0) - .expect(TFloat32.DTYPE)) { - assertEquals(114.0f, output.data().getFloat(), 0.0f); + .get(0)) { + assertEquals(114.0f, output.getFloat(), 0.0f); } } } @@ -223,14 +222,13 @@ public void addGradientsWithInitialValuesToGraph() { assertEquals(1, grad1.length); assertEquals(TFloat32.DTYPE, grad1[0].dataType()); - try (Tensor c = TFloat32.scalarOf(3.0f); - Tensor output = s.runner() + try (TFloat32 c = TFloat32.scalarOf(3.0f); + TFloat32 output = (TFloat32)s.runner() .feed(x, c) .fetch(grad1[0]) .run() - .get(0) - .expect(TFloat32.DTYPE)) { - assertEquals(108.0f, output.data().getFloat(), 0.0f); + .get(0)) { + assertEquals(108.0f, output.getFloat(), 0.0f); } } } @@ -284,14 +282,13 @@ public void buildWhileLoopSingleInput() { }, "test_loop"); - try (Tensor c = TInt32.scalarOf(2); - Tensor output = s.runner() + try (TInt32 c = TInt32.scalarOf(2); + TInt32 output = (TInt32)s.runner() .feed(input, c) .fetch(loopOutputs[0]) .run() - .get(0) - .expect(TInt32.DTYPE)) { - assertEquals(16, output.data().getInt()); // ((2^2)^2) + .get(0)) { + assertEquals(16, output.getInt()); // ((2^2)^2) } } } @@ -320,9 +317,9 @@ public void buildWhileLoopMultipleInputs() { }, "test_loop"); - try (Tensor c1 = TInt32.scalarOf(2); - Tensor c2 = TInt32.scalarOf(5); - AutoCloseableList> outputs = + try (TInt32 c1 = TInt32.scalarOf(2); + TInt32 c2 = TInt32.scalarOf(5); + AutoCloseableList outputs = new AutoCloseableList<>( s.runner() .feed(input1, c1) @@ -331,8 +328,8 @@ public void buildWhileLoopMultipleInputs() { .fetch(loopOutputs[1]) .run())) { assertEquals(2, outputs.size()); - assertEquals(16, outputs.get(0).expect(TInt32.DTYPE).data().getInt()); // ((2^2)^2) - assertEquals(625, outputs.get(1).expect(TInt32.DTYPE).data().getInt()); // ((5^2)^2) + assertEquals(16, ((TInt32)outputs.get(0)).getInt()); // ((2^2)^2) + assertEquals(625, ((TInt32)outputs.get(1)).getInt()); // ((5^2)^2) } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index d807d13de00..417205f1f38 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -107,9 +107,9 @@ public void exportFunctionWithVariables() throws IOException { f.session().run(Init.DEFAULT_NAME); // Call the graph and remember the result of computation for later - try (Tensor xTensor = TFloat32.tensorOf(xValue); - Tensor zTensor = f.call(xTensor).expect(TFloat32.DTYPE)) { - reducedSum = zTensor.data().getFloat(); + try (TFloat32 xTensor = TFloat32.tensorOf(xValue); + TFloat32 zTensor = (TFloat32)f.call(xTensor)) { + reducedSum = zTensor.getFloat(); } // Save/export the model (which is a single function in this case) f.save(testFolder.toString()); @@ -153,15 +153,15 @@ public void exportFunctionWithVariables() throws IOException { assertNotNull(outputInfo); assertEquals(0, outputInfo.getTensorShape().getDimCount()); - try (Tensor xTensor = TFloat32.tensorOf(xValue)) { + try (TFloat32 xTensor = TFloat32.tensorOf(xValue)) { // Call the saved model function and make sure it returns the same result as before - try (Tensor zTensor = function.call(xTensor).expect(TFloat32.DTYPE)) { - assertEquals(reducedSum, zTensor.data().getFloat(), EPSILON); + try (TFloat32 zTensor = (TFloat32)function.call(xTensor)) { + assertEquals(reducedSum, zTensor.getFloat(), EPSILON); } // Now call the same function directly from the model - try (Tensor zTensor = - savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum").expect(TFloat32.DTYPE)) { - assertEquals(reducedSum, zTensor.data().getFloat(), EPSILON); + try (TFloat32 zTensor = + (TFloat32)savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum")) { + assertEquals(reducedSum, zTensor.getFloat(), EPSILON); } } } @@ -179,9 +179,9 @@ public void exportMultipleFunctions() throws IOException { ConcreteFunction f1 = ConcreteFunction.create(f1Signature, s); ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s)) { f1.session().run(Init.DEFAULT_NAME); - try (Tensor x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); - Tensor t = f1.call(x).expect(TFloat32.DTYPE)) { - reducedSum = t.data().getFloat(); + try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); + TFloat32 t = (TFloat32)f1.call(x)) { + reducedSum = t.getFloat(); } SavedModelBundle.exporter(testFolder.toString()) .withFunction(f1) @@ -193,15 +193,15 @@ public void exportMultipleFunctions() throws IOException { assertEquals(2, model.signatures().size()); ConcreteFunction f1 = model.function(Signature.DEFAULT_KEY); assertNotNull(f1); - try (Tensor x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); - Tensor t = f1.call(x).expect(TFloat32.DTYPE)) { - assertEquals(reducedSum, t.data().getFloat(), EPSILON); + try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); + TFloat32 t = (TFloat32)f1.call(x)) { + assertEquals(reducedSum, t.getFloat(), EPSILON); } ConcreteFunction f2 = model.function("identity"); assertNotNull(f2); - try (Tensor x = TFloat32.scalarOf(10.0f); - Tensor t = f2.call(x).expect(TFloat32.DTYPE)) { - assertEquals(10.0f, t.data().getFloat(), 0.0f); + try (TFloat32 x = TFloat32.scalarOf(10.0f); + TFloat32 t = (TFloat32)f2.call(x)) { + assertEquals(10.0f, t.getFloat(), 0.0f); } try { model.function("NoSuchFunction"); @@ -290,15 +290,15 @@ public void pythonTfFunction() { * Signature name used for saving 'add', argument names 'a' and 'b' */ ConcreteFunction add = bundle.function("add"); - Map> args = new HashMap(); - try (Tensor a = TFloat32.scalarOf(10.0f); - Tensor b = TFloat32.scalarOf(15.5f)) { + Map args = new HashMap(); + try (TFloat32 a = TFloat32.scalarOf(10.0f); + TFloat32 b = TFloat32.scalarOf(15.5f)) { args.put("a", a); args.put("b", b); - Map> result = add.call(args); + Map result = add.call(args); assertEquals(result.size(), 1); - try (Tensor c = result.values().iterator().next().expect(TFloat32.DTYPE)) { - assertEquals(25.5f, c.data().getFloat()); + try (TFloat32 c = (TFloat32)result.values().iterator().next()) { + assertEquals(25.5f, c.getFloat()); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index fa41af32a29..0556c1ff17f 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -48,11 +48,11 @@ public void runUsingOperationNames() { Session s = new Session(g)) { Ops tf = Ops.create(g); transpose_A_times_X(tf, new int[][] {{2}, {3}}); - try (Tensor x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); - AutoCloseableList> outputs = + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); + AutoCloseableList outputs = new AutoCloseableList<>(s.runner().feed("X", x).fetch("Y").run())) { assertEquals(1, outputs.size()); - assertEquals(31, outputs.get(0).expect(TInt32.DTYPE).data().getInt(0, 0)); + assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); } } } @@ -65,11 +65,11 @@ public void runUsingOperationHandles() { transpose_A_times_X(tf, new int[][] {{2}, {3}}); Output feed = g.operation("X").output(0); Output fetch = g.operation("Y").output(0); - try (Tensor x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); - AutoCloseableList> outputs = + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); + AutoCloseableList outputs = new AutoCloseableList<>(s.runner().feed(feed, x).fetch(fetch).run())) { assertEquals(1, outputs.size()); - assertEquals(31, outputs.get(0).expect(TInt32.DTYPE).data().getInt(0, 0)); + assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); } } } @@ -83,22 +83,19 @@ public void runUsingColonSeparatedNames() { tf.math.add(split.output().get(0), split.output().get(1)); // Fetch using colon separated names. - try (Tensor fetched = - s.runner().fetch("Split:1").run().get(0).expect(TInt32.DTYPE)) { - assertEquals(3, fetched.data().getInt(0)); - assertEquals(4, fetched.data().getInt(1)); + try (TInt32 fetched = (TInt32)s.runner().fetch("Split:1").run().get(0)) { + assertEquals(3, fetched.getInt(0)); + assertEquals(4, fetched.getInt(1)); } // Feed using colon separated names. - try (Tensor fed = TInt32.vectorOf(4, 3, 2, 1); - Tensor fetched = - s.runner() + try (TInt32 fed = TInt32.vectorOf(4, 3, 2, 1); + TInt32 fetched = (TInt32) s.runner() .feed("Split:0", fed) .feed("Split:1", fed) .fetch("Add") .run() - .get(0) - .expect(TInt32.DTYPE)) { - assertEquals(NdArrays.vectorOf(8, 6, 4, 2), fetched.data()); + .get(0)) { + assertEquals(NdArrays.vectorOf(8, 6, 4, 2), fetched); } } } @@ -109,17 +106,16 @@ public void runWithMetadata() { Session s = new Session(g)) { Ops tf = Ops.create(g); transpose_A_times_X(tf, new int[][] {{2}, {3}}); - try (Tensor x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}))) { - Session.Run result = - s.runner() + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}))) { + Session.Run result = s.runner() .feed("X", x) .fetch("Y") .setOptions(fullTraceRunOptions()) .runAndFetchMetadata(); // Sanity check on outputs. - AutoCloseableList> outputs = new AutoCloseableList<>(result.outputs); + AutoCloseableList outputs = new AutoCloseableList<>(result.outputs); assertEquals(1, outputs.size()); - assertEquals(31, outputs.get(0).expect(TInt32.DTYPE).data().getInt(0, 0)); + assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); // Sanity check on metadata assertNotNull(result.metadata); assertTrue(result.metadata.hasStepStats(), result.metadata.toString()); @@ -135,11 +131,11 @@ public void runMultipleOutputs() { Ops tf = Ops.create(g); tf.withName("c1").constant(2718); tf.withName("c2").constant(31415); - AutoCloseableList> outputs = + AutoCloseableList outputs = new AutoCloseableList<>(s.runner().fetch("c2").fetch("c1").run()); assertEquals(2, outputs.size()); - assertEquals(31415, outputs.get(0).expect(TInt32.DTYPE).data().getInt()); - assertEquals(2718, outputs.get(1).expect(TInt32.DTYPE).data().getInt()); + assertEquals(31415, ((TInt32)outputs.get(0)).getInt()); + assertEquals(2718, ((TInt32)outputs.get(1)).getInt()); outputs.close(); } } @@ -177,8 +173,8 @@ public void runInit() { try (Session s = new Session(g)) { s.run(tf.init()); - try (Tensor t = s.runner().fetch(add).run().get(0).expect(TInt32.DTYPE)) { - assertEquals(30, t.data().getInt()); + try (TInt32 t = (TInt32) s.runner().fetch(add).run().get(0)) { + assertEquals(30, t.getInt()); } } } @@ -198,8 +194,8 @@ public void runInitByName() { try (Session s = new Session(g)) { s.run("init_test"); - try (Tensor t = s.runner().fetch(add).run().get(0).expect(TInt32.DTYPE)) { - assertEquals(30, t.data().getInt()); + try (TInt32 t = (TInt32) s.runner().fetch(add).run().get(0)) { + assertEquals(30, t.getInt()); } try { s.run("wrong_name"); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java index 01ef11efedd..4da9bce9e90 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java @@ -64,22 +64,22 @@ public void createWithRawData() { String strings = "test"; Shape strings_shape = Shape.scalar(); byte[] strings_; // raw TF_STRING - try (Tensor t = TString.tensorOf(NdArrays.scalarOfObject(strings))) { + try (TString t = TString.tensorOf(NdArrays.scalarOfObject(strings))) { strings_ = new byte[(int)t.numBytes()]; - t.rawData().read(strings_); + t.asRawTensor().data().read(strings_); } // validate creating a tensor using a raw data byte buffers { - try (Tensor t = Tensor.of(TBool.DTYPE, bools_shape, DataBuffers.of(bools_))) { + try (TBool t = Tensor.of(TBool.DTYPE, bools_shape, DataBuffers.of(bools_))) { boolean[] actual = new boolean[bools_.length]; - t.data().read(DataBuffers.of(actual)); + t.read(DataBuffers.of(actual)); assertArrayEquals(bools, actual); } // note: the buffer is expected to contain raw TF_STRING (as per C API) - try (Tensor t = Tensor.of(TString.DTYPE, strings_shape, DataBuffers.of(strings_))) { - assertEquals(strings, t.data().getObject()); + try (TString t = Tensor.of(TString.DTYPE, strings_shape, DataBuffers.of(strings_))) { + assertEquals(strings, t.getObject()); } } @@ -87,15 +87,15 @@ public void createWithRawData() { { DoubleBuffer buf = ByteBuffer.allocateDirect(8 * doubles.length).order(ByteOrder.nativeOrder()) .asDoubleBuffer().put(doubles); - try (Tensor t = TFloat64.tensorOf(doubles_shape, d -> d.write(DataBuffers.of(buf)))) { + try (TFloat64 t = TFloat64.tensorOf(doubles_shape, d -> d.write(DataBuffers.of(buf)))) { double[] actual = new double[doubles.length]; - t.data().read(DataBuffers.of(actual)); + t.read(DataBuffers.of(actual)); assertArrayEquals(doubles, actual, EPSILON); } } // validate shape checking - try (Tensor t = Tensor.of(TBool.DTYPE, Shape.of(bools_.length * 2), DataBuffers.of(bools_))) { + try (TBool t = Tensor.of(TBool.DTYPE, Shape.of(bools_.length * 2), DataBuffers.of(bools_))) { fail("should have failed on incompatible buffer"); } catch (IllegalArgumentException e) { // expected @@ -111,9 +111,9 @@ public void createFromBufferWithNativeByteOrder() { .asDoubleBuffer() .put(doubles); flipBuffer(buf); - try (Tensor t = TFloat64.tensorOf(Shape.of(4), DataBuffers.of(buf))) { + try (TFloat64 t = TFloat64.tensorOf(Shape.of(4), DataBuffers.of(buf))) { double[] actual = new double[doubles.length]; - t.data().read(DataBuffers.of(actual)); + t.read(DataBuffers.of(actual)); assertArrayEquals(doubles, actual, EPSILON); } } @@ -130,9 +130,9 @@ public void createFromBufferWithNonNativeByteOrder() { .asDoubleBuffer() .put(doubles); flipBuffer(buf); - try (Tensor t = TFloat64.tensorOf(Shape.of(4), DataBuffers.of(buf))) { + try (TFloat64 t = TFloat64.tensorOf(Shape.of(4), DataBuffers.of(buf))) { double[] actual = new double[doubles.length]; - t.data().read(DataBuffers.of(actual)); + t.read(DataBuffers.of(actual)); assertArrayEquals(doubles, actual, EPSILON); } } @@ -147,24 +147,24 @@ public void createWithTypedBuffer() { // validate creating a tensor using a typed buffer { Shape shape = Shape.of(4); - try (Tensor t = TFloat64.tensorOf(shape, DataBuffers.of(doubles))) { + try (TFloat64 t = TFloat64.tensorOf(shape, DataBuffers.of(doubles))) { DoubleBuffer actual = DoubleBuffer.allocate(doubles.capacity()); - t.data().read(DataBuffers.of(actual)); + t.read(DataBuffers.of(actual)); assertEquals(doubles, actual); } - try (Tensor t = TFloat32.tensorOf(shape, DataBuffers.of(floats))) { + try (TFloat32 t = TFloat32.tensorOf(shape, DataBuffers.of(floats))) { FloatBuffer actual = FloatBuffer.allocate(floats.capacity()); - t.data().read(DataBuffers.of(actual)); + t.read(DataBuffers.of(actual)); assertEquals(floats, actual); } - try (Tensor t = TInt32.tensorOf(shape, DataBuffers.of(ints))) { + try (TInt32 t = TInt32.tensorOf(shape, DataBuffers.of(ints))) { IntBuffer actual = IntBuffer.allocate(ints.capacity()); - t.data().read(DataBuffers.of(actual)); + t.read(DataBuffers.of(actual)); assertEquals(ints, actual); } - try (Tensor t = TInt64.tensorOf(shape, DataBuffers.of(longs))) { + try (TInt64 t = TInt64.tensorOf(shape, DataBuffers.of(longs))) { LongBuffer actual = LongBuffer.allocate(longs.capacity()); - t.data().read(DataBuffers.of(actual)); + t.read(DataBuffers.of(actual)); assertEquals(longs, actual); } } @@ -172,22 +172,22 @@ public void createWithTypedBuffer() { // validate shape-checking { Shape shape = Shape.of(5); - try (Tensor t = TFloat64.tensorOf(shape, DataBuffers.of(doubles))) { + try (TFloat64 t = TFloat64.tensorOf(shape, DataBuffers.of(doubles))) { fail("should have failed on incompatible buffer"); } catch (BufferUnderflowException e) { // expected } - try (Tensor t = TFloat32.tensorOf(shape, DataBuffers.of(floats))) { + try (TFloat32 t = TFloat32.tensorOf(shape, DataBuffers.of(floats))) { fail("should have failed on incompatible buffer"); } catch (BufferUnderflowException e) { // expected } - try (Tensor t = TInt32.tensorOf(shape, DataBuffers.of(ints))) { + try (TInt32 t = TInt32.tensorOf(shape, DataBuffers.of(ints))) { fail("should have failed on incompatible buffer"); } catch (BufferUnderflowException e) { // expected } - try (Tensor t = TInt64.tensorOf(shape, DataBuffers.of(longs))) { + try (TInt64 t = TInt64.tensorOf(shape, DataBuffers.of(longs))) { fail("should have failed on incompatible buffer"); } catch (BufferUnderflowException e) { // expected @@ -203,39 +203,39 @@ public void readFromRawData() { long[] longs = {1L, 2L, 3L}; boolean[] bools = {true, false, true}; - try (Tensor tints = TInt32.vectorOf(ints); - Tensor tfloats = TFloat32.vectorOf(floats); - Tensor tdoubles = TFloat64.vectorOf(doubles); - Tensor tlongs = TInt64.vectorOf(longs); - Tensor tbools = TBool.vectorOf(bools)) { + try (TInt32 tints = TInt32.vectorOf(ints); + TFloat32 tfloats = TFloat32.vectorOf(floats); + TFloat64 tdoubles = TFloat64.vectorOf(doubles); + TInt64 tlongs = TInt64.vectorOf(longs); + TBool tbools = TBool.vectorOf(bools)) { // validate that any datatype is readable with ByteBuffer (content, position) { ByteBuffer bbuf = ByteBuffer.allocate(1024).order(ByteOrder.nativeOrder()); clearBuffer(bbuf); // FLOAT - assertEquals(tfloats.numBytes(), tfloats.rawData().size()); - tfloats.rawData().copyTo(DataBuffers.of(bbuf), tfloats.numBytes()); + assertEquals(tfloats.numBytes(), tfloats.asRawTensor().data().size()); + tfloats.asRawTensor().data().copyTo(DataBuffers.of(bbuf), tfloats.numBytes()); assertEquals(floats[0], bbuf.asFloatBuffer().get(0), EPSILON); clearBuffer(bbuf); // DOUBLE - assertEquals(tdoubles.numBytes(), tdoubles.rawData().size()); - tdoubles.rawData().copyTo(DataBuffers.of(bbuf), tdoubles.numBytes()); + assertEquals(tdoubles.numBytes(), tdoubles.asRawTensor().data().size()); + tdoubles.asRawTensor().data().copyTo(DataBuffers.of(bbuf), tdoubles.numBytes()); assertEquals(doubles[0], bbuf.asDoubleBuffer().get(0), EPSILON); - clearBuffer(bbuf); // INT32 - assertEquals(tints.numBytes(), tints.rawData().size()); - tints.rawData().copyTo(DataBuffers.of(bbuf), tints.numBytes()); + clearBuffer(bbuf); // INT3 + assertEquals(tints.numBytes(), tints.asRawTensor().data().size()); + tints.asRawTensor().data().copyTo(DataBuffers.of(bbuf), tints.numBytes()); assertEquals(ints[0], bbuf.asIntBuffer().get(0)); clearBuffer(bbuf); // INT64 - assertEquals(tlongs.numBytes(), tlongs.rawData().size()); - tlongs.rawData().copyTo(DataBuffers.of(bbuf), tlongs.numBytes()); + assertEquals(tlongs.numBytes(), tlongs.asRawTensor().data().size()); + tlongs.asRawTensor().data().copyTo(DataBuffers.of(bbuf), tlongs.numBytes()); assertEquals(longs[0], bbuf.asLongBuffer().get(0)); clearBuffer(bbuf); // BOOL - assertEquals(tbools.numBytes(), tbools.rawData().size()); - tbools.rawData().copyTo(DataBuffers.of(bbuf), tbools.numBytes()); + assertEquals(tbools.numBytes(), tbools.asRawTensor().data().size()); + tbools.asRawTensor().data().copyTo(DataBuffers.of(bbuf), tbools.numBytes()); assertEquals(bools[0], bbuf.get(0) != 0); } @@ -243,7 +243,7 @@ public void readFromRawData() { { ByteBuffer bbuf = ByteBuffer.allocateDirect((int)tdoubles.numBytes()).order(ByteOrder.nativeOrder()); - tdoubles.rawData().copyTo(DataBuffers.of(bbuf), tdoubles.numBytes()); + tdoubles.asRawTensor().data().copyTo(DataBuffers.of(bbuf), tdoubles.numBytes()); assertEquals(doubles[0], bbuf.asDoubleBuffer().get(0), EPSILON); } @@ -256,7 +256,7 @@ public void readFromRawData() { ? ByteOrder.BIG_ENDIAN : ByteOrder.LITTLE_ENDIAN) .asDoubleBuffer(); - tdoubles.rawData().asDoubles().copyTo(DataBuffers.of(foreignBuf), foreignBuf.capacity()); + tdoubles.asRawTensor().data().asDoubles().copyTo(DataBuffers.of(foreignBuf), foreignBuf.capacity()); double[] actual = new double[foreignBuf.remaining()]; foreignBuf.get(actual); assertArrayEquals(doubles, actual, EPSILON); @@ -266,79 +266,79 @@ public void readFromRawData() { @Test public void scalars() { - try (Tensor t = TFloat32.scalarOf(2.718f)) { + try (TFloat32 t = TFloat32.scalarOf(2.718f)) { assertEquals(TFloat32.DTYPE, t.dataType()); assertEquals(0, t.shape().numDimensions()); - assertEquals(2.718f, t.data().getFloat(), EPSILON_F); + assertEquals(2.718f, t.getFloat(), EPSILON_F); } - try (Tensor t = TFloat64.scalarOf(3.1415)) { + try (TFloat64 t = TFloat64.scalarOf(3.1415)) { assertEquals(TFloat64.DTYPE, t.dataType()); assertEquals(0, t.shape().numDimensions()); - assertEquals(3.1415, t.data().getDouble(), EPSILON); + assertEquals(3.1415, t.getDouble(), EPSILON); } - try (Tensor t = TInt32.scalarOf(-33)) { + try (TInt32 t = TInt32.scalarOf(-33)) { assertEquals(TInt32.DTYPE, t.dataType()); assertEquals(0, t.shape().numDimensions()); - assertEquals(-33, t.data().getInt()); + assertEquals(-33, t.getInt()); } - try (Tensor t = TInt64.scalarOf(8589934592L)) { + try (TInt64 t = TInt64.scalarOf(8589934592L)) { assertEquals(TInt64.DTYPE, t.dataType()); assertEquals(0, t.shape().numDimensions()); - assertEquals(8589934592L, t.data().getLong()); + assertEquals(8589934592L, t.getLong()); } - try (Tensor t = TBool.scalarOf(true)) { + try (TBool t = TBool.scalarOf(true)) { assertEquals(TBool.DTYPE, t.dataType()); assertEquals(0, t.shape().numDimensions()); - assertTrue(t.data().getBoolean()); + assertTrue(t.getBoolean()); } - try (Tensor t = TString.scalarOf("sombrero")) { + try (TString t = TString.scalarOf("sombrero")) { assertEquals(TString.DTYPE, t.dataType()); assertEquals(0, t.shape().numDimensions()); - assertEquals("sombrero", t.data().getObject()); + assertEquals("sombrero", t.getObject()); } final byte[] bytes = {1, 2, 3, 4}; - try (Tensor t = TString.tensorOfBytes(NdArrays.scalarOfObject(bytes))) { + try (TString t = TString.tensorOfBytes(NdArrays.scalarOfObject(bytes))) { assertEquals(TString.DTYPE, t.dataType()); assertEquals(0, t.shape().numDimensions()); - assertArrayEquals(bytes, t.data().asBytes().getObject()); + assertArrayEquals(bytes, t.asBytes().getObject()); } } @Test public void nDimensional() { DoubleNdArray vector = StdArrays.ndCopyOf(new double[]{1.414, 2.718, 3.1415}); - try (Tensor t = TFloat64.tensorOf(vector)) { + try (TFloat64 t = TFloat64.tensorOf(vector)) { assertEquals(TFloat64.DTYPE, t.dataType()); assertEquals(1, t.shape().numDimensions()); assertEquals(3, t.shape().size(0)); - assertEquals(vector, t.data()); + assertEquals(vector, t); } IntNdArray matrix = StdArrays.ndCopyOf(new int[][]{{1, 2, 3}, {4, 5, 6}}); - try (Tensor t = TInt32.tensorOf(matrix)) { + try (TInt32 t = TInt32.tensorOf(matrix)) { assertEquals(TInt32.DTYPE, t.dataType()); assertEquals(2, t.shape().numDimensions()); assertEquals(2, t.shape().size(0)); assertEquals(3, t.shape().size(1)); - assertEquals(matrix, t.data()); + assertEquals(matrix, t); } LongNdArray threeD = StdArrays.ndCopyOf(new long[][][]{ {{1}, {3}, {5}, {7}, {9}}, {{2}, {4}, {6}, {8}, {0}}, }); - try (Tensor t = TInt64.tensorOf(threeD)) { + try (TInt64 t = TInt64.tensorOf(threeD)) { assertEquals(TInt64.DTYPE, t.dataType()); assertEquals(3, t.shape().numDimensions()); assertEquals(2, t.shape().size(0)); assertEquals(5, t.shape().size(1)); assertEquals(1, t.shape().size(2)); - assertEquals(threeD, t.data()); + assertEquals(threeD, t); } BooleanNdArray fourD = StdArrays.ndCopyOf(new boolean[][][][]{ @@ -346,14 +346,14 @@ public void nDimensional() { {{{false, false, true, true}, {false, true, false, false}}}, {{{false, true, false, true}, {false, true, true, false}}}, }); - try (Tensor t = TBool.tensorOf(fourD)) { + try (TBool t = TBool.tensorOf(fourD)) { assertEquals(TBool.DTYPE, t.dataType()); assertEquals(4, t.shape().numDimensions()); assertEquals(3, t.shape().size(0)); assertEquals(1, t.shape().size(1)); assertEquals(2, t.shape().size(2)); assertEquals(4, t.shape().size(3)); - assertEquals(fourD, t.data()); + assertEquals(fourD, t); } } @@ -365,36 +365,36 @@ public void testNDimensionalStringTensor() { matrix.setObject(String.format("(%d, %d) = %d", i, j, i << j), i, j); } } - try (Tensor t = TString.tensorOf(matrix)) { + try (TString t = TString.tensorOf(matrix)) { assertEquals(TString.DTYPE, t.dataType()); assertEquals(2, t.shape().numDimensions()); assertEquals(4, t.shape().size(0)); assertEquals(3, t.shape().size(1)); - assertEquals(matrix, t.data()); + assertEquals(matrix, t); } NdArray byteMatrix = NdArrays.ofObjects(byte[].class, matrix.shape()); matrix.scalars().forEachIndexed((i, s) -> byteMatrix.setObject(s.getObject().getBytes(UTF_8), i)); - try (Tensor t = TString.tensorOfBytes(byteMatrix)) { + try (TString t = TString.tensorOfBytes(byteMatrix)) { assertEquals(TString.DTYPE, t.dataType()); assertEquals(2, t.shape().numDimensions()); assertEquals(4, t.shape().size(0)); assertEquals(3, t.shape().size(1)); - assertEquals(byteMatrix, t.data().asBytes()); - assertEquals(matrix, t.data()); + assertEquals(byteMatrix, t.asBytes()); + assertEquals(matrix, t); } } @Test public void testUint8TensorFromArray() { byte[] vector = new byte[] {1, 2, 3, 4}; - try (Tensor t = TUint8.vectorOf(vector)) { + try (TUint8 t = TUint8.vectorOf(vector)) { assertEquals(TUint8.DTYPE, t.dataType()); assertEquals(1, t.shape().numDimensions()); assertEquals(4, t.shape().size(0)); byte[] got = new byte[4]; - t.data().read(DataBuffers.of(got)); + t.read(DataBuffers.of(got)); assertArrayEquals(vector, got); } } @@ -402,13 +402,13 @@ public void testUint8TensorFromArray() { @Test public void testCreateFromArrayOfBoxed() { Integer[] vector = new Integer[] {1, 2, 3, 4}; - try (Tensor t = TInt32.tensorOf(Shape.of(4), d -> d.write(DataBuffers.ofObjects(vector)))) { + try (TInt32 t = TInt32.tensorOf(Shape.of(4), d -> d.write(DataBuffers.ofObjects(vector)))) { assertEquals(TInt32.DTYPE, t.dataType()); assertEquals(1, t.shape().numDimensions()); assertEquals(4, t.shape().size(0)); Integer[] got = new Integer[4]; - t.data().read(DataBuffers.ofObjects(got)); + t.read(DataBuffers.ofObjects(got)); assertArrayEquals(vector, got); } } @@ -421,7 +421,7 @@ public void failCreateOnMismatchedDimensions() { invalid[x][y] = new int[x + y + 1]; } } - try (Tensor t = TInt32.tensorOf(StdArrays.ndCopyOf(invalid))) { + try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(invalid))) { fail("Tensor.create() should fail because of differing sizes in the 3rd dimension"); } catch (IllegalArgumentException e) { // The expected exception. @@ -433,11 +433,11 @@ public void tensorWithZeroDimension() { // Note: Historically, TF Java failed on purpose when trying to allocate a tensor with a shape // that has one or more dimensions set to 0 elements. But Python API allows it, so we should do // the same. - try (Tensor t = TInt32.tensorOf(Shape.of(3, 0, 1))) { + try (TInt32 t = TInt32.tensorOf(Shape.of(3, 0, 1))) { assertEquals(0, t.numBytes()); assertEquals(0, t.shape().size()); } - try (Tensor t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[3][0][1]))) { + try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[3][0][1]))) { assertEquals(0, t.numBytes()); assertEquals(0, t.shape().size()); } @@ -445,10 +445,10 @@ public void tensorWithZeroDimension() { @Test public void allocateTensorWithSize() { - try (Tensor t = Tensor.of(TInt32.DTYPE, Shape.of(2, 2, 2), 8 * TInt32.DTYPE.byteSize())) { + try (TInt32 t = Tensor.of(TInt32.DTYPE, Shape.of(2, 2, 2), 8 * TInt32.DTYPE.byteSize())) { // ok } - try (Tensor t = Tensor.of(TInt32.DTYPE, Shape.of(2, 2, 2), 9 * TInt32.DTYPE.byteSize())) { + try (TInt32 t = Tensor.of(TInt32.DTYPE, Shape.of(2, 2, 2), 9 * TInt32.DTYPE.byteSize())) { // ok (size requested is larger that minimum space required) } try { @@ -462,10 +462,10 @@ public void allocateTensorWithSize() { @Test public void useAfterClose() { int n = 4; - Tensor t = TInt32.scalarOf(n); + TInt32 t = TInt32.scalarOf(n); t.close(); try { - t.data(); + t.numBytes(); } catch (IllegalStateException e) { // The expected exception. } @@ -473,25 +473,19 @@ public void useAfterClose() { @Test public void eagerTensorIsReleasedAfterSessionIsClosed() { - Tensor sum; + TInt32 sum; try (EagerSession session = EagerSession.create()) { Ops tf = Ops.create(session); sum = tf.math.add(tf.constant(10), tf.constant(20)).asTensor(); - sum.nativeHandle(); // does not throw - assertEquals(30, sum.data().getInt()); + sum.asRawTensor().nativeHandle(); // does not throw + assertEquals(30, sum.getInt()); } try { - sum.nativeHandle(); + sum.asRawTensor().nativeHandle(); fail("Tensor native handle should have been closed by ending eager session"); } catch (IllegalStateException e) { // as expected } - try { - sum.data().getInt(); - fail("Tensor data should not be accessible after tensor is closed"); - } catch (IllegalStateException e) { - // as expected - } } @Test @@ -503,12 +497,12 @@ public void fromHandle() { // An exception is made for this test, where the pitfalls of this is avoided by not calling // close() on both Tensors. final FloatNdArray matrix = StdArrays.ndCopyOf(new float[][]{{1, 2, 3}, {4, 5, 6}}); - try (Tensor src = TFloat32.tensorOf(matrix)) { - Tensor cpy = Tensor.fromHandle(src.nativeHandle()).expect(TFloat32.DTYPE); + try (TFloat32 src = TFloat32.tensorOf(matrix)) { + TFloat32 cpy = (TFloat32)RawTensor.fromHandle(src.asRawTensor().nativeHandle()).asTypedTensor(); assertEquals(src.dataType(), cpy.dataType()); assertEquals(src.shape().numDimensions(), cpy.shape().numDimensions()); assertEquals(src.shape(), cpy.shape()); - assertEquals(matrix, cpy.data()); + assertEquals(matrix, cpy); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java index bbebfd5f454..c97dc83f510 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java @@ -169,11 +169,10 @@ public void composite() { // assertNotNull(g.operation("variance/zero")); // Verify correct results as well. - Tensor result = - sess.runner().fetch(var1.output()).run().get(0).expect(TInt32.DTYPE); - assertEquals(21704, result.data().getInt()); - result = sess.runner().fetch(var2.output()).run().get(0).expect(TInt32.DTYPE); - assertEquals(21704, result.data().getInt()); + TInt32 result = (TInt32)sess.runner().fetch(var1.output()).run().get(0); + assertEquals(21704, result.getInt()); + result = (TInt32)sess.runner().fetch(var2.output()).run().get(0); + assertEquals(21704, result.getInt()); } } @@ -189,7 +188,7 @@ static Const create(Scope s, int[] v) { return create(s, TInt32.vectorOf(v)); } - static Const create(Scope s, Tensor value) { + static Const create(Scope s, T value) { return new Const<>( s.env() .opBuilder("Const", s.makeOpName("Const")) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java index 7cdc01f2d31..16a644aced1 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -21,9 +21,11 @@ import org.junit.jupiter.api.Test; import org.tensorflow.AutoCloseableList; +import org.tensorflow.EagerSession; import org.tensorflow.Graph; import org.tensorflow.Session; import org.tensorflow.Tensor; +import org.tensorflow.op.Ops; import org.tensorflow.op.Scope; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.buffer.DataBuffer; @@ -58,10 +60,10 @@ public void createInts() { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList> t = + try (AutoCloseableList t = new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0).expect(TInt32.DTYPE).data()); - assertEquals(array, t.get(1).expect(TInt32.DTYPE).data()); + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } } @@ -77,10 +79,10 @@ public void createFloats() { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList> t = + try (AutoCloseableList t = new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0).expect(TFloat32.DTYPE).data()); - assertEquals(array, t.get(1).expect(TFloat32.DTYPE).data()); + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } } @@ -96,10 +98,10 @@ public void createDoubles() { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList> t = + try (AutoCloseableList t = new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0).expect(TFloat64.DTYPE).data()); - assertEquals(array, t.get(1).expect(TFloat64.DTYPE).data()); + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } } @@ -115,10 +117,10 @@ public void createLongs() { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList> t = + try (AutoCloseableList t = new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0).expect(TInt64.DTYPE).data()); - assertEquals(array, t.get(1).expect(TInt64.DTYPE).data()); + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } } @@ -134,11 +136,32 @@ public void createStrings() throws IOException { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList> t = + try (AutoCloseableList t = new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0).expect(TString.DTYPE).data()); - assertEquals(array, t.get(1).expect(TString.DTYPE).data()); + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } } + + @Test + public void createFromTensorsInEagerMode() throws IOException { + try (EagerSession s = EagerSession.create(); + TInt32 t = TInt32.vectorOf(1, 2, 3, 4)) { + Ops tf = Ops.create(s); + + Constant c1 = tf.constant(t); + assertEquals(c1.asTensor(), t); + + // A different endpoint for capturing a tensor as a constant, which supports all data types + Constant c2 = tf.capture(t); + assertEquals(c2.asTensor(), t); + assertEquals(c1.asTensor(), c2.asTensor()); + + // Permute data in the tensor to make sure that constant copies are independent + t.setInt(10); + assertEquals(NdArrays.vectorOf(10, 2, 3, 4), t); + assertEquals(NdArrays.vectorOf(1, 2, 3, 4), c1.asTensor()); + } + } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java index a337bd73098..ede179740e5 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java @@ -36,8 +36,8 @@ public void tensorInputTensorOutput() { Session sess = new Session(g)) { Ops ops = Ops.create(g); Operand x = ops.math.add(ops.constant(1), ops.constant(2)); - try (Tensor result = sess.runner().fetch(x).run().get(0).expect(TInt32.DTYPE)) { - assertEquals(3, result.data().getInt()); + try (TInt32 result = (TInt32)sess.runner().fetch(x).run().get(0)) { + assertEquals(3, result.getInt()); } } } @@ -52,8 +52,8 @@ public void testListInputTensorOutput() { inputs.add(ops.constant(2)); inputs.add(ops.constant(3)); Operand x = ops.math.addN(inputs); - try (Tensor result = sess.runner().fetch(x).run().get(0).expect(TInt32.DTYPE)) { - assertEquals(6, result.data().getInt()); + try (TInt32 result = (TInt32)sess.runner().fetch(x).run().get(0)) { + assertEquals(6, result.getInt()); } } } @@ -77,8 +77,8 @@ public void testControlDependencies() { Operand x = ops.withControlDependencies(controls).math.add(variable, ops.constant(0)); sess.runner().addTarget(initVariable).run(); - try (Tensor result = sess.runner().fetch(x).run().get(0).expect(TInt32.DTYPE)) { - assertEquals(3, result.data().getInt()); + try (TInt32 result = (TInt32)sess.runner().fetch(x).run().get(0)) { + assertEquals(3, result.getInt()); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java index fe1503d415f..5333b8e0d33 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -47,13 +47,13 @@ public void createGradients() { assertNotNull(grads.dy()); assertEquals(2, grads.dy().size()); - try (Tensor c = TFloat32.scalarOf(3.0f); - AutoCloseableList> outputs = + try (TFloat32 c = TFloat32.scalarOf(3.0f); + AutoCloseableList outputs = new AutoCloseableList<>( sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) { - assertEquals(108.0f, outputs.get(0).expect(TFloat32.DTYPE).data().getFloat(), 0.0f); - assertEquals(18.0f, outputs.get(1).expect(TFloat32.DTYPE).data().getFloat(), 0.0f); + assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); + assertEquals(18.0f, ((TFloat32)outputs.get(1)).getFloat(), 0.0f); } } } @@ -74,11 +74,11 @@ public void createGradientsWithSum() { assertNotNull(grads.dy()); assertEquals(1, grads.dy().size()); - try (Tensor c = TFloat32.scalarOf(3.0f); - AutoCloseableList> outputs = + try (TFloat32 c = TFloat32.scalarOf(3.0f); + AutoCloseableList outputs = new AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) { - assertEquals(114.0f, outputs.get(0).expect(TFloat32.DTYPE).data().getFloat(), 0.0f); + assertEquals(114.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); } } } @@ -100,12 +100,12 @@ public void createGradientsWithInitialValues() { assertNotNull(grads1.dy()); assertEquals(1, grads1.dy().size()); - try (Tensor c = TFloat32.scalarOf(3.0f); - AutoCloseableList> outputs = + try (TFloat32 c = TFloat32.scalarOf(3.0f); + AutoCloseableList outputs = new AutoCloseableList<>( sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) { - assertEquals(108.0f, outputs.get(0).expect(TFloat32.DTYPE).data().getFloat(), 0.0f); + assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java index d5eb7412ea3..083beca923c 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java @@ -43,19 +43,16 @@ public void testFlatten_Operand() { Shape expResult = Shape.create(scope, operand, TInt64.DTYPE); Operand reshaped = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); - Operand actual = Shapes.flatten(scope, reshaped); + Operand actual = Shapes.flatten(scope, reshaped); Shape tfshape = Shape.create(scope, actual, TInt64.DTYPE); AtomicInteger index = new AtomicInteger(); - try (Tensor result1 = - session.runner().fetch(tfshape.asOutput()).run().get(0).expect(TInt64.DTYPE); - Tensor result2 = - session.runner().fetch(expResult.asOutput()).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result1 = (TInt64)session.runner().fetch(tfshape.asOutput()).run().get(0); + TInt64 result2 = (TInt64)session.runner().fetch(expResult.asOutput()).run().get(0)) { result1 - .data() .scalars() .forEach( - s -> assertEquals(result2.data().getLong(index.getAndIncrement()), s.getLong())); + s -> assertEquals(result2.getLong(index.getAndIncrement()), s.getLong())); } } } @@ -65,22 +62,21 @@ public void testFlatten_Operand() { public void testFlatten_Shape() { try (EagerSession session = EagerSession.create()) { Scope scope = new Scope(session); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); Shape expShape = Shape.create(scope, operand, TInt64.DTYPE); - Operand actual = + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); Shape tfshape = Shape.create(scope, actual, TInt64.DTYPE); Operand flattened = Shapes.flatten(scope, tfshape, TInt64.DTYPE); AtomicInteger index = new AtomicInteger(); flattened - .asOutput() - .data() + .asTensor() .scalars() .forEach( s -> assertEquals( - expShape.asOutput().data().getLong(index.getAndIncrement()), s.getLong())); + expShape.asTensor().getLong(index.getAndIncrement()), s.getLong())); } } @@ -90,16 +86,15 @@ public void testSize_Shape() { try (Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = + Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); Shape tfshape = Shape.create(scope, actual, TInt64.DTYPE); Operand size = Shapes.size(scope, tfshape, TInt64.DTYPE); AtomicInteger index = new AtomicInteger(); - try (Tensor result1 = - session.runner().fetch(size.asOutput()).run().get(0).expect(TInt64.DTYPE)) { - result1.data().scalars().forEach(s -> assertEquals(8, s.getLong())); + try (TInt64 result1 = (TInt64)session.runner().fetch(size.asOutput()).run().get(0)) { + result1.scalars().forEach(s -> assertEquals(8, s.getLong())); } } } @@ -110,27 +105,24 @@ public void testSize_Shape_Operand() { try (Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = + Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand size = Shapes.size(scope, tfshape, Constant.scalarOf(scope, 0)); - try (Tensor result = - session.runner().fetch(size.asOutput()).run().get(0).expect(TInt32.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(4, s.getInt())); + try (TInt32 result = (TInt32)session.runner().fetch(size.asOutput()).run().get(0)) { + result.scalars().forEach(s -> assertEquals(4, s.getInt())); } size = Shapes.size(scope, tfshape, Constant.scalarOf(scope, 1)); - try (Tensor result = - session.runner().fetch(size.asOutput()).run().get(0).expect(TInt32.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(2, s.getInt())); + try (TInt32 result = (TInt32)session.runner().fetch(size.asOutput()).run().get(0)) { + result.scalars().forEach(s -> assertEquals(2, s.getInt())); } size = Shapes.size(scope, tfshape, Constant.scalarOf(scope, 2)); - try (Tensor result = - session.runner().fetch(size.asOutput()).run().get(0).expect(TInt32.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(1, s.getInt())); + try (TInt32 result = (TInt32)session.runner().fetch(size.asOutput()).run().get(0)) { + result.scalars().forEach(s -> assertEquals(1, s.getInt())); } } } @@ -141,26 +133,23 @@ public void testSize_Operand_Operand() { try (Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = + Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); Operand size = Shapes.size(scope, actual, Constant.scalarOf(scope, 0)); - try (Tensor result = - session.runner().fetch(size.asOutput()).run().get(0).expect(TInt32.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(4, s.getInt())); + try (TInt32 result = (TInt32)session.runner().fetch(size.asOutput()).run().get(0)) { + result.scalars().forEach(s -> assertEquals(4, s.getInt())); } size = Shapes.size(scope, actual, Constant.scalarOf(scope, 1)); - try (Tensor result = - session.runner().fetch(size.asOutput()).run().get(0).expect(TInt32.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(2, s.getInt())); + try (TInt32 result = (TInt32)session.runner().fetch(size.asOutput()).run().get(0)) { + result.scalars().forEach(s -> assertEquals(2, s.getInt())); } size = Shapes.size(scope, actual, Constant.scalarOf(scope, 2)); - try (Tensor result = - session.runner().fetch(size.asOutput()).run().get(0).expect(TInt32.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(1, s.getInt())); + try (TInt32 result = (TInt32)session.runner().fetch(size.asOutput()).run().get(0)) { + result.scalars().forEach(s -> assertEquals(1, s.getInt())); } } } @@ -171,15 +160,14 @@ public void testNumDimensions() { try (Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = + Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand nDims = Shapes.numDimensions(scope, tfshape); - try (Tensor result = - session.runner().fetch(nDims.asOutput()).run().get(0).expect(TInt32.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(3, s.getInt())); + try (TInt32 result = (TInt32)session.runner().fetch(nDims.asOutput()).run().get(0)) { + result.scalars().forEach(s -> assertEquals(3, s.getInt())); } } } @@ -199,7 +187,7 @@ public void testReduceDims_Operand_Operand() { AtomicInteger index = new AtomicInteger(); int[] expected = {8}; reducedShape - .data() + .asTensor() .scalars() .forEach( s -> { @@ -219,12 +207,12 @@ public void testReduceDims_Shape_Operand() { Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {2, 2, 2})); Shape tfshape = Shape.create(scope, actual); - Operand reduced = Shapes.reduceDims(scope, actual, Constant.scalarOf(scope, 0)); + Operand reduced = Shapes.reduceDims(scope, actual, Constant.scalarOf(scope, 0)); Shape reducedShape = Shape.create(scope, reduced); AtomicInteger index = new AtomicInteger(); int[] expected1 = {8}; reducedShape - .data() + .asTensor() .scalars() .forEach( s -> { @@ -237,7 +225,7 @@ public void testReduceDims_Shape_Operand() { index.set(0); int[] expected2 = {2, 4}; reducedShape - .data() + .asTensor() .scalars() .forEach( s -> { @@ -250,7 +238,7 @@ public void testReduceDims_Shape_Operand() { index.set(0); int[] expected3 = {2, 2, 2}; reducedShape - .data() + .asTensor() .scalars() .forEach( s -> { @@ -266,18 +254,16 @@ public void testSqueeze() { try (Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = + Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 1, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand squeezed = Shapes.squeeze(scope, tfshape); AtomicInteger index = new AtomicInteger(); int[] expected = {4, 2}; - try (Tensor result = - session.runner().fetch(squeezed.asOutput()).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = (TInt32)session.runner().fetch(squeezed.asOutput()).run().get(0)) { result - .data() .scalars() .forEach( s -> { @@ -293,18 +279,16 @@ public void testHead() { try (Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = + Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 1, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand head = Shapes.head(scope, tfshape); AtomicInteger index = new AtomicInteger(); int[] expected = {4}; - try (Tensor result = - session.runner().fetch(head.asOutput()).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = (TInt32)session.runner().fetch(head.asOutput()).run().get(0)) { result - .data() .scalars() .forEach( s -> { @@ -320,18 +304,16 @@ public void testTake() { try (Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = + Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 1, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand take = Shapes.take(scope, tfshape, Constant.scalarOf(scope, 2)); AtomicInteger index = new AtomicInteger(); int[] expected = {4, 1}; - try (Tensor result = - session.runner().fetch(take.asOutput()).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = (TInt32)session.runner().fetch(take.asOutput()).run().get(0)) { result - .data() .scalars() .forEach( s -> { @@ -347,18 +329,16 @@ public void testTail() { try (Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = + Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 1, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand tail = Shapes.tail(scope, tfshape); AtomicInteger index = new AtomicInteger(); int[] expected = {1}; - try (Tensor result = - session.runner().fetch(tail.asOutput()).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = (TInt32)session.runner().fetch(tail.asOutput()).run().get(0)) { result - .data() .scalars() .forEach( s -> { @@ -374,18 +354,16 @@ public void testTakeLast() { try (Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = + Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 1, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand takeLast = Shapes.takeLast(scope, tfshape, Constant.scalarOf(scope, 3)); AtomicInteger index = new AtomicInteger(); int[] expected = {1, 2, 1}; - try (Tensor result = - session.runner().fetch(takeLast.asOutput()).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = (TInt32)session.runner().fetch(takeLast.asOutput()).run().get(0)) { result - .data() .scalars() .forEach( s -> { @@ -401,17 +379,15 @@ public void testPrependInt() { try (Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); + Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); Shape tfshape = Shape.create(scope, actual); Operand prepend = Shapes.prepend(scope, tfshape, 3); AtomicInteger index = new AtomicInteger(); int[] expected = {3, 4, 2}; - try (Tensor result = - session.runner().fetch(prepend.asOutput()).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = (TInt32)session.runner().fetch(prepend.asOutput()).run().get(0)) { result - .data() .scalars() .forEach( s -> { @@ -427,17 +403,15 @@ public void testPrependLong() { try (Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); + Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); Shape tfshape = Shape.create(scope, actual, TInt64.DTYPE); Operand prepend = Shapes.prepend(scope, tfshape, 1L); AtomicInteger index = new AtomicInteger(); long[] expected = {1, 4, 2}; - try (Tensor result = - session.runner().fetch(prepend.asOutput()).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = (TInt64)session.runner().fetch(prepend.asOutput()).run().get(0)) { result - .data() .scalars() .forEach( s -> { @@ -453,11 +427,11 @@ public void testPrependShapeTInt32() { try (Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand1 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual1 = + Operand operand1 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual1 = Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[] {4, 2})); - Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual2 = + Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual2 = Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[] {2, 4})); Shape tfshape1 = Shape.create(scope, actual1); Shape tfshape2 = Shape.create(scope, actual2); @@ -465,10 +439,8 @@ public void testPrependShapeTInt32() { Operand prepend = Shapes.prepend(scope, tfshape1, tfshape2); AtomicInteger index = new AtomicInteger(); int[] expected = {2, 4, 4, 2}; - try (Tensor result = - session.runner().fetch(prepend.asOutput()).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = (TInt32)session.runner().fetch(prepend.asOutput()).run().get(0)) { result - .data() .scalars() .forEach( s -> { @@ -484,11 +456,11 @@ public void testPrependShapeTInt64() { try (Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand1 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual1 = + Operand operand1 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual1 = Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[] {4, 2})); - Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual2 = + Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual2 = Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[] {2, 4})); Shape tfshape1 = Shape.create(scope, actual1, TInt64.DTYPE); Shape tfshape2 = Shape.create(scope, actual2, TInt64.DTYPE); @@ -496,10 +468,8 @@ public void testPrependShapeTInt64() { Operand prepend = Shapes.prepend(scope, tfshape1, tfshape2); AtomicInteger index = new AtomicInteger(); long[] expected = {2, 4, 4, 2}; - try (Tensor result = - session.runner().fetch(prepend.asOutput()).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = (TInt64)session.runner().fetch(prepend.asOutput()).run().get(0)) { result - .data() .scalars() .forEach( s -> { @@ -515,17 +485,15 @@ public void testAppendLong() { try (Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); + Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); Shape tfshape = Shape.create(scope, actual, TInt64.DTYPE); Operand append = Shapes.append(scope, tfshape, 2L); AtomicInteger index = new AtomicInteger(); long[] expected = {4L, 2L, 2L}; - try (Tensor result = - session.runner().fetch(append.asOutput()).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = (TInt64)session.runner().fetch(append.asOutput()).run().get(0)) { result - .data() .scalars() .forEach( s -> { @@ -541,17 +509,15 @@ public void testAppendInt() { try (Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); + Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); Shape tfshape = Shape.create(scope, actual); Operand append = Shapes.append(scope, tfshape, 2); AtomicInteger index = new AtomicInteger(); int[] expected = {4, 2, 2}; - try (Tensor result = - session.runner().fetch(append.asOutput()).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = (TInt32)session.runner().fetch(append.asOutput()).run().get(0)) { result - .data() .scalars() .forEach( s -> { @@ -567,11 +533,11 @@ public void testAppendShapeTInt32() { try (Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand1 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual1 = + Operand operand1 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual1 = Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[] {4, 2})); - Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual2 = + Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual2 = Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[] {2, 4})); Shape tfshape1 = Shape.create(scope, actual1); Shape tfshape2 = Shape.create(scope, actual2); @@ -579,10 +545,8 @@ public void testAppendShapeTInt32() { Operand append = Shapes.append(scope, tfshape1, tfshape2); AtomicInteger index = new AtomicInteger(); int[] expected = {4, 2, 2, 4}; - try (Tensor result = - session.runner().fetch(append.asOutput()).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = (TInt32)session.runner().fetch(append.asOutput()).run().get(0)) { result - .data() .scalars() .forEach( s -> { @@ -598,11 +562,11 @@ public void testAppendShapeTInt64() { try (Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand1 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual1 = + Operand operand1 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual1 = Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[] {4, 2})); - Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual2 = + Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual2 = Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[] {2, 4})); Shape tfshape1 = Shape.create(scope, actual1, TInt64.DTYPE); Shape tfshape2 = Shape.create(scope, actual2, TInt64.DTYPE); @@ -610,10 +574,8 @@ public void testAppendShapeTInt64() { Operand append = Shapes.append(scope, tfshape1, tfshape2); AtomicInteger index = new AtomicInteger(); long[] expected = {4, 2, 2, 4}; - try (Tensor result = - session.runner().fetch(append.asOutput()).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = (TInt64)session.runner().fetch(append.asOutput()).run().get(0)) { result - .data() .scalars() .forEach( s -> { diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java index 9600f8b38fc..204bd4b10f3 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java @@ -42,8 +42,8 @@ public void createIntZeros() { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TInt32.DTYPE); - try (Tensor result = sess.runner().fetch(op).run().get(0).expect(TInt32.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(0, s.getInt())); + try (TInt32 result = (TInt32)sess.runner().fetch(op).run().get(0)) { + result.scalars().forEach(s -> assertEquals(0, s.getInt())); } } } @@ -55,8 +55,8 @@ public void createFloatZeros() { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat32.DTYPE); - try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(TFloat32.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(0.0f, s.getFloat(), 0)); + try (TFloat32 result = (TFloat32)sess.runner().fetch(op.asOutput()).run().get(0)) { + result.scalars().forEach(s -> assertEquals(0.0f, s.getFloat(), 0)); } } } @@ -68,8 +68,8 @@ public void createDoubleZeros() { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat64.DTYPE); - try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(TFloat64.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(0.0f, s.getDouble(), 0)); + try (TFloat64 result = (TFloat64)sess.runner().fetch(op.asOutput()).run().get(0)) { + result.scalars().forEach(s -> assertEquals(0.0f, s.getDouble(), 0)); } } } @@ -81,8 +81,8 @@ public void createLongZeros() { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TInt64.DTYPE); - try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(TInt64.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(0L, s.getLong())); + try (TInt64 result = (TInt64)sess.runner().fetch(op.asOutput()).run().get(0)) { + result.scalars().forEach(s -> assertEquals(0L, s.getLong())); } } } @@ -94,8 +94,8 @@ public void createBooleanZeros() { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TBool.DTYPE); - try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(TBool.DTYPE)) { - result.data().scalars().forEach(s -> assertFalse(s.getBoolean())); + try (TBool result = (TBool)sess.runner().fetch(op.asOutput()).run().get(0)) { + result.scalars().forEach(s -> assertFalse(s.getBoolean())); } } } @@ -107,8 +107,8 @@ public void createUint8Zeros() { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TUint8.DTYPE); - try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(TUint8.DTYPE)) { - result.data().scalars().forEach(s -> assertEquals(0, s.getByte())); + try (TUint8 result = (TUint8)sess.runner().fetch(op.asOutput()).run().get(0)) { + result.scalars().forEach(s -> assertEquals(0, s.getByte())); } } } @@ -120,8 +120,8 @@ public void createStringZeros() { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TString.DTYPE); - try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(TString.DTYPE)) { - result.data().scalars().forEach(s -> assertTrue(s.getObject().isEmpty())); + try (TString result = (TString)sess.runner().fetch(op.asOutput()).run().get(0)) { + result.scalars().forEach(s -> assertTrue(s.getObject().isEmpty())); } } } @@ -133,7 +133,7 @@ public void operationsComposingZerosAreCorrectlyNamed() { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros zeros = Zeros.create(scope.withSubScope("test"), Constant.vectorOf(scope, shape), TFloat32.DTYPE); - List> results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); + List results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java index 87b24b0da2a..24390ae0a91 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java @@ -21,90 +21,93 @@ import org.junit.jupiter.api.Test; import org.tensorflow.EagerSession; -import org.tensorflow.Tensor; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.index.Indices; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Constant; +import org.tensorflow.op.math.Add; import org.tensorflow.op.math.Sub; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.IntNdArray; -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.index.Indices; import org.tensorflow.types.family.TNumber; -abstract class NumericTypesTestBase, U> { +abstract class NumericTypesTestBase { @Test public void initializeTensorsWithZeros() { // Allocate a tensor of 32-bits integer of the shape (2, 3, 2) - Tensor tensor = allocateTensor(Shape.of(2, 3, 2)); - NdArray tensorData = tensor.data(); + T tensor = allocateTensor(Shape.of(2, 3, 2)); - assertEquals(3, tensorData.rank()); - assertEquals(12, tensorData.size()); + assertEquals(3, tensor.rank()); + assertEquals(12, tensor.size()); + NdArray data = (NdArray)tensor; try (EagerSession session = EagerSession.create()) { Ops tf = Ops.create(session); // Initialize tensor memory with zeros and take a snapshot - tensorData.scalars().forEach(scalar -> scalar.setObject(valueOf(0))); - Constant x = tf.constant(tensor); + data.scalars().forEach(scalar -> ((NdArray)scalar).setObject(valueOf(0))); + Constant x = tf.capture(tensor); // Initialize the same tensor memory with ones and take a snapshot - tensorData.scalars().forEach(scalar -> scalar.setObject(valueOf(1))); - Constant y = tf.constant(tensor); + data.scalars().forEach(scalar -> ((NdArray)scalar).setObject(valueOf(1))); + Constant y = tf.capture(tensor); // Subtract y from x and validate the result Sub sub = tf.math.sub(x, y); - sub.data().scalars().forEach(scalar -> + ((NdArray)sub.asTensor()).scalars().forEach(scalar -> assertEquals(valueOf(-1), scalar.getObject()) ); } } @Test - public void genericTest() { - IntNdArray heapData = NdArrays.vectorOf(0, 1, 2, 3); + public void setAndCompute() { + NdArray heapData = allocateNdArray(Shape.of(4)) + .setObject(valueOf(0), 0) + .setObject(valueOf(1), 1) + .setObject(valueOf(2), 2) + .setObject(valueOf(3), 3); // Creates a 2x2 matrix - try (Tensor tensor = TInt32.tensorOf(Shape.of(2, 2))) { - IntNdArray tensorData = tensor.data(); + try (T tensor = allocateTensor(Shape.of(2, 2))) { + NdArray data = (NdArray)tensor; // Copy first 2 values of the vector to the first row of the matrix - tensorData.set(heapData.slice(Indices.range(0, 2)), 0); + data.set(heapData.slice(Indices.range(0, 2)), 0); // Copy values at an odd position in the vector as the second row of the matrix - tensorData.set(heapData.slice(Indices.odd()), 1); + data.set(heapData.slice(Indices.odd()), 1); - assertEquals(0, tensorData.getInt(0, 0)); - assertEquals(1, tensorData.getInt(0, 1)); - assertEquals(1, tensorData.getInt(1, 0)); - assertEquals(3, tensorData.getInt(1, 1)); + assertEquals(valueOf(0), data.getObject(0, 0)); + assertEquals(valueOf(1), data.getObject(0, 1)); + assertEquals(valueOf(1), data.getObject(1, 0)); + assertEquals(valueOf(3), data.getObject(1, 1)); // Read rows of the tensor in reverse order - IntNdArray reversedTensorData = tensorData.slice(Indices.all(), Indices.flip()); + NdArray flippedData = data.slice(Indices.flip(), Indices.flip()); - assertEquals(1, reversedTensorData.getInt(0, 0)); - assertEquals(0, reversedTensorData.getInt(0, 1)); - assertEquals(3, reversedTensorData.getInt(1, 0)); - assertEquals(1, reversedTensorData.getInt(1, 1)); + assertEquals(valueOf(3), flippedData.getObject(0, 0)); + assertEquals(valueOf(1), flippedData.getObject(0, 1)); + assertEquals(valueOf(1), flippedData.getObject(1, 0)); + assertEquals(valueOf(0), flippedData.getObject(1, 1)); try (EagerSession session = EagerSession.create()) { Ops tf = Ops.create(session); - // Compute the power of the tensor by itself - Constant x = tf.constant(tensor); - IntNdArray result = tf.math.pow(x, x).data(); + Add add = tf.math.add(tf.capture(tensor), tf.capture(tensor)); + NdArray result = (NdArray)add.asTensor(); - // Validate result by computing the same operation in Java - tensorData.scalars().forEachIndexed((coords, s) -> - assertEquals(Math.pow(s.getInt(), s.getInt()), result.getInt(coords), 1e-7f) - ); + assertEquals(valueOf(0), result.getObject(0, 0)); + assertEquals(valueOf(2), result.getObject(0, 1)); + assertEquals(valueOf(2), result.getObject(1, 0)); + assertEquals(valueOf(6), result.getObject(1, 1)); } } } - abstract Tensor allocateTensor(Shape shape); + abstract T allocateTensor(Shape shape); + + abstract NdArray allocateNdArray(Shape shape); abstract U valueOf(Integer value); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TBfloat16Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TBfloat16Test.java index 8681e805e3d..17a6e0dd2b5 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TBfloat16Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TBfloat16Test.java @@ -17,16 +17,22 @@ package org.tensorflow.types; -import org.tensorflow.Tensor; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; public class TBfloat16Test extends NumericTypesTestBase { @Override - Tensor allocateTensor(Shape shape) { + TBfloat16 allocateTensor(Shape shape) { return TBfloat16.tensorOf(shape); } + @Override + NdArray allocateNdArray(Shape shape) { + return NdArrays.ofFloats(shape); + } + @Override Float valueOf(Integer value) { return value.floatValue(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat16Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat16Test.java index b72fe6fc01c..c1ae8ad3b6d 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat16Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat16Test.java @@ -17,16 +17,22 @@ package org.tensorflow.types; -import org.tensorflow.Tensor; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; public class TFloat16Test extends NumericTypesTestBase { @Override - Tensor allocateTensor(Shape shape) { + TFloat16 allocateTensor(Shape shape) { return TFloat16.tensorOf(shape); } + @Override + NdArray allocateNdArray(Shape shape) { + return NdArrays.ofFloats(shape); + } + @Override Float valueOf(Integer value) { return value.floatValue(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat32Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat32Test.java index c4b1f6023f3..8df96f2871a 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat32Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat32Test.java @@ -17,16 +17,22 @@ package org.tensorflow.types; -import org.tensorflow.Tensor; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; public class TFloat32Test extends NumericTypesTestBase { @Override - Tensor allocateTensor(Shape shape) { + TFloat32 allocateTensor(Shape shape) { return TFloat32.tensorOf(shape); } + @Override + NdArray allocateNdArray(Shape shape) { + return NdArrays.ofFloats(shape); + } + @Override Float valueOf(Integer value) { return value.floatValue(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat64Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat64Test.java index 0e9c8947d0f..47b4b6d936a 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat64Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat64Test.java @@ -17,16 +17,22 @@ package org.tensorflow.types; -import org.tensorflow.Tensor; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; public class TFloat64Test extends NumericTypesTestBase { @Override - Tensor allocateTensor(Shape shape) { + TFloat64 allocateTensor(Shape shape) { return TFloat64.tensorOf(shape); } + @Override + NdArray allocateNdArray(Shape shape) { + return NdArrays.ofDoubles(shape); + } + @Override Double valueOf(Integer value) { return value.doubleValue(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java index c52394bf210..9ea7f952f04 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java @@ -17,16 +17,24 @@ package org.tensorflow.types; -import org.tensorflow.Tensor; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; public class TInt32Test extends NumericTypesTestBase { @Override - Tensor allocateTensor(Shape shape) { + TInt32 allocateTensor(Shape shape) { return TInt32.tensorOf(shape); } + @Override + NdArray allocateNdArray(Shape shape) { + return NdArrays.ofInts(shape); + } + @Override Integer valueOf(Integer value) { return value; diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt64Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt64Test.java index 261ac546fc5..a88f3fb4d6d 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt64Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt64Test.java @@ -17,16 +17,22 @@ package org.tensorflow.types; -import org.tensorflow.Tensor; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; public class TInt64Test extends NumericTypesTestBase { @Override - Tensor allocateTensor(Shape shape) { + TInt64 allocateTensor(Shape shape) { return TInt64.tensorOf(shape); } + @Override + NdArray allocateNdArray(Shape shape) { + return NdArrays.ofLongs(shape); + } + @Override Long valueOf(Integer value) { return value.longValue(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java index a4700aa652f..015f93b70e7 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java @@ -23,46 +23,36 @@ import java.nio.charset.StandardCharsets; import org.junit.jupiter.api.Test; -import org.tensorflow.Tensor; -import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; public class TStringTest { @Test public void createScalar() { - Tensor tensor = TString.scalarOf("Pretty vacant"); + TString tensor = TString.scalarOf("Pretty vacant"); assertNotNull(tensor); - - TString data = tensor.data(); - assertNotNull(data); - assertEquals(Shape.scalar(), data.shape()); - assertEquals("Pretty vacant", data.getObject()); + assertEquals(Shape.scalar(), tensor.shape()); + assertEquals("Pretty vacant", tensor.getObject()); } @Test public void createrScalarLongerThan127() { - Tensor tensor = TString.scalarOf("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !"); + TString tensor = TString.scalarOf("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !"); assertNotNull(tensor); - - TString data = tensor.data(); - assertNotNull(data); - assertEquals(Shape.scalar(), data.shape()); - assertEquals("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !", data.getObject()); + assertEquals(Shape.scalar(), tensor.shape()); + assertEquals("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !", tensor.getObject()); } @Test public void createVector() { - Tensor tensor = TString.vectorOf("Pretty", "vacant"); + TString tensor = TString.vectorOf("Pretty", "vacant"); assertNotNull(tensor); - - TString data = tensor.data(); - assertNotNull(data); - assertEquals(Shape.of(2), data.shape()); - assertEquals("Pretty", data.getObject(0)); - assertEquals("vacant", data.getObject(1)); + assertEquals(Shape.of(2), tensor.shape()); + assertEquals("Pretty", tensor.getObject(0)); + assertEquals("vacant", tensor.getObject(1)); } @Test @@ -73,30 +63,27 @@ public void createCopy() { .setObject("New", 1, 0) .setObject("York", 1, 1); - Tensor tensor = TString.tensorOf(strings); + TString tensor = TString.tensorOf(strings); assertNotNull(tensor); - - TString data = tensor.data(); - assertNotNull(data); strings.scalars().forEachIndexed((idx, s) -> - assertEquals(s.getObject(), data.getObject(idx)) + assertEquals(s.getObject(), tensor.getObject(idx)) ); } @Test public void defaultCharsetIsUtf8() { - Tensor tensor = TString.tensorOf(NdArrays.scalarOfObject(BABY_CHICK)); - byte[] bytes = tensor.data().asBytes().getObject(); + TString tensor = TString.tensorOf(NdArrays.scalarOfObject(BABY_CHICK)); + byte[] bytes = tensor.asBytes().getObject(); assertArrayEquals(new byte[] { (byte)0xF0, (byte)0x9F, (byte)0x90, (byte)0xA5 }, bytes); - assertEquals(BABY_CHICK, tensor.data().getObject()); + assertEquals(BABY_CHICK, tensor.getObject()); } @Test public void usingDifferentCharset() { - Tensor tensor = TString.tensorOf(StandardCharsets.UTF_16LE, NdArrays.scalarOfObject(BABY_CHICK)); - byte[] bytes = tensor.data().asBytes().getObject(); + TString tensor = TString.tensorOf(StandardCharsets.UTF_16LE, NdArrays.scalarOfObject(BABY_CHICK)); + byte[] bytes = tensor.asBytes().getObject(); assertArrayEquals(new byte[] { (byte)0x3D, (byte)0xD8, (byte)0x25, (byte)0xDC }, bytes); - assertEquals(BABY_CHICK, tensor.data().using(StandardCharsets.UTF_16LE).getObject()); + assertEquals(BABY_CHICK, tensor.using(StandardCharsets.UTF_16LE).getObject()); } @Test @@ -106,11 +93,11 @@ public void initializingTensorWithRawBytes() { for (int i = 0; i < strings.length; ++i) { bytes.setObject(strings[i].getBytes(), i); } - Tensor tensor = TString.tensorOfBytes(bytes); + TString tensor = TString.tensorOfBytes(bytes); assertNotNull(tensor); assertEquals(bytes.shape(), tensor.shape()); - NdArray tensorBytes = tensor.data().asBytes(); + NdArray tensorBytes = tensor.asBytes(); for (int i = 0; i < strings.length; ++i) { assertArrayEquals(bytes.getObject(i), tensorBytes.getObject(i)); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TUint8Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TUint8Test.java index cc83087e018..ce7397d5878 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TUint8Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TUint8Test.java @@ -17,16 +17,22 @@ package org.tensorflow.types; -import org.tensorflow.Tensor; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; public class TUint8Test extends NumericTypesTestBase { @Override - Tensor allocateTensor(Shape shape) { + TUint8 allocateTensor(Shape shape) { return TUint8.tensorOf(shape); } + @Override + NdArray allocateNdArray(Shape shape) { + return NdArrays.ofBytes(shape); + } + @Override Byte valueOf(Integer value) { return value.byteValue(); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java index f4c4b681715..0bad6a41214 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java @@ -272,7 +272,7 @@ public Iterator>> iterator() { @Override public boolean hasNext() { - return nextOptional.hasValue().data().getBoolean(); + return nextOptional.hasValue().asTensor().getBoolean(); } @Override diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index 9a4e0b7d3c4..fbc92f14d31 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -361,7 +361,7 @@ public static Operand rangeCheck( tf.withSubScope("rangeCheck") .withControlDependencies(Collections.singletonList(assertThat)); return ltf.identity(values); - } else if (!cond.asOutput().data().getBoolean()) + } else if (!cond.asTensor().getBoolean()) throw new IllegalArgumentException(String.format("%s : values out of range", prefix)); else return values; } @@ -409,7 +409,7 @@ public static Operand valueCheck( tf.withSubScope("valueCheck") .withControlDependencies(Collections.singletonList(assertThat)); return ltf.identity(values); - } else if (!cond.asOutput().data().getBoolean()) + } else if (!cond.asTensor().getBoolean()) throw new IllegalArgumentException(String.format("%s : values not in value set", prefix)); else return values; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java index 122de9f21ae..29f04d9f398 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java @@ -63,64 +63,36 @@ public static int[] getIntArray(Scope scope, Operand dims) { * @throws java.lang.IllegalArgumentException if the dims type is not an integer */ public static long[] getLongArray(Scope scope, Operand dims) { - DataType dType = dims.asOutput().dataType(); - if (!dType.isInteger()) { - throw new IllegalArgumentException("the data type must be an integer type"); - } - List result = new ArrayList<>(); - if (scope.env().isEager()) { - if (dType.equals(TInt32.DTYPE)) { - @SuppressWarnings("unchecked") - Operand idims = (Operand) dims; - - idims.asOutput().data().scalars().forEach(s -> result.add((long) s.getInt())); - } else if (dType.equals(TInt64.DTYPE)) { - @SuppressWarnings("unchecked") - Operand ldims = (Operand) dims; - ldims.asOutput().data().scalars().forEach(s -> result.add(s.getLong())); - } else if (dType.equals(TUint8.DTYPE)) { - @SuppressWarnings("unchecked") - Operand udims = (Operand) dims; - udims.asOutput().data().scalars().forEach(s -> result.add(s.getObject().longValue())); - } else { // shouldn't happen - throw new IllegalArgumentException("the data type must be an integer type"); - } - - } else { - try (Session session = new Session((Graph) scope.env())) { - if (dType.equals(TInt32.DTYPE)) { - try (Tensor tensorResult = - session.runner().fetch(dims).run().get(0).expect(TInt32.DTYPE)) { - tensorResult.data().scalars().forEach(s -> result.add((long) s.getInt())); - } - } else if (dType.equals(TInt64.DTYPE)) { - try (Tensor tensorResult = - session.runner().fetch(dims).run().get(0).expect(TInt64.DTYPE)) { - tensorResult.data().scalars().forEach(s -> result.add(s.getLong())); - } - } else if (dType.equals(TUint8.DTYPE)) { - try (Tensor tensorResult = - session.runner().fetch(dims).run().get(0).expect(TUint8.DTYPE)) { - tensorResult.data().scalars().forEach(s -> result.add(s.getObject().longValue())); - } - } else { // shouldn't happen - throw new IllegalArgumentException("the data type must be an integer type"); - } - } + return getLongArray(dims.asTensor()); + } + try (Session session = new Session((Graph)scope.env()); + Tensor tensor = session.runner().fetch(dims).run().get(0)) { + return getLongArray(tensor); } - return result.stream().mapToLong(i -> i).toArray(); } /** - * Gets the shape for the data within a Tensor + * Converts a TInt32 or TInt64 to a java long array * - * @param tensor the tensor - * @return the Shape of the tensor's data; + * @param scope the TensorFlow scope + * @param dims the dimension tensor + * @param the type of the dimensions, must either be TInt32 or TInt64 type + * @return the long array + * @throws java.lang.IllegalArgumentException if the dims type is not an integer */ - public static Shape getShape(Tensor tensor) { - NdArray data = (NdArray) tensor.data(); - return data.shape(); + public static long[] getLongArray(Tensor dims) { + List result = new ArrayList<>(); + if (dims instanceof TInt32) { + ((TInt32)dims).scalars().forEach(s -> result.add((long) s.getInt())); + } else if (dims instanceof TInt64) { + ((TInt64)dims).scalars().forEach(s -> result.add(s.getLong())); + } else if (dims instanceof TUint8) { + ((TUint8)dims).scalars().forEach(s -> result.add(s.getObject().longValue())); + } else { // shouldn't happen + throw new IllegalArgumentException("the data type must be an integer type"); + } + return result.stream().mapToLong(i -> i).toArray(); } /** diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java index 6a54cb08de6..48800d4dc1b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java @@ -45,13 +45,13 @@ public void testEagerBatchDataset() { int count = 0; for (List> components : dataset) { - try (Tensor batch1 = - components.get(0).asTensor().expect(TInt32.DTYPE); - Tensor batch2 = - components.get(1).asTensor().expect(TInt32.DTYPE);) { + try (TInt32 batch1 = + (TInt32)components.get(0).asTensor(); + TInt32 batch2 = + (TInt32)components.get(1).asTensor();) { - assertEquals(testMatrix1.slice(range(count, count + 2)), batch1.data()); - assertEquals(testMatrix2.slice(range(count, count + 2)), batch2.data()); + assertEquals(testMatrix1.slice(range(count, count + 2)), batch1); + assertEquals(testMatrix2.slice(range(count, count + 2)), batch2); count += 2; } @@ -72,13 +72,13 @@ public void testDropLastBatch() { int count = 0; for (List> components : dataset) { - try (Tensor batch1 = - components.get(0).asTensor().expect(TInt32.DTYPE); - Tensor batch2 = - components.get(1).asTensor().expect(TInt32.DTYPE);) { + try (TInt32 batch1 = + (TInt32)components.get(0).asTensor(); + TInt32 batch2 = + (TInt32)components.get(1).asTensor();) { - assertEquals(testMatrix1.slice(range(count, count + 3)), batch1.data()); - assertEquals(testMatrix2.slice(range(count, count + 3)), batch2.data()); + assertEquals(testMatrix1.slice(range(count, count + 3)), batch1); + assertEquals(testMatrix2.slice(range(count, count + 3)), batch2); count += 3; } @@ -100,21 +100,21 @@ public void testKeepLastBatch() { boolean foundLastBatch = false; for (List> components : dataset) { - try (Tensor batch1 = - components.get(0).asTensor().expect(TInt32.DTYPE); - Tensor batch2 = - components.get(1).asTensor().expect(TInt32.DTYPE);) { + try (TInt32 batch1 = + (TInt32)components.get(0).asTensor(); + TInt32 batch2 = + (TInt32)components.get(1).asTensor();) { if (count == 0) { assertEquals(testMatrix1.slice(range(count, count + 3)), - batch1.data()); + batch1); assertEquals(testMatrix2.slice(range(count, count + 3)), - batch2.data()); + batch2); count += 3; } else { assertEquals(testMatrix1.slice(range(count, count + 1)), - batch1.data()); + batch1); assertEquals(testMatrix2.slice(range(count, count + 1)), - batch2.data()); + batch2); foundLastBatch = true; } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java index 6bb6e21f330..448e90a17ea 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java @@ -54,12 +54,12 @@ public void testGraphIteration() { int batches = 0; while (true) { try { - List> outputs = session.runner().fetch(x).fetch(y).run(); + List outputs = session.runner().fetch(x).fetch(y).run(); - try (Tensor xBatch = outputs.get(0).expect(TInt32.DTYPE); - Tensor yBatch = outputs.get(1).expect(TInt32.DTYPE)) { - assertEquals(testMatrix1.get(batches), xBatch.data()); - assertEquals(testMatrix2.get(batches), yBatch.data()); + try (TInt32 xBatch = (TInt32)outputs.get(0); + TInt32 yBatch = (TInt32)outputs.get(1)) { + assertEquals(testMatrix1.get(batches), xBatch); + assertEquals(testMatrix2.get(batches), yBatch); batches++; } } catch (TFOutOfRangeException e) { @@ -82,11 +82,11 @@ public void testEagerIteration() { Dataset dataset = Dataset.fromTensorSlices(tf, tensors, dataTypes); int count = 0; for (List> outputs : dataset) { - try (Tensor batch1 = outputs.get(0).asTensor().expect(TInt32.DTYPE); - Tensor batch2 = outputs.get(1).asTensor().expect(TInt32.DTYPE); ) { + try (TInt32 batch1 = (TInt32)outputs.get(0).asTensor(); + TInt32 batch2 = (TInt32)outputs.get(1).asTensor(); ) { - assertEquals(testMatrix1.get(count), batch1.data()); - assertEquals(testMatrix2.get(count), batch2.data()); + assertEquals(testMatrix1.get(count), batch1); + assertEquals(testMatrix2.get(count), batch2); count++; } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java index 5960442ff70..ede0a1aa61d 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java @@ -65,8 +65,7 @@ public void testGraphIteration() { Dataset dataset = Dataset.fromTensorSlices(tf, tensors, dataTypes) .mapAllComponents( - component -> - tf.math.mul(component.asOutput().expect(TInt32.DTYPE), tf.constant(2))); + component -> tf.math.mul(component.asOutput().expect(TInt32.DTYPE), tf.constant(2))); DatasetIterator iterator = dataset.makeOneShotIterator(); List> components = iterator.getNext(); @@ -79,13 +78,13 @@ public void testGraphIteration() { int batches = 0; while (true) { try { - List> outputs = session.runner().fetch(X).fetch(y).run(); + List outputs = session.runner().fetch(X).fetch(y).run(); - try (Tensor XBatch = outputs.get(0).expect(TInt32.DTYPE); - Tensor yBatch = outputs.get(1).expect(TInt32.DTYPE)) { + try (TInt32 XBatch = (TInt32)outputs.get(0); + TInt32 yBatch = (TInt32)outputs.get(1)) { - assertEquals(mapped1.get(batches), XBatch.data()); - assertEquals(mapped2.get(batches), yBatch.data()); + assertEquals(mapped1.get(batches), XBatch); + assertEquals(mapped2.get(batches), yBatch); batches++; } @@ -114,11 +113,11 @@ public void testEagerIteration() { int count = 0; for (List> outputs : dataset) { - try (Tensor XBatch = outputs.get(0).asTensor().expect(TInt32.DTYPE); - Tensor yBatch = outputs.get(1).asTensor().expect(TInt32.DTYPE); ) { + try (TInt32 XBatch = (TInt32)outputs.get(0).asTensor(); + TInt32 yBatch = (TInt32)outputs.get(1).asTensor(); ) { - assertEquals(mapped1.get(count), XBatch.data()); - assertEquals(mapped2.get(count), yBatch.data()); + assertEquals(mapped1.get(count), XBatch); + assertEquals(mapped2.get(count), yBatch); count++; } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java index 9ff8080034d..6dc877cc6eb 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java @@ -40,11 +40,11 @@ public void testEagerSkipDataset() { int count = 2; for (List> components : dataset) { - try (Tensor batch1 = components.get(0).asTensor().expect(TInt32.DTYPE); - Tensor batch2 = - components.get(1).asTensor().expect(TInt32.DTYPE); ) { - assertEquals(testMatrix1.get(count), batch1.data()); - assertEquals(testMatrix2.get(count), batch2.data()); + try (TInt32 batch1 = (TInt32)components.get(0).asTensor(); + TInt32 batch2 = + (TInt32)components.get(1).asTensor(); ) { + assertEquals(testMatrix1.get(count), batch1); + assertEquals(testMatrix2.get(count), batch2); count++; } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java index 4419f4660db..626fe719936 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java @@ -41,11 +41,11 @@ public void testEagerTakeDataset() { int count = 0; for (List> components : dataset) { - try (Tensor batch1 = components.get(0).asTensor().expect(TInt32.DTYPE); - Tensor batch2 = components.get(1).asTensor().expect(TInt32.DTYPE); ) { + try (TInt32 batch1 = (TInt32)components.get(0).asTensor(); + TInt32 batch2 = (TInt32)components.get(1).asTensor(); ) { - assertEquals(testMatrix1.get(count), batch1.data()); - assertEquals(testMatrix2.get(count), batch2.data()); + assertEquals(testMatrix1.get(count), batch1); + assertEquals(testMatrix2.get(count), batch2); count++; } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java index 4b992b0a79d..d2592026f12 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java @@ -140,25 +140,23 @@ public void testBasic() { (float) Math.pow(beta1, step + 1), (float) Math.pow(beta2, step + 1) }; - try (Tensor result = - session + try (TFloat32 result = + (TFloat32)session .getGraphSession() .runner() .fetch("beta1_power") .run() - .get(0) - .expect(TFloat32.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals(powers[0], f.getFloat(), epsilon1)); + .get(0)) { + result.scalars().forEach(f -> assertEquals(powers[0], f.getFloat(), epsilon1)); } - try (Tensor result = - session + try (TFloat32 result = + (TFloat32)session .getGraphSession() .runner() .fetch("beta2_power") .run() - .get(0) - .expect(TFloat32.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals(powers[1], f.getFloat(), epsilon1)); + .get(0)) { + result.scalars().forEach(f -> assertEquals(powers[1], f.getFloat(), epsilon1)); } session.run(update); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java index a303067cdc8..ac322f952db 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java @@ -149,15 +149,14 @@ public void testBasic() { // Test powers final float beta1Power = (float) Math.pow(BETA_ONE_DEFAULT, step + 1); - try (Tensor result = - session + try (TFloat32 result = + (TFloat32)session .getGraphSession() .runner() .fetch("beta1_power") .run() - .get(0) - .expect(TFloat32.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals(beta1Power, f.getFloat(), epsilon1)); + .get(0)) { + result.scalars().forEach(f -> assertEquals(beta1Power, f.getFloat(), epsilon1)); } session.run(update); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java index 6c26aab2995..e064e793a37 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java @@ -147,15 +147,14 @@ public void testBasic() { session.evaluate(var0Init, var0); session.evaluate(var1Init, var1); - try (Tensor result = - session + try (TFloat32 result = + (TFloat32)session .getGraphSession() .runner() .fetch("momentum") .run() - .get(0) - .expect(TFloat32.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals(1F, f.getFloat(), epsilon1)); + .get(0)) { + result.scalars().forEach(f -> assertEquals(1F, f.getFloat(), epsilon1)); } momentum = 1F; @@ -167,15 +166,14 @@ public void testBasic() { Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); momentum = momentum * mut; - try (Tensor result = - session + try (TFloat32 result = + (TFloat32)session .getGraphSession() .runner() .fetch("momentum") .run() - .get(0) - .expect(TFloat32.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals(momentum, f.getFloat(), epsilon1)); + .get(0)) { + result.scalars().forEach(f -> assertEquals(momentum, f.getFloat(), epsilon1)); } mcache = ND.mul(mcache, momentum); FloatNdArray[] resultsNP = nadamUpdateNdArray(var0Np, grads0Np, step, m0, v0, mcache); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java index bca90211e50..1f5f2f16053 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java @@ -89,52 +89,52 @@ public void evaluate(double expected, Operand input) { Operand o = (Operand) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); } index.set(0); - o.data().scalars().forEach(f -> assertEquals(expected, f.getFloat(), epsilon)); + o.asTensor().scalars().forEach(f -> assertEquals(expected, f.getFloat(), epsilon)); } else if (dtype == TFloat64.DTYPE) { Operand o = (Operand) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); } index.set(0); - o.data().scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); + o.asTensor().scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); } else if (dtype == TInt32.DTYPE) { Operand o = (Operand) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } index.set(0); - o.data().scalars().forEach(f -> assertEquals((int) expected, f.getInt())); + o.asTensor().scalars().forEach(f -> assertEquals((int) expected, f.getInt())); } else if (dtype == TInt64.DTYPE) { Operand o = (Operand) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } index.set(0); - o.data().scalars().forEach(f -> assertEquals((long) expected, f.getLong())); + o.asTensor().scalars().forEach(f -> assertEquals((long) expected, f.getLong())); } else if (dtype == TUint8.DTYPE) { Operand o = (Operand) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); } index.set(0); - o.data().scalars().forEach(f -> assertEquals((long) expected, f.getByte())); + o.asTensor().scalars().forEach(f -> assertEquals((long) expected, f.getByte())); } } @@ -151,12 +151,12 @@ public void evaluate(Number[] expected, Output input) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); } index.set(0); - o.data() + o.asTensor() .scalars() .forEach( f -> @@ -166,12 +166,12 @@ public void evaluate(Number[] expected, Output input) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); } index.set(0); - o.data() + o.asTensor() .scalars() .forEach( f -> @@ -181,36 +181,36 @@ public void evaluate(Number[] expected, Output input) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } index.set(0); - o.data() + o.asTensor() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); } else if (dtype == TInt64.DTYPE) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } index.set(0); - o.data() + o.asTensor() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); } else if (dtype == TUint8.DTYPE) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%x). %d\n", index.getAndIncrement(), f.getByte())); } index.set(0); - o.data() + o.asTensor() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].byteValue(), f.getByte())); } @@ -224,12 +224,12 @@ public void evaluate(FloatNdArray expected, Output input) { Output o = (Output) input; AtomicLong index = new AtomicLong(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); } index.set(0); - o.data() + o.asTensor() .scalars() .forEach( f -> assertEquals(expected.getFloat(index.getAndIncrement()), f.getFloat(), epsilon)); @@ -237,12 +237,12 @@ public void evaluate(FloatNdArray expected, Output input) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); } index.set(0); - o.data() + o.asTensor() .scalars() .forEach( f -> @@ -251,24 +251,24 @@ public void evaluate(FloatNdArray expected, Output input) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } index.set(0); - for (IntNdArray f : o.data().scalars()) { + for (IntNdArray f : o.asTensor().scalars()) { assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt()); } } else if (dtype == TInt64.DTYPE) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } index.set(0); - o.data() + o.asTensor() .scalars() .forEach( f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); @@ -276,12 +276,12 @@ public void evaluate(FloatNdArray expected, Output input) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); if (debug) { - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); } index.set(0); - o.data() + o.asTensor() .scalars() .forEach( f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getByte())); @@ -296,10 +296,10 @@ public void evaluateString(Output input, Predicate predicate) { if (debug) { if (isScalar) { System.out.printf( - "0). %b <==> %s\n", predicate.test(input.data().getObject()), input.data().getObject()); + "0). %b <==> %s\n", predicate.test(input.asTensor().getObject()), input.asTensor().getObject()); } else { input - .data() + .asTensor() .scalars() .forEachIndexed( (idx, s) -> @@ -310,9 +310,9 @@ public void evaluateString(Output input, Predicate predicate) { } index.set(0); if (isScalar) { - assertTrue(predicate.test(input.data().getObject())); + assertTrue(predicate.test(input.asTensor().getObject())); } else { - input.data().scalars().forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); + input.asTensor().scalars().forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); } } @@ -327,9 +327,9 @@ public void evaluate(Output input, Predicate predic if (debug) { if (isScalar) { System.out.printf( - "0). %b <==> %f\n", predicate.test(o.data().getFloat()), o.data().getFloat()); + "0). %b <==> %f\n", predicate.test(o.asTensor().getFloat()), o.asTensor().getFloat()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> @@ -340,20 +340,20 @@ public void evaluate(Output input, Predicate predic } index.set(0); if (isScalar) { - assertTrue(predicate.test(o.data().getFloat())); + assertTrue(predicate.test(o.asTensor().getFloat())); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getFloat()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getFloat()))); } } else if (dtype == TFloat64.DTYPE) { Output o = (Output) input; if (debug) { if (isScalar) { System.out.printf( - "0). %b <==> %f\n", predicate.test(o.data().getDouble()), o.data().getDouble()); + "0). %b <==> %f\n", predicate.test(o.asTensor().getDouble()), o.asTensor().getDouble()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> @@ -364,20 +364,20 @@ public void evaluate(Output input, Predicate predic } index.set(0); if (isScalar) { - assertTrue(predicate.test(o.data().getDouble())); + assertTrue(predicate.test(o.asTensor().getDouble())); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getDouble()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getDouble()))); } } else if (dtype == TFloat16.DTYPE) { Output o = (Output) input; if (debug) { if (isScalar) { System.out.printf( - "0). %b <==> %f\n", predicate.test(o.data().getFloat()), o.data().getFloat()); + "0). %b <==> %f\n", predicate.test(o.asTensor().getFloat()), o.asTensor().getFloat()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> @@ -388,20 +388,20 @@ public void evaluate(Output input, Predicate predic } index.set(0); if (isScalar) { - assertTrue(predicate.test(o.data().getFloat())); + assertTrue(predicate.test(o.asTensor().getFloat())); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getFloat()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getFloat()))); } } else if (dtype == TInt32.DTYPE) { Output o = (Output) input; if (debug) { if (isScalar) { System.out.printf( - "0). %b <==> %d\n", predicate.test(o.data().getInt()), o.data().getInt()); + "0). %b <==> %d\n", predicate.test(o.asTensor().getInt()), o.asTensor().getInt()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> @@ -412,20 +412,20 @@ public void evaluate(Output input, Predicate predic } index.set(0); if (isScalar) { - assertTrue(predicate.test(o.data().getInt())); + assertTrue(predicate.test(o.asTensor().getInt())); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getInt()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getInt()))); } } else if (dtype == TInt64.DTYPE) { Output o = (Output) input; if (debug) { if (isScalar) { System.out.printf( - "0). %b <==> %d\n", predicate.test(o.data().getLong()), o.data().getLong()); + "0). %b <==> %d\n", predicate.test(o.asTensor().getLong()), o.asTensor().getLong()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> @@ -436,20 +436,20 @@ public void evaluate(Output input, Predicate predic } index.set(0); if (isScalar) { - assertTrue(predicate.test(o.data().getLong())); + assertTrue(predicate.test(o.asTensor().getLong())); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getLong()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getLong()))); } } else if (dtype == TUint8.DTYPE) { Output o = (Output) input; if (debug) { if (isScalar) { System.out.printf( - "0). %b <==> %x\n", predicate.test(o.data().getByte()), o.data().getByte()); + "0). %b <==> %x\n", predicate.test(o.asTensor().getByte()), o.asTensor().getByte()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> @@ -460,11 +460,11 @@ public void evaluate(Output input, Predicate predic } index.set(0); if (isScalar) { - assertTrue(predicate.test(o.data().getByte())); + assertTrue(predicate.test(o.asTensor().getByte())); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.data().getByte()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getByte()))); } } else { fail("Unexpected DataType: " + dtype); @@ -482,13 +482,13 @@ public void evaluate(String[] expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { input - .data() + .asTensor() .scalars() .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); } index.set(0); input - .data() + .asTensor() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); } @@ -504,13 +504,13 @@ public void evaluate(Boolean[] expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { input - .data() + .asTensor() .scalars() .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); } index.set(0); input - .data() + .asTensor() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getBoolean())); } @@ -530,25 +530,25 @@ public void evaluate(Output expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { if (isScalar) { - System.out.printf("0). %f <==> %f\n", x.data().getFloat(), o.data().getFloat()); + System.out.printf("0). %f <==> %f\n", x.asTensor().getFloat(), o.asTensor().getFloat()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %f <==> %f\n", - index.getAndIncrement(), x.data().getFloat(idx), f.getFloat())); + index.getAndIncrement(), x.asTensor().getFloat(idx), f.getFloat())); } } index.set(0); if (isScalar) { - assertEquals(x.data().getFloat(), o.data().getFloat(), epsilon); + assertEquals(x.asTensor().getFloat(), o.asTensor().getFloat(), epsilon); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( - (idx, f) -> assertEquals(x.data().getFloat(idx), f.getFloat(), epsilon)); + (idx, f) -> assertEquals(x.asTensor().getFloat(idx), f.getFloat(), epsilon)); } } else if (dtype == TFloat64.DTYPE) { Output x = (Output) expected; @@ -556,25 +556,25 @@ public void evaluate(Output expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { if (isScalar) { - System.out.printf("0). %f <==> %f\n", x.data().getDouble(), o.data().getDouble()); + System.out.printf("0). %f <==> %f\n", x.asTensor().getDouble(), o.asTensor().getDouble()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %f <==> %f\n", - index.getAndIncrement(), x.data().getDouble(idx), f.getDouble())); + index.getAndIncrement(), x.asTensor().getDouble(idx), f.getDouble())); } } index.set(0); if (isScalar) { - assertEquals(x.data().getDouble(), o.data().getDouble(), epsilon); + assertEquals(x.asTensor().getDouble(), o.asTensor().getDouble(), epsilon); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( - (idx, f) -> assertEquals(x.data().getDouble(idx), f.getDouble(), epsilon)); + (idx, f) -> assertEquals(x.asTensor().getDouble(idx), f.getDouble(), epsilon)); } } else if (dtype == TInt32.DTYPE) { Output x = (Output) expected; @@ -582,24 +582,24 @@ public void evaluate(Output expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { if (isScalar) { - System.out.printf("0). %d <==> %d\n", x.data().getInt(), o.data().getInt()); + System.out.printf("0). %d <==> %d\n", x.asTensor().getInt(), o.asTensor().getInt()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %d <==> %d\n", - index.getAndIncrement(), x.data().getInt(idx), f.getInt())); + index.getAndIncrement(), x.asTensor().getInt(idx), f.getInt())); } } index.set(0); if (isScalar) { - assertEquals(x.data().getInt(), o.data().getInt()); + assertEquals(x.asTensor().getInt(), o.asTensor().getInt()); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.data().getInt(idx), f.getInt())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getInt(idx), f.getInt())); } } else if (dtype == TInt64.DTYPE) { Output x = (Output) expected; @@ -607,24 +607,24 @@ public void evaluate(Output expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { if (isScalar) { - System.out.printf("0). %d <==> %d\n", x.data().getLong(), o.data().getLong()); + System.out.printf("0). %d <==> %d\n", x.asTensor().getLong(), o.asTensor().getLong()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %d <==> %d\n", - index.getAndIncrement(), x.data().getLong(idx), f.getLong())); + index.getAndIncrement(), x.asTensor().getLong(idx), f.getLong())); } } index.set(0); if (isScalar) { - assertEquals(x.data().getLong(), o.data().getLong()); + assertEquals(x.asTensor().getLong(), o.asTensor().getLong()); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.data().getLong(idx), f.getLong())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getLong(idx), f.getLong())); } } else if (dtype == TUint8.DTYPE) { Output x = (Output) expected; @@ -632,24 +632,24 @@ public void evaluate(Output expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { if (isScalar) { - System.out.printf("0). %x <==> %x\n", x.data().getByte(), o.data().getByte()); + System.out.printf("0). %x <==> %x\n", x.asTensor().getByte(), o.asTensor().getByte()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %x <==> %x\n", - index.getAndIncrement(), x.data().getByte(idx), f.getByte())); + index.getAndIncrement(), x.asTensor().getByte(idx), f.getByte())); } } index.set(0); if (isScalar) { - assertEquals(x.data().getByte(), o.data().getByte()); + assertEquals(x.asTensor().getByte(), o.asTensor().getByte()); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.data().getByte(idx), f.getByte())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getByte(idx), f.getByte())); } } else if (dtype == TString.DTYPE) { Output x = (Output) expected; @@ -657,24 +657,24 @@ public void evaluate(Output expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { if (isScalar) { - System.out.printf("0). %s <==> %s\n", x.data().getObject(), o.data().getObject()); + System.out.printf("0). %s <==> %s\n", x.asTensor().getObject(), o.asTensor().getObject()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %s <==> %s\n", - index.getAndIncrement(), x.data().getObject(idx), f.getObject())); + index.getAndIncrement(), x.asTensor().getObject(idx), f.getObject())); } } index.set(0); if (isScalar) { - assertEquals(x.data().getObject(), o.data().getObject()); + assertEquals(x.asTensor().getObject(), o.asTensor().getObject()); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.data().getObject(idx), f.getObject())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getObject(idx), f.getObject())); } } else if (dtype == TBool.DTYPE) { Output x = (Output) expected; @@ -682,24 +682,24 @@ public void evaluate(Output expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { if (isScalar) { - System.out.printf("0). %b <==> %b\n", x.data().getBoolean(), o.data().getBoolean()); + System.out.printf("0). %b <==> %b\n", x.asTensor().getBoolean(), o.asTensor().getBoolean()); } else { - o.data() + o.asTensor() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %b <==> %b\n", - index.getAndIncrement(), x.data().getBoolean(idx), f.getBoolean())); + index.getAndIncrement(), x.asTensor().getBoolean(idx), f.getBoolean())); } } index.set(0); if (isScalar) { - assertEquals(x.data().getBoolean(), o.data().getBoolean()); + assertEquals(x.asTensor().getBoolean(), o.asTensor().getBoolean()); } else { - o.data() + o.asTensor() .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.data().getBoolean(idx), f.getBoolean())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getBoolean(idx), f.getBoolean())); } } } @@ -711,43 +711,43 @@ public void print(PrintWriter writer, Output input) { if (dtype == TFloat32.DTYPE) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); } else if (dtype == TFloat64.DTYPE) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); } else if (dtype == TInt32.DTYPE) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } else if (dtype == TInt64.DTYPE) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } else if (dtype == TUint8.DTYPE) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); } else if (dtype == TString.DTYPE) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); } else if (dtype == TBool.DTYPE) { Output o = (Output) input; AtomicInteger index = new AtomicInteger(); - o.data() + o.asTensor() .scalars() .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); } else { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java index c58667c15d0..6df79a4432a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java @@ -30,27 +30,35 @@ import static org.junit.jupiter.api.Assertions.*; -/** Graph Mode Test Session */ +/** + * Graph Mode Test Session + */ public class GraphTestSession extends TestSession { private final Graph graph; private final Session session; private final Ops tf; - /** Create a Graph mode test session. */ + /** + * Create a Graph mode test session. + */ public GraphTestSession() { graph = new Graph(); session = new Session(graph); tf = Ops.create(graph).withName("test"); } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public Ops getTF() { return tf; } - /** Get the Graph object that is represented by this Test Session */ + /** + * Get the Graph object that is represented by this Test Session + */ public Graph getGraph() { return graph; } @@ -64,133 +72,144 @@ public Session getSession() { return session; } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void close() { session.close(); graph.close(); } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public boolean isEager() { return false; } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public Session getGraphSession() { return this.session; } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public EagerSession getEagerSession() { return null; } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void initialize() { graph.initializers().forEach(initializer -> session.runner().addTarget(initializer).run()); } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void run(Op op) { session.run(op); } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluate(double expected, Operand input) { DataType dtype = input.asOutput().dataType(); if (dtype == TFloat32.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals((float) expected, f.getFloat(), epsilon)); + try (TFloat32 result = + (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + result.scalars().forEach(f -> assertEquals((float) expected, f.getFloat(), epsilon)); } } else if (dtype == TFloat64.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); + try (TFloat64 result = + (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + result.scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); } } else if (dtype == TInt32.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals((int) expected, f.getInt())); + try (TInt32 result = + (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + result.scalars().forEach(f -> assertEquals((int) expected, f.getInt())); } } else if (dtype == TInt64.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals((long) expected, f.getLong())); + try (TInt64 result = + (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + result.scalars().forEach(f -> assertEquals((long) expected, f.getLong())); } } else if (dtype == TUint8.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { - result.data().scalars().forEach(f -> assertEquals((long) expected, f.getByte())); + try (TUint8 result = + (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + result.scalars().forEach(f -> assertEquals((long) expected, f.getByte())); } } else { fail("Unexpected DataType: " + dtype); } } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluate(Number[] expected, Output input) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); @@ -202,19 +221,17 @@ public void evaluate(Number[] expected, Output input) { if (dtype == TFloat32.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach( f -> @@ -224,19 +241,17 @@ public void evaluate(Number[] expected, Output input) { } else if (dtype == TFloat64.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach( f -> @@ -246,57 +261,51 @@ public void evaluate(Number[] expected, Output input) { } else if (dtype == TInt32.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); } } else if (dtype == TInt64.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); } } else if (dtype == TUint8.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getByte())); } @@ -305,26 +314,26 @@ public void evaluate(Number[] expected, Output input) { } } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluate(FloatNdArray expected, Output input) { DataType dtype = input.asOutput().dataType(); if (dtype == TFloat32.DTYPE) { AtomicLong index = new AtomicLong(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach( f -> @@ -334,19 +343,17 @@ public void evaluate(FloatNdArray expected, Output input) { } else if (dtype == TFloat64.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach( f -> @@ -356,19 +363,17 @@ public void evaluate(FloatNdArray expected, Output input) { } else if (dtype == TInt32.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach( f -> assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt())); @@ -376,19 +381,17 @@ public void evaluate(FloatNdArray expected, Output input) { } else if (dtype == TInt64.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach( f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); @@ -396,19 +399,17 @@ public void evaluate(FloatNdArray expected, Output input) { } else if (dtype == TUint8.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach( f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getByte())); @@ -418,7 +419,9 @@ public void evaluate(FloatNdArray expected, Output input) { } } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluate(String[] expected, Output input) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); @@ -428,25 +431,25 @@ public void evaluate(String[] expected, Output input) { () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + try (TString result = + (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + try (TString result = + (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); } } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluate(Boolean[] expected, Output input) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); @@ -456,31 +459,31 @@ public void evaluate(Boolean[] expected, Output input) { () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + try (TBool result = + (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getObject())); } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + try (TBool result = + (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { result - .data() .scalars() .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); } } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluate(Output expected, Output input) { assert input.shape().equals(expected.shape()) : String.format( - "expected shape (%s) != to input shape (%s)", - expected.shape().toString(), input.shape().toString()); + "expected shape (%s) != to input shape (%s)", + expected.shape().toString(), input.shape().toString()); AtomicInteger index = new AtomicInteger(); DataType dtype = input.asOutput().dataType(); if (!dtype.equals(expected.dataType())) { @@ -493,316 +496,300 @@ public void evaluate(Output expected, Output input) { if (dtype == TFloat32.DTYPE) { final Output finalExpected = (Output) expected; if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0); + TFloat32 expectedResult = + (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %f <==> %f\n", expectedResult.data().getFloat(), result.data().getFloat()); + "0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %f <==> %f\n", index.getAndIncrement(), - finalExpected.data().getFloat(idx), + finalExpected.asTensor().getFloat(idx), f.getFloat())); } } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0); + TFloat32 expectedResult = + (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertEquals(expectedResult.data().getFloat(), result.data().getFloat(), epsilon); + assertEquals(expectedResult.getFloat(), result.getFloat(), epsilon); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> - assertEquals(expectedResult.data().getFloat(idx), f.getFloat(), epsilon)); + assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); } } } else if (dtype == TFloat64.DTYPE) { final Output finalExpected = (Output) expected; if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0); + TFloat64 expectedResult = + (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %f <==> %f\n", expectedResult.data().getDouble(), result.data().getDouble()); + "0). %f <==> %f\n", expectedResult.getDouble(), result.getDouble()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %f <==> %f\n", index.getAndIncrement(), - finalExpected.data().getDouble(idx), + finalExpected.asTensor().getDouble(idx), f.getDouble())); } } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0); + TFloat64 expectedResult = + (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertEquals(expectedResult.data().getDouble(), result.data().getDouble(), epsilon); + assertEquals(expectedResult.getDouble(), result.getDouble(), epsilon); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> - assertEquals(expectedResult.data().getDouble(idx), f.getDouble(), epsilon)); + assertEquals(expectedResult.getDouble(idx), f.getDouble(), epsilon)); } } } else if (dtype == TFloat16.DTYPE) { final Output finalExpected = (Output) expected; if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat16.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat16.DTYPE)) { + try (TFloat16 result = + (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0); + TFloat16 expectedResult = + (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %f <==> %f\n", expectedResult.data().getFloat(), result.data().getFloat()); + "0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); } else { result - .data() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), - finalExpected.data().getFloat(idx), - f.getFloat())); + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %f <==> %f\n", + index.getAndIncrement(), + finalExpected.asTensor().getFloat(idx), + f.getFloat())); } } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat16.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat16.DTYPE)) { + try (TFloat16 result = + (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0); + TFloat16 expectedResult = + (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertEquals(expectedResult.data().getFloat(), result.data().getFloat(), epsilon); + assertEquals(expectedResult.getFloat(), result.getFloat(), epsilon); } else { result - .data() - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.data().getFloat(idx), f.getFloat(), epsilon)); + .scalars() + .forEachIndexed( + (idx, f) -> + assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); } } } else if (dtype == TInt32.DTYPE) { final Output finalExpected = (Output) expected; if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + (TInt32)this.getGraphSession().runner().fetch(input).run().get(0); + TInt32 expectedResult = + (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %d <==> %d\n", expectedResult.data().getInt(), result.data().getInt()); + "0). %d <==> %d\n", expectedResult.getInt(), result.getInt()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %d <==> %d\n", - index.getAndIncrement(), finalExpected.data().getInt(idx), f.getInt())); + index.getAndIncrement(), finalExpected.asTensor().getInt(idx), f.getInt())); } } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + (TInt32)this.getGraphSession().runner().fetch(input).run().get(0); + TInt32 expectedResult = + (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertEquals(expectedResult.data().getInt(), result.data().getInt(), epsilon); + assertEquals(expectedResult.getInt(), result.getInt(), epsilon); } else { result - .data() .scalars() .forEachIndexed( - (idx, f) -> assertEquals(expectedResult.data().getInt(idx), f.getInt(), epsilon)); + (idx, f) -> assertEquals(expectedResult.getInt(idx), f.getInt(), epsilon)); } } } else if (dtype == TInt64.DTYPE) { final Output finalExpected = (Output) expected; if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + (TInt64)this.getGraphSession().runner().fetch(input).run().get(0); + TInt64 expectedResult = + (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %d <==> %d\n", expectedResult.data().getLong(), result.data().getLong()); + "0). %d <==> %d\n", expectedResult.getLong(), result.getLong()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %d <==> %d\n", index.getAndIncrement(), - finalExpected.data().getLong(idx), + finalExpected.asTensor().getLong(idx), f.getLong())); } } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + (TInt64)this.getGraphSession().runner().fetch(input).run().get(0); + TInt64 expectedResult = + (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertEquals(expectedResult.data().getLong(), result.data().getLong(), epsilon); + assertEquals(expectedResult.getLong(), result.getLong(), epsilon); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> - assertEquals(expectedResult.data().getLong(idx), f.getLong(), epsilon)); + assertEquals(expectedResult.getLong(idx), f.getLong(), epsilon)); } } } else if (dtype == TUint8.DTYPE) { final Output finalExpected = (Output) expected; if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + (TUint8)this.getGraphSession().runner().fetch(input).run().get(0); + TUint8 expectedResult = + (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %d <==> %d\n", expectedResult.data().getByte(), result.data().getByte()); + "0). %d <==> %d\n", expectedResult.getByte(), result.getByte()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %d <==> %d\n", index.getAndIncrement(), - finalExpected.data().getByte(idx), + finalExpected.asTensor().getByte(idx), f.getByte())); } } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + (TUint8)this.getGraphSession().runner().fetch(input).run().get(0); + TUint8 expectedResult = + (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertEquals(expectedResult.data().getByte(), result.data().getByte(), epsilon); + assertEquals(expectedResult.getByte(), result.getByte(), epsilon); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> - assertEquals(expectedResult.data().getByte(idx), f.getByte(), epsilon)); + assertEquals(expectedResult.getByte(idx), f.getByte(), epsilon)); } } } else if (dtype == TBool.DTYPE) { final Output finalExpected = (Output) expected; if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + try (TBool result = + (TBool)this.getGraphSession().runner().fetch(input).run().get(0); + TBool expectedResult = + (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %b <==> %b\n", expectedResult.data().getBoolean(), result.data().getBoolean()); + "0). %b <==> %b\n", expectedResult.getBoolean(), result.getBoolean()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %b <==> %b\n", index.getAndIncrement(), - finalExpected.data().getBoolean(idx), + finalExpected.asTensor().getBoolean(idx), f.getBoolean())); } } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + try (TBool result = + (TBool)this.getGraphSession().runner().fetch(input).run().get(0); + TBool expectedResult = + (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertEquals(expectedResult.data().getBoolean(), result.data().getBoolean()); + assertEquals(expectedResult.getBoolean(), result.getBoolean()); } else { result - .data() .scalars() .forEachIndexed( - (idx, f) -> assertEquals(expectedResult.data().getBoolean(idx), f.getBoolean())); + (idx, f) -> assertEquals(expectedResult.getBoolean(idx), f.getBoolean())); } } } else if (dtype == TString.DTYPE) { final Output finalExpected = (Output) expected; if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + try (TString result = + (TString)this.getGraphSession().runner().fetch(input).run().get(0); + TString expectedResult = + (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %s <==> %s\n", expectedResult.data().getObject(), result.data().getObject()); + "0). %s <==> %s\n", expectedResult.getObject(), result.getObject()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> System.out.printf( "%d). %s <==> %s\n", index.getAndIncrement(), - finalExpected.data().getObject(idx), + finalExpected.asTensor().getObject(idx), f.getObject())); } } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE); - Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + try (TString result = + (TString)this.getGraphSession().runner().fetch(input).run().get(0); + TString expectedResult = + (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertEquals(expectedResult.data().getObject(), result.data().getObject()); + assertEquals(expectedResult.getObject(), result.getObject()); } else { result - .data() .scalars() .forEachIndexed( - (idx, f) -> assertEquals(expectedResult.data().getObject(idx), f.getObject())); + (idx, f) -> assertEquals(expectedResult.getObject(idx), f.getObject())); } } } else { @@ -810,21 +797,22 @@ public void evaluate(Output expected, Output input) { } } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluateString(Output input, Predicate predicate) { boolean isScalar = input.shape().equals(Shape.scalar()); AtomicInteger index = new AtomicInteger(); if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + try (TString result = + (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( "0). %b <==> %s\n", - predicate.test(result.data().getObject()), result.data().getObject()); + predicate.test(result.getObject()), result.getObject()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> @@ -835,20 +823,21 @@ public void evaluateString(Output input, Predicate predicate) { } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + try (TString result = + (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertTrue(predicate.test(result.data().getObject())); + assertTrue(predicate.test(result.getObject())); } else { result - .data() .scalars() .forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); } } } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluate(Output input, Predicate predicate) { AtomicInteger index = new AtomicInteger(); @@ -856,15 +845,14 @@ public void evaluate(Output input, Predicate predic boolean isScalar = input.shape().equals(Shape.scalar()); if (dtype == TFloat32.DTYPE) { if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( "0). %b <==> %f\n", - predicate.test(result.data().getFloat()), result.data().getFloat()); + predicate.test(result.getFloat()), result.getFloat()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> @@ -875,28 +863,26 @@ public void evaluate(Output input, Predicate predic } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertTrue(predicate.test(result.data().getFloat())); + assertTrue(predicate.test(result.getFloat())); } else { result - .data() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.data().getFloat()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getFloat()))); } } } else if (dtype == TFloat64.DTYPE) { if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( "0). %b <==> %f\n", - predicate.test(result.data().getDouble()), result.data().getDouble()); + predicate.test(result.getDouble()), result.getDouble()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> @@ -907,27 +893,25 @@ public void evaluate(Output input, Predicate predic } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertTrue(predicate.test(result.data().getDouble())); + assertTrue(predicate.test(result.getDouble())); } else { result - .data() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.data().getDouble()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getDouble()))); } } } else if (dtype == TInt32.DTYPE) { if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( - "0). %b <==> %d\n", predicate.test(result.data().getInt()), result.data().getInt()); + "0). %b <==> %d\n", predicate.test(result.getInt()), result.getInt()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> @@ -938,28 +922,26 @@ public void evaluate(Output input, Predicate predic } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertTrue(predicate.test(result.data().getInt())); + assertTrue(predicate.test(result.getInt())); } else { result - .data() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.data().getInt()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getInt()))); } } } else if (dtype == TInt64.DTYPE) { if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( "0). %b <==> %d\n", - predicate.test(result.data().getLong()), result.data().getLong()); + predicate.test(result.getLong()), result.getLong()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> @@ -970,28 +952,26 @@ public void evaluate(Output input, Predicate predic } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertTrue(predicate.test(result.data().getLong())); + assertTrue(predicate.test(result.getLong())); } else { result - .data() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.data().getLong()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getLong()))); } } } else if (dtype == TUint8.DTYPE) { if (debug) { - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { System.out.printf( "0). %b <==> %d\n", - predicate.test(result.data().getByte()), result.data().getByte()); + predicate.test(result.getByte()), result.getByte()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> @@ -1002,15 +982,14 @@ public void evaluate(Output input, Predicate predic } } index.set(0); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - assertTrue(predicate.test(result.data().getByte())); + assertTrue(predicate.test(result.getByte())); } else { result - .data() .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.data().getByte()))); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getByte()))); } } } else { @@ -1018,7 +997,9 @@ public void evaluate(Output input, Predicate predic } } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void print(PrintWriter writer, Output input) { boolean isScalar = input.shape().size() == 1; @@ -1026,13 +1007,12 @@ public void print(PrintWriter writer, Output input) { DataType dtype = input.dataType(); if (dtype == TFloat32.DTYPE) { AtomicInteger index = new AtomicInteger(); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = + (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { - writer.printf("%d). %f\n", index.getAndIncrement(), result.data().getFloat()); + writer.printf("%d). %f\n", index.getAndIncrement(), result.getFloat()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); @@ -1041,14 +1021,13 @@ public void print(PrintWriter writer, Output input) { } else if (dtype == TFloat64.DTYPE) { AtomicInteger index = new AtomicInteger(); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + try (TFloat64 result = + (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %f\n", index.getAndIncrement(), ((Output) input).data().getDouble()); + "%d). %f\n", index.getAndIncrement(), ((Output) input).asTensor().getDouble()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); @@ -1057,14 +1036,13 @@ public void print(PrintWriter writer, Output input) { } else if (dtype == TInt32.DTYPE) { AtomicInteger index = new AtomicInteger(); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + try (TInt32 result = + (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %d\n", index.getAndIncrement(), ((Output) input).data().getInt()); + "%d). %d\n", index.getAndIncrement(), ((Output) input).asTensor().getInt()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); @@ -1073,14 +1051,13 @@ public void print(PrintWriter writer, Output input) { } else if (dtype == TInt64.DTYPE) { AtomicInteger index = new AtomicInteger(); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + try (TInt64 result = + (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %d\n", index.getAndIncrement(), ((Output) input).data().getLong()); + "%d). %d\n", index.getAndIncrement(), ((Output) input).asTensor().getLong()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); @@ -1089,14 +1066,13 @@ public void print(PrintWriter writer, Output input) { } else if (dtype == TUint8.DTYPE) { AtomicInteger index = new AtomicInteger(); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TUint8.DTYPE)) { + try (TUint8 result = + (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %x\n", index.getAndIncrement(), ((Output) input).data().getByte()); + "%d). %x\n", index.getAndIncrement(), ((Output) input).asTensor().getByte()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> writer.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); @@ -1105,14 +1081,13 @@ public void print(PrintWriter writer, Output input) { } else if (dtype == TBool.DTYPE) { AtomicInteger index = new AtomicInteger(); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + try (TBool result = + (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %b\n", index.getAndIncrement(), ((Output) input).data().getBoolean()); + "%d). %b\n", index.getAndIncrement(), ((Output) input).asTensor().getBoolean()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> writer.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); @@ -1121,14 +1096,13 @@ public void print(PrintWriter writer, Output input) { } else if (dtype == TString.DTYPE) { AtomicInteger index = new AtomicInteger(); - try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + try (TString result = + (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %s\n", index.getAndIncrement(), ((Output) input).data().getObject()); + "%d). %s\n", index.getAndIncrement(), ((Output) input).asTensor().getObject()); } else { result - .data() .scalars() .forEachIndexed( (idx, f) -> writer.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); From 88e9a2166399a39bcbd5917af52dd9f32a495c17 Mon Sep 17 00:00:00 2001 From: klessard Date: Sat, 12 Dec 2020 22:51:30 -0500 Subject: [PATCH 2/3] Rectify documentation based on PR review --- .../annotations/org/tensorflow/op/Ops.java | 25 +++++++++-------- .../main/java/org/tensorflow/DataType.java | 4 +++ .../main/java/org/tensorflow/RawTensor.java | 28 ++++++++++++------- .../src/main/java/org/tensorflow/Tensor.java | 9 +++--- .../java/org/tensorflow/op/core/Constant.java | 7 +++-- .../org/tensorflow/types/family/TType.java | 22 +++++++++++++-- .../org/tensorflow/op/core/ConstantTest.java | 2 +- .../org/tensorflow/op/core/GradientsTest.java | 6 ++-- .../types/NumericTypesTestBase.java | 6 ++-- 9 files changed, 71 insertions(+), 38 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index f4e70f54e39..06908c41d6a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -1072,17 +1072,6 @@ public Bucketize bucketize(Operand input, List bou return Bucketize.create(scope, input, boundaries); } - /** - * Capture a {@code tensor} by making a constant copy of it. - * - * @param scope is a scope used to add the underlying operation. - * @param tensor a Tensor holding the constant value - * @return a constant of the same data type as `tensor` - */ - public Constant capture(T tensor) { - return Constant.create(scope, tensor); - } - /** * Clips tensor values to a specified min and max. *

@@ -1878,6 +1867,20 @@ public Constant constant(DataType type, Shape shape, return Constant.tensorOf(scope, type, shape, data); } + /** + * Create a constant by making an immutable copy of {@code tensor}. + * + *

Note: this endpoint cannot be simply called {@code constant} since it will conflict with + * other endpoints accepting an NdArray in parameter {e.g. {@link #tensorOf(Scope, FloatNdArray)}}. + * + * @param scope is a scope used to add the underlying operation. + * @param tensor a Tensor holding the constant value + * @return a constant of the same data type as `tensor` + */ + public Constant constantOf(T tensor) { + return Constant.create(scope, tensor); + } + /** * This op consumes a lock created by `MutexLock`. *

diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java index f76dc1696a7..fc6268f40a2 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java @@ -36,6 +36,10 @@ public interface TensorMapper { /** * Maps the tensor memory to a n-dimensional typed data space. * + *

This method is designed to be invoked internally by this library only, in order to pass the + * native handle of {@code tensor} as {@code nativeHandle} (and since only classes from the + * {@code org.tensorflow} package can retrieve such handle). + * * @param tensor the tensor to map in its raw nature * @param nativeHandle native handle of the tensor * @return a typed tensor of type {@code T} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java index dde040ff141..8aa2499d5ec 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java @@ -29,7 +29,7 @@ import org.tensorflow.types.family.TType; /** - * A tensor which memory has not been mapped. + * A tensor which memory has not been mapped to a data space directly accessible from the JVM. * *

A raw tensor is a minimalist representation of a tensor allocated in native memory by the * TensorFlow runtime library and it controls its lifetime within the current process. The data @@ -37,17 +37,10 @@ * n-dimensional typed space by a {@link TType typed tensor}.

* *

Instances of a RawTensor are not thread-safe and their resource must be released - * by calling {@link #close()} explicitly or implicitly (try-with-resources).

+ * by calling {@link #close()} explicitly or implicitly via try-with-resources.

*/ public final class RawTensor implements Tensor { - /** - * Returns a typed version of this tensor - */ - TType asTypedTensor() { - return dtype.map(this); - } - @Override public DataType dataType() { return dtype; @@ -152,13 +145,28 @@ static RawTensor fromHandle(TF_Tensor handle, EagerSession session) { } /** - * @return native handle to this tensor + * Returns the native handle to this tensor * @throws IllegalStateException if tensor has been closed */ TF_Tensor nativeHandle() { return requireHandle(tensorHandle); } + /** + * Returns a typed reference to this tensor + * + *

In some cases, it is more useful to keep a typed reference to a tensor rather than its raw + * nature to prevent mapping its memory on every access (e.g. when calling {@link Operand#asTensor()}). + * + * @param type of the tensor (must be compatible with the internal representation of this tensor, + * as indicated by {@link #dataType()}) + * @return typed reference to this tensor + * @throws ClassCastException if {@code T} is not compatible type with {@link #dataType()} + */ + T asTypedTensor() { + return (T)dtype.map(this); + } + private static TF_Tensor requireHandle(TF_Tensor handle) { if (handle == null || handle.isNull()) { throw new IllegalStateException("close() was called on the Tensor"); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index aa4618db3e5..bccdf698608 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -46,9 +46,8 @@ * try (Tensor t = Tensor.of(...)) { * doSomethingWith(t); * } - * - *

Instances of a Tensor are not thread-safe. * }

+ *

Instances of a Tensor are not thread-safe. */ public interface Tensor extends Shaped, AutoCloseable { @@ -88,7 +87,7 @@ static T of(DataType dtype, Shape shape) { static T of(DataType dtype, Shape shape, long size) { RawTensor tensor = RawTensor.allocate(dtype, shape, size); try { - return dtype.map(tensor); + return tensor.asTypedTensor(); } catch (Exception e) { tensor.close(); throw e; @@ -130,7 +129,7 @@ static T of(DataType dtype, Shape shape, Consumer dataIn * size for the tensor is explicitly set instead of being computed from the datatype and shape. * *

This could be useful for tensor types that stores data but also metadata in the tensor memory, - * such as lookup table in a tensor of strings. + * such as the lookup table in a tensor of strings. * * @param the tensor element type * @param dtype datatype of the tensor @@ -148,7 +147,7 @@ static T of(DataType dtype, Shape shape, long size, Consume try { dataInitializer.accept(tensor); return tensor; - } catch (Throwable t) { + } catch (Exception t) { tensor.close(); throw t; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java index ff1c990d3ea..3be83d8173e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java @@ -1278,13 +1278,16 @@ public static Constant tensorOf(Scope scope, Shape shape) { } /** - * Capture a {@code tensor} by making a constant copy of it. + * Create a constant by making an immutable copy of {@code tensor}. + * + *

Note: this endpoint cannot be simply called {@code constant} since it will conflict with + * other endpoints accepting an NdArray in parameter {e.g. {@link #tensorOf(Scope, FloatNdArray)}}. * * @param scope is a scope used to add the underlying operation. * @param tensor a Tensor holding the constant value * @return a constant of the same data type as `tensor` */ - @Endpoint(name = "capture") // Cannot be "constant" since it will conflict with other endpoints accepting an NdArray + @Endpoint(name = "constantOf") public static Constant create(Scope scope, T tensor) { return new Constant<>( scope diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java index 304c070cc8e..21d275296c8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java @@ -22,12 +22,12 @@ /** * Common interface for all typed tensors. * - *

Typed tensors wraps a {@link RawTensor} by mapping their native memory to a n-dimensional - * data space allowing direct I/O access from the JVM.

+ *

Typed tensors wrap a {@link org.tensorflow.RawTensor RawTensor} by mapping their native memory + * to a n-dimensional data space allowing direct I/O access from the JVM.

* *

Subinterfaces of {@code TType} are propagated as a generic parameter to various entities of * TensorFlow to identify the type of the tensor they carry. For example, a - * {@link org.tensorflow.Operand Operand} is an operand which outputs is a 32-bit floating + * {@link org.tensorflow.Operand Operand} is an operand which outputs a 32-bit floating * point tensor. This parameter ensure type-compatibility between operands of a computation at * compile-time. For example: * @@ -41,6 +41,22 @@ * tf.math.add(c1, c2); // OK * tf.math.add(c1, c3); // Compilation failure * }

+ * + *

Even if all typed tensors implements somehow {@link org.tensorflow.ndarray.NdArray NdArray} + * to provide access to their data, {@code TType} deliberately does not extend directly from this + * interface, for the following reasons: + *

    + *
  • Implementing {@code NdArray} at this level could only expose boxed-type accessors, which + * are less performant than their primitive equivalent, only exposed by subinterfaces of + * {@code NdArray} (e.g. {@code FloatNdArray}). + *
  • + *
  • {@code TType} would need to carry a new generic parameter for typing the {@code NdArray}, + * which will increase the verbosity in the signature of any method accepting or returning + * an instance of this interface, which is very common. + *
  • + *
+ * Therefore, enforcing the user to cast a reference of {@code TType} in a concrete tensor type before + * accessing its data guarantees better performance and improves readability. */ public interface TType extends Tensor { diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java index 16a644aced1..5dd6903d913 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -154,7 +154,7 @@ public void createFromTensorsInEagerMode() throws IOException { assertEquals(c1.asTensor(), t); // A different endpoint for capturing a tensor as a constant, which supports all data types - Constant c2 = tf.capture(t); + Constant c2 = tf.constantOf(t); assertEquals(c2.asTensor(), t); assertEquals(c1.asTensor(), c2.asTensor()); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java index 5333b8e0d33..6bab99095e4 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -48,7 +48,7 @@ public void createGradients() { assertEquals(2, grads.dy().size()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = + AutoCloseableList outputs = new AutoCloseableList<>( sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) { @@ -75,7 +75,7 @@ public void createGradientsWithSum() { assertEquals(1, grads.dy().size()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = + AutoCloseableList outputs = new AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) { assertEquals(114.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); @@ -101,7 +101,7 @@ public void createGradientsWithInitialValues() { assertEquals(1, grads1.dy().size()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = + AutoCloseableList outputs = new AutoCloseableList<>( sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) { diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java index 24390ae0a91..faddc7c5826 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java @@ -46,11 +46,11 @@ public void initializeTensorsWithZeros() { // Initialize tensor memory with zeros and take a snapshot data.scalars().forEach(scalar -> ((NdArray)scalar).setObject(valueOf(0))); - Constant x = tf.capture(tensor); + Constant x = tf.constantOf(tensor); // Initialize the same tensor memory with ones and take a snapshot data.scalars().forEach(scalar -> ((NdArray)scalar).setObject(valueOf(1))); - Constant y = tf.capture(tensor); + Constant y = tf.constantOf(tensor); // Subtract y from x and validate the result Sub sub = tf.math.sub(x, y); @@ -94,7 +94,7 @@ public void setAndCompute() { try (EagerSession session = EagerSession.create()) { Ops tf = Ops.create(session); - Add add = tf.math.add(tf.capture(tensor), tf.capture(tensor)); + Add add = tf.math.add(tf.constantOf(tensor), tf.constantOf(tensor)); NdArray result = (NdArray)add.asTensor(); assertEquals(valueOf(0), result.getObject(0, 0)); From 8773ccc1981e4bc934af4031258df6a41a460a09 Mon Sep 17 00:00:00 2001 From: klessard Date: Sun, 13 Dec 2020 16:00:13 -0500 Subject: [PATCH 3/3] Rebase on master --- .../java/org/tensorflow/DeviceSpecTest.java | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java index 49f64931307..e4340da3275 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java @@ -53,9 +53,9 @@ public void withDeviceMethod() { .abs(aOps) .asOutput(); - try (AutoCloseableList> t = + try (AutoCloseableList t = new AutoCloseableList<>(session.runner().fetch(absOps).run())) { - assertEquals(1, t.get(0).rawData().asInts().getObject(0)); + assertEquals(1, ((TInt32)t.get(0)).getInt()); } } } @@ -85,9 +85,9 @@ public void withEmptyDeviceSpec() { .abs(aOps) .asOutput(); - try (AutoCloseableList> t = + try (AutoCloseableList t = new AutoCloseableList<>(session.runner().fetch(absOps).run())) { - assertEquals(1, t.get(0).rawData().asInts().getObject(0)); + assertEquals(1, ((TInt32)t.get(0)).getInt()); } } } @@ -131,9 +131,9 @@ public void withTwoScopes() { .mul(absOps, bOps) .asOutput(); - try (AutoCloseableList> t = + try (AutoCloseableList t = new AutoCloseableList<>(session.runner().fetch(mulOps).run())) { - assertEquals(10, t.get(0).rawData().asInts().getObject(0)); + assertEquals(10, ((TInt32)t.get(0)).getInt()); } } } @@ -179,7 +179,7 @@ public void withIncorrectDeviceSpec() { .mul(absOps, bOps) .asOutput(); - try (AutoCloseableList> t = + try (AutoCloseableList t = new AutoCloseableList<>(session.runner().fetch(mulOps).run())) { fail(); } catch (TFInvalidArgumentException e) { @@ -212,9 +212,9 @@ public void withDeviceSpecInScope() { .abs(aOps) .asOutput(); - try (AutoCloseableList> t = + try (AutoCloseableList t = new AutoCloseableList<>(session.runner().fetch(absOps).run())) { - assertEquals(1, t.get(0).rawData().asInts().getObject(0)); + assertEquals(1, ((TInt32)t.get(0)).getInt()); } } }