Skip to content

[Type Refactor] Merge TType and Tensor instances as a single entity #160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
types = MakeTypePair(Type::Class("Shape", "org.tensorflow.ndarray"));

} else if (attr_type == "tensor") {
types = MakeTypePair(Type::Class("Tensor", "org.tensorflow")
.add_parameter(Type::Wildcard()));
types = MakeTypePair(Type::Class("Tensor", "org.tensorflow"));

} else if (attr_type == "type") {
Type type = *iterable_out ? Type::Wildcard() : NextGeneric();
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.tensorflow.EagerSession;
import org.tensorflow.ExecutionEnvironment;
import org.tensorflow.Operand;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.BooleanNdArray;
import org.tensorflow.ndarray.ByteNdArray;
import org.tensorflow.ndarray.DoubleNdArray;
Expand Down Expand Up @@ -349,10 +348,10 @@ public final class Ops {

public final SignalOps signal;

public final TrainOps train;

public final QuantizationOps quantization;

public final TrainOps train;

private final Scope scope;

private Ops(Scope scope) {
Expand All @@ -374,8 +373,8 @@ private Ops(Scope scope) {
math = new MathOps(this);
audio = new AudioOps(this);
signal = new SignalOps(this);
train = new TrainOps(this);
quantization = new QuantizationOps(this);
train = new TrainOps(this);
}

/**
Expand Down Expand Up @@ -1708,17 +1707,6 @@ public Constant<TInt64> constant(Shape shape) {
return Constant.tensorOf(scope, shape);
}

/**
* Create a constant from a Tensor.
*
* @param scope is a scope used to add the underlying operation.
* @param tensor a Tensor holding the constant value
* @return a constant of the same data type as `tensor`
*/
public <T extends TType> Constant<T> constant(Tensor<T> tensor) {
return Constant.create(scope, tensor);
}

/**
* Creates a constant of {@code String} elements, using the given charset.
*
Expand Down Expand Up @@ -1879,6 +1867,20 @@ public <T extends TType> Constant<T> constant(DataType<T> type, Shape shape,
return Constant.tensorOf(scope, type, shape, data);
}

/**
* Create a constant by making an immutable copy of {@code tensor}.
*
* <p>Note: this endpoint cannot be simply called {@code constant} since it will conflict with
* other endpoints accepting an NdArray in parameter {e.g. {@link #tensorOf(Scope, FloatNdArray)}}.
*
* @param scope is a scope used to add the underlying operation.
* @param tensor a Tensor holding the constant value
* @return a constant of the same data type as `tensor`
*/
public <T extends TType> Constant<T> constantOf(T tensor) {
return Constant.create(scope, tensor);
}

/**
* This op consumes a lock created by `MutexLock`.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ public Options maxImages(Long maxImages) {
/**
* @param badColor Color to use for pixels with non-finite values.
*/
public Options badColor(Tensor<?> badColor) {
public Options badColor(Tensor badColor) {
this.badColor = badColor;
return this;
}

private Long maxImages;
private Tensor<?> badColor;
private Tensor badColor;

private Options() {
}
Expand Down Expand Up @@ -150,7 +150,7 @@ public static Options maxImages(Long maxImages) {
/**
* @param badColor Color to use for pixels with non-finite values.
*/
public static Options badColor(Tensor<?> badColor) {
public static Options badColor(Tensor badColor) {
return new Options().badColor(badColor);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,5 @@ public String toString() {
* @param outputIdx index of the output of this operation
* @return output tensor
*/
abstract Tensor<?> tensor(int outputIdx);
abstract Tensor tensor(int outputIdx);
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
*
* <pre>{@code
* ConcreteFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
* Map<String, Tensor<?>> outputTensorMap = myFunction.call(inputTensorMap);
* Map<String, Tensor> outputTensorMap = myFunction.call(inputTensorMap);
* }</pre>
*/
public class ConcreteFunction implements AutoCloseable {
Expand All @@ -61,8 +61,8 @@ public class ConcreteFunction implements AutoCloseable {
*
* public static void main(String args[]) {
* try (ConcreteFunction function = ConcreteFunction.create(MyModel::addTwo);
* Tensor<TFloat32> x = TFloat32.scalarOf(2.0f)) {
* assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
* TFloat32 x = TFloat32.scalarOf(2.0f)) {
* assertEquals(4.0f, ((TFloat32)function.call(x)).getFloat());
* }
* }
* }
Expand Down Expand Up @@ -97,8 +97,8 @@ public static ConcreteFunction create(Function<Ops, Signature> functionBuilder)
* Signature signature = Signature.builder().input("x", input).output("y", output).build();
*
* try (ConcreteFunction f = ConcreteFunction.create(signature, g);
* Tensor<TFloat32> x = TFloat32.scalarOf(2.0f)) {
* assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
* TFloat32 x = TFloat32.scalarOf(2.0f)) {
* assertEquals(4.0f, ((TFloat32)function.call(x)).getFloat());
* }
* // Graph g is still valid at this point
* }
Expand Down Expand Up @@ -129,8 +129,8 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
* // Auto-closing the function just as an example but this is not required since it has
* // no effect
* try (ConcreteFunction f = ConcreteFunction.create(signature, s);
* Tensor<TFloat32> t = TFloat32.scalarOf(2.0f)) {
* assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
* TFloat32 t = TFloat32.scalarOf(2.0f)) {
* assertEquals(4.0f, ((TFloat32)function.call(x)).getFloat());
* }
* // Session s is still valid at this point
* }
Expand Down Expand Up @@ -163,14 +163,14 @@ public Signature signature() {
* @return output tensors resulting from the execution of the function,
* mapped by their signature name
*/
public Map<String, Tensor<?>> call(Map<String, Tensor<?>> arguments)
public Map<String, Tensor> call(Map<String, Tensor> arguments)
throws IllegalArgumentException {

final SignatureDef signatureDef = signature.asSignatureDef();
final Session.Runner runner = session.runner();

signatureDef.getInputsMap().forEach((argName, t) -> {
Tensor<?> tensor = arguments.get(argName);
Tensor tensor = arguments.get(argName);
if (tensor == null) {
throw new IllegalArgumentException(String.format("Missing argument [%s]", argName));
}
Expand All @@ -180,10 +180,10 @@ public Map<String, Tensor<?>> call(Map<String, Tensor<?>> arguments)
Map<String, TensorInfo> outputToNode = signatureDef.getOutputsMap();
outputToNode.values().forEach(t -> runner.fetch(t.getName()));

List<Tensor<?>> resultTensors = runner.run();
List<Tensor> resultTensors = runner.run();
try {
ListIterator<Tensor<?>> resultTensorIter = resultTensors.listIterator();
Map<String, Tensor<?>> returnMap = new HashMap<String, Tensor<?>>();
ListIterator<Tensor> resultTensorIter = resultTensors.listIterator();
Map<String, Tensor> returnMap = new HashMap<String, Tensor>();

// Use the output names as present in the signature definition
for (String nodeName: outputToNode.keySet()) {
Expand All @@ -193,7 +193,7 @@ public Map<String, Tensor<?>> call(Map<String, Tensor<?>> arguments)

} catch (Exception e) {
// Release tensors before throwing exception
for (Tensor<?> t : resultTensors) {
for (Tensor t : resultTensors) {
t.close();
}
throw e;
Expand All @@ -210,7 +210,7 @@ public Map<String, Tensor<?>> call(Map<String, Tensor<?>> arguments)
* @throws IllegalArgumentException if there are multiple input or output parameters defined
* in the function
*/
public Tensor<?> call(Tensor<?> tensor) throws IllegalArgumentException {
public Tensor call(Tensor tensor) throws IllegalArgumentException {
final SignatureDef signatureDef = signature.asSignatureDef();

if (signatureDef.getInputsCount() != 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package org.tensorflow;

import org.tensorflow.internal.c_api.TF_Tensor;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.types.TBfloat16;
import org.tensorflow.types.TBool;
import org.tensorflow.types.TFloat16;
Expand All @@ -35,13 +34,17 @@ public final class DataType<T extends TType> {
public interface TensorMapper<T> {

/**
* Maps tensor memory to a data structure for manipulating elements of this type.
* Maps the tensor memory to a n-dimensional typed data space.
*
* @param nativeTensor pointer to the native tensor
* @param shape the shape of the tensor
* @return data structure of elements of this type
* <p>This method is designed to be invoked internally by this library only, in order to pass the
* native handle of {@code tensor} as {@code nativeHandle} (and since only classes from the
* {@code org.tensorflow} package can retrieve such handle).
*
* @param tensor the tensor to map in its raw nature
* @param nativeHandle native handle of the tensor
* @return a typed tensor of type {@code T}
*/
T apply(TF_Tensor nativeTensor, Shape shape);
T apply(RawTensor tensor, TF_Tensor nativeHandle);
}

/**
Expand Down Expand Up @@ -158,13 +161,13 @@ int nativeCode() {
}

/**
* Maps a tensor to a data structure for manipulating elements of this type.
* Maps a raw tensor to a typed tensor.
*
* @param tensor tensor to map
* @return data structure of elements of this type
*/
T map(Tensor<T> tensor) {
return tensorMapper.apply(tensor.nativeHandle(), tensor.shape());
T map(RawTensor tensor) {
return tensorMapper.apply(tensor, tensor.nativeHandle());
}

private final int nativeCode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public TFE_TensorHandle getUnsafeNativeHandle(int outputIndex) {
public Shape shape(int outputIndex) {
// If the tensor of this output has already been resolved, return its shape.
// Otherwise, retrieve the tensor shape from the native library.
Tensor<?> tensor = outputTensors.get(outputIndex);
Tensor tensor = outputTensors.get(outputIndex);
if (tensor != null) {
return tensor.shape();
}
Expand All @@ -107,7 +107,7 @@ public Shape shape(int outputIndex) {
public DataType<?> dtype(int outputIndex) {
// If the tensor of this output has already been resolved, return its datatype.
// Otherwise, retrieve the tensor datatype from the native library.
Tensor<?> tensor = outputTensors.get(outputIndex);
Tensor tensor = outputTensors.get(outputIndex);
if (tensor != null) {
return tensor.dataType();
}
Expand All @@ -116,8 +116,8 @@ public DataType<?> dtype(int outputIndex) {
}

@Override
public Tensor<?> tensor(int outputIndex) {
Tensor<?> tensor = outputTensors.get(outputIndex);
public Tensor tensor(int outputIndex) {
Tensor tensor = outputTensors.get(outputIndex);
if (tensor == null) {
tensor = resolveTensor(outputIndex);
}
Expand All @@ -127,21 +127,21 @@ public Tensor<?> tensor(int outputIndex) {
private final EagerSession session;
private final String type;
private final String name;
private final AtomicReferenceArray<Tensor<?>> outputTensors;
private final AtomicReferenceArray<Tensor> outputTensors;

private Tensor<?> resolveTensor(int outputIndex) {
private Tensor resolveTensor(int outputIndex) {
// Take an optimistic approach, where we attempt to resolve the output tensor without locking.
// If another thread has resolved it meanwhile, release our copy and reuse the existing one
// instead.
Tensor<?> tensor = resolveTensorHandle(getUnsafeNativeHandle(outputIndex), session);
Tensor tensor = resolveTensorHandle(getUnsafeNativeHandle(outputIndex), session);
if (!outputTensors.compareAndSet(outputIndex, null, tensor)) {
session.detach(tensor.nativeHandle());
session.detach(tensor.asRawTensor().nativeHandle());
tensor = outputTensors.get(outputIndex);
}
return tensor;
}

private TFE_Op opHandle;
private final TFE_Op opHandle;
private final TFE_TensorHandle[] outputHandles;

private static void requireOp(TFE_Op handle) {
Expand All @@ -156,13 +156,13 @@ private static void requireTensorHandle(TFE_TensorHandle handle) {
}
}

private static Tensor<?> resolveTensorHandle(TFE_TensorHandle handle, EagerSession session) {
private static Tensor resolveTensorHandle(TFE_TensorHandle handle, EagerSession session) {
requireTensorHandle(handle);
try (PointerScope scope = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
TF_Tensor tensor = TFE_TensorHandleResolve(handle, status).withDeallocator();
status.throwExceptionIfNotOK();
return Tensor.fromHandle(tensor, session);
return RawTensor.fromHandle(tensor, session).asTypedTensor();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,13 @@ public EagerOperationBuilder setAttr(String name, DataType<?>[] values) {
}

@Override
public EagerOperationBuilder setAttr(String name, Tensor<?> value) {
setAttrTensor(opHandle, name, value.nativeHandle());
public EagerOperationBuilder setAttr(String name, Tensor value) {
setAttrTensor(opHandle, name, value.asRawTensor().nativeHandle());
return this;
}

@Override
public EagerOperationBuilder setAttr(String name, Tensor<?>[] values) {
public EagerOperationBuilder setAttr(String name, Tensor[] values) {
// TODO (karllessard) could be supported by adding this attribute type in the eager C API
throw new UnsupportedOperationException(
"Tensor list attributes are not supported in eager mode");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ DataType<?> dtype(int outputIdx) {
}

@Override
Tensor<?> tensor(int outputIdx) {
Tensor tensor(int outputIdx) {
throw new IllegalStateException("Graph tensors must be fetched by running a session");
}

Expand Down
Loading