Skip to content

Commit 1d44f05

Browse files
karllessardrnett
authored andcommitted
[Type Refactor] Merge TType and Tensor instances as a single entity (tensorflow#160)
* Merge TType and Tensor instances as a single entity * Rectify documentation based on PR review * Rebase on master
1 parent fa7426e commit 1d44f05

File tree

69 files changed

+1738
-1619
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+1738
-1619
lines changed

tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_specs.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,7 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
152152
types = MakeTypePair(Type::Class("Shape", "org.tensorflow.ndarray"));
153153

154154
} else if (attr_type == "tensor") {
155-
types = MakeTypePair(Type::Class("Tensor", "org.tensorflow")
156-
.add_parameter(Type::Wildcard()));
155+
types = MakeTypePair(Type::Class("Tensor", "org.tensorflow"));
157156

158157
} else if (attr_type == "type") {
159158
Type type = *iterable_out ? Type::Wildcard() : NextGeneric();

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DebuggingOps.java

Lines changed: 0 additions & 50 deletions
This file was deleted.

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.tensorflow.EagerSession;
2525
import org.tensorflow.ExecutionEnvironment;
2626
import org.tensorflow.Operand;
27-
import org.tensorflow.Tensor;
2827
import org.tensorflow.ndarray.BooleanNdArray;
2928
import org.tensorflow.ndarray.ByteNdArray;
3029
import org.tensorflow.ndarray.DoubleNdArray;
@@ -349,10 +348,10 @@ public final class Ops {
349348

350349
public final SignalOps signal;
351350

352-
public final TrainOps train;
353-
354351
public final QuantizationOps quantization;
355352

353+
public final TrainOps train;
354+
356355
private final Scope scope;
357356

358357
private Ops(Scope scope) {
@@ -374,8 +373,8 @@ private Ops(Scope scope) {
374373
math = new MathOps(this);
375374
audio = new AudioOps(this);
376375
signal = new SignalOps(this);
377-
train = new TrainOps(this);
378376
quantization = new QuantizationOps(this);
377+
train = new TrainOps(this);
379378
}
380379

381380
/**
@@ -1708,17 +1707,6 @@ public Constant<TInt64> constant(Shape shape) {
17081707
return Constant.tensorOf(scope, shape);
17091708
}
17101709

1711-
/**
1712-
* Create a constant from a Tensor.
1713-
*
1714-
* @param scope is a scope used to add the underlying operation.
1715-
* @param tensor a Tensor holding the constant value
1716-
* @return a constant of the same data type as `tensor`
1717-
*/
1718-
public <T extends TType> Constant<T> constant(Tensor<T> tensor) {
1719-
return Constant.create(scope, tensor);
1720-
}
1721-
17221710
/**
17231711
* Creates a constant of {@code String} elements, using the given charset.
17241712
*
@@ -1879,6 +1867,20 @@ public <T extends TType> Constant<T> constant(DataType<T> type, Shape shape,
18791867
return Constant.tensorOf(scope, type, shape, data);
18801868
}
18811869

1870+
/**
1871+
* Create a constant by making an immutable copy of {@code tensor}.
1872+
*
1873+
* <p>Note: this endpoint cannot be simply called {@code constant} since it will conflict with
1874+
* other endpoints accepting an NdArray in parameter {e.g. {@link #tensorOf(Scope, FloatNdArray)}}.
1875+
*
1876+
* @param scope is a scope used to add the underlying operation.
1877+
* @param tensor a Tensor holding the constant value
1878+
* @return a constant of the same data type as `tensor`
1879+
*/
1880+
public <T extends TType> Constant<T> constantOf(T tensor) {
1881+
return Constant.create(scope, tensor);
1882+
}
1883+
18821884
/**
18831885
* This op consumes a lock created by `MutexLock`.
18841886
* <p>

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImageSummary.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,13 @@ public Options maxImages(Long maxImages) {
9999
/**
100100
* @param badColor Color to use for pixels with non-finite values.
101101
*/
102-
public Options badColor(Tensor<?> badColor) {
102+
public Options badColor(Tensor badColor) {
103103
this.badColor = badColor;
104104
return this;
105105
}
106106

107107
private Long maxImages;
108-
private Tensor<?> badColor;
108+
private Tensor badColor;
109109

110110
private Options() {
111111
}
@@ -150,7 +150,7 @@ public static Options maxImages(Long maxImages) {
150150
/**
151151
* @param badColor Color to use for pixels with non-finite values.
152152
*/
153-
public static Options badColor(Tensor<?> badColor) {
153+
public static Options badColor(Tensor badColor) {
154154
return new Options().badColor(badColor);
155155
}
156156

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,5 @@ public String toString() {
8686
* @param outputIdx index of the output of this operation
8787
* @return output tensor
8888
*/
89-
abstract Tensor<?> tensor(int outputIdx);
89+
abstract Tensor tensor(int outputIdx);
9090
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
*
3535
* <pre>{@code
3636
* ConcreteFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
37-
* Map<String, Tensor<?>> outputTensorMap = myFunction.call(inputTensorMap);
37+
* Map<String, Tensor> outputTensorMap = myFunction.call(inputTensorMap);
3838
* }</pre>
3939
*/
4040
public class ConcreteFunction implements AutoCloseable {
@@ -61,8 +61,8 @@ public class ConcreteFunction implements AutoCloseable {
6161
*
6262
* public static void main(String args[]) {
6363
* try (ConcreteFunction function = ConcreteFunction.create(MyModel::addTwo);
64-
* Tensor<TFloat32> x = TFloat32.scalarOf(2.0f)) {
65-
* assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
64+
* TFloat32 x = TFloat32.scalarOf(2.0f)) {
65+
* assertEquals(4.0f, ((TFloat32)function.call(x)).getFloat());
6666
* }
6767
* }
6868
* }
@@ -97,8 +97,8 @@ public static ConcreteFunction create(Function<Ops, Signature> functionBuilder)
9797
* Signature signature = Signature.builder().input("x", input).output("y", output).build();
9898
*
9999
* try (ConcreteFunction f = ConcreteFunction.create(signature, g);
100-
* Tensor<TFloat32> x = TFloat32.scalarOf(2.0f)) {
101-
* assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
100+
* TFloat32 x = TFloat32.scalarOf(2.0f)) {
101+
* assertEquals(4.0f, ((TFloat32)function.call(x)).getFloat());
102102
* }
103103
* // Graph g is still valid at this point
104104
* }
@@ -129,8 +129,8 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
129129
* // Auto-closing the function just as an example but this is not required since it has
130130
* // no effect
131131
* try (ConcreteFunction f = ConcreteFunction.create(signature, s);
132-
* Tensor<TFloat32> t = TFloat32.scalarOf(2.0f)) {
133-
* assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
132+
* TFloat32 t = TFloat32.scalarOf(2.0f)) {
133+
* assertEquals(4.0f, ((TFloat32)function.call(x)).getFloat());
134134
* }
135135
* // Session s is still valid at this point
136136
* }
@@ -163,14 +163,14 @@ public Signature signature() {
163163
* @return output tensors resulting from the execution of the function,
164164
* mapped by their signature name
165165
*/
166-
public Map<String, Tensor<?>> call(Map<String, Tensor<?>> arguments)
166+
public Map<String, Tensor> call(Map<String, Tensor> arguments)
167167
throws IllegalArgumentException {
168168

169169
final SignatureDef signatureDef = signature.asSignatureDef();
170170
final Session.Runner runner = session.runner();
171171

172172
signatureDef.getInputsMap().forEach((argName, t) -> {
173-
Tensor<?> tensor = arguments.get(argName);
173+
Tensor tensor = arguments.get(argName);
174174
if (tensor == null) {
175175
throw new IllegalArgumentException(String.format("Missing argument [%s]", argName));
176176
}
@@ -180,10 +180,10 @@ public Map<String, Tensor<?>> call(Map<String, Tensor<?>> arguments)
180180
Map<String, TensorInfo> outputToNode = signatureDef.getOutputsMap();
181181
outputToNode.values().forEach(t -> runner.fetch(t.getName()));
182182

183-
List<Tensor<?>> resultTensors = runner.run();
183+
List<Tensor> resultTensors = runner.run();
184184
try {
185-
ListIterator<Tensor<?>> resultTensorIter = resultTensors.listIterator();
186-
Map<String, Tensor<?>> returnMap = new HashMap<String, Tensor<?>>();
185+
ListIterator<Tensor> resultTensorIter = resultTensors.listIterator();
186+
Map<String, Tensor> returnMap = new HashMap<String, Tensor>();
187187

188188
// Use the output names as present in the signature definition
189189
for (String nodeName: outputToNode.keySet()) {
@@ -193,7 +193,7 @@ public Map<String, Tensor<?>> call(Map<String, Tensor<?>> arguments)
193193

194194
} catch (Exception e) {
195195
// Release tensors before throwing exception
196-
for (Tensor<?> t : resultTensors) {
196+
for (Tensor t : resultTensors) {
197197
t.close();
198198
}
199199
throw e;
@@ -210,7 +210,7 @@ public Map<String, Tensor<?>> call(Map<String, Tensor<?>> arguments)
210210
* @throws IllegalArgumentException if there are multiple input or output parameters defined
211211
* in the function
212212
*/
213-
public Tensor<?> call(Tensor<?> tensor) throws IllegalArgumentException {
213+
public Tensor call(Tensor tensor) throws IllegalArgumentException {
214214
final SignatureDef signatureDef = signature.asSignatureDef();
215215

216216
if (signatureDef.getInputsCount() != 1) {

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
package org.tensorflow;
1717

1818
import org.tensorflow.internal.c_api.TF_Tensor;
19-
import org.tensorflow.ndarray.Shape;
2019
import org.tensorflow.types.TBfloat16;
2120
import org.tensorflow.types.TBool;
2221
import org.tensorflow.types.TFloat16;
@@ -35,13 +34,17 @@ public final class DataType<T extends TType> {
3534
public interface TensorMapper<T> {
3635

3736
/**
38-
* Maps tensor memory to a data structure for manipulating elements of this type.
37+
* Maps the tensor memory to a n-dimensional typed data space.
3938
*
40-
* @param nativeTensor pointer to the native tensor
41-
* @param shape the shape of the tensor
42-
* @return data structure of elements of this type
39+
* <p>This method is designed to be invoked internally by this library only, in order to pass the
40+
* native handle of {@code tensor} as {@code nativeHandle} (and since only classes from the
41+
* {@code org.tensorflow} package can retrieve such handle).
42+
*
43+
* @param tensor the tensor to map in its raw nature
44+
* @param nativeHandle native handle of the tensor
45+
* @return a typed tensor of type {@code T}
4346
*/
44-
T apply(TF_Tensor nativeTensor, Shape shape);
47+
T apply(RawTensor tensor, TF_Tensor nativeHandle);
4548
}
4649

4750
/**
@@ -158,13 +161,13 @@ int nativeCode() {
158161
}
159162

160163
/**
161-
* Maps a tensor to a data structure for manipulating elements of this type.
164+
* Maps a raw tensor to a typed tensor.
162165
*
163166
* @param tensor tensor to map
164167
* @return data structure of elements of this type
165168
*/
166-
T map(Tensor<T> tensor) {
167-
return tensorMapper.apply(tensor.nativeHandle(), tensor.shape());
169+
T map(RawTensor tensor) {
170+
return tensorMapper.apply(tensor, tensor.nativeHandle());
168171
}
169172

170173
private final int nativeCode;

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ public TFE_TensorHandle getUnsafeNativeHandle(int outputIndex) {
9191
public Shape shape(int outputIndex) {
9292
// If the tensor of this output has already been resolved, return its shape.
9393
// Otherwise, retrieve the tensor shape from the native library.
94-
Tensor<?> tensor = outputTensors.get(outputIndex);
94+
Tensor tensor = outputTensors.get(outputIndex);
9595
if (tensor != null) {
9696
return tensor.shape();
9797
}
@@ -107,7 +107,7 @@ public Shape shape(int outputIndex) {
107107
public DataType<?> dtype(int outputIndex) {
108108
// If the tensor of this output has already been resolved, return its datatype.
109109
// Otherwise, retrieve the tensor datatype from the native library.
110-
Tensor<?> tensor = outputTensors.get(outputIndex);
110+
Tensor tensor = outputTensors.get(outputIndex);
111111
if (tensor != null) {
112112
return tensor.dataType();
113113
}
@@ -116,8 +116,8 @@ public DataType<?> dtype(int outputIndex) {
116116
}
117117

118118
@Override
119-
public Tensor<?> tensor(int outputIndex) {
120-
Tensor<?> tensor = outputTensors.get(outputIndex);
119+
public Tensor tensor(int outputIndex) {
120+
Tensor tensor = outputTensors.get(outputIndex);
121121
if (tensor == null) {
122122
tensor = resolveTensor(outputIndex);
123123
}
@@ -127,21 +127,21 @@ public Tensor<?> tensor(int outputIndex) {
127127
private final EagerSession session;
128128
private final String type;
129129
private final String name;
130-
private final AtomicReferenceArray<Tensor<?>> outputTensors;
130+
private final AtomicReferenceArray<Tensor> outputTensors;
131131

132-
private Tensor<?> resolveTensor(int outputIndex) {
132+
private Tensor resolveTensor(int outputIndex) {
133133
// Take an optimistic approach, where we attempt to resolve the output tensor without locking.
134134
// If another thread has resolved it meanwhile, release our copy and reuse the existing one
135135
// instead.
136-
Tensor<?> tensor = resolveTensorHandle(getUnsafeNativeHandle(outputIndex), session);
136+
Tensor tensor = resolveTensorHandle(getUnsafeNativeHandle(outputIndex), session);
137137
if (!outputTensors.compareAndSet(outputIndex, null, tensor)) {
138-
session.detach(tensor.nativeHandle());
138+
session.detach(tensor.asRawTensor().nativeHandle());
139139
tensor = outputTensors.get(outputIndex);
140140
}
141141
return tensor;
142142
}
143143

144-
private TFE_Op opHandle;
144+
private final TFE_Op opHandle;
145145
private final TFE_TensorHandle[] outputHandles;
146146

147147
private static void requireOp(TFE_Op handle) {
@@ -156,13 +156,13 @@ private static void requireTensorHandle(TFE_TensorHandle handle) {
156156
}
157157
}
158158

159-
private static Tensor<?> resolveTensorHandle(TFE_TensorHandle handle, EagerSession session) {
159+
private static Tensor resolveTensorHandle(TFE_TensorHandle handle, EagerSession session) {
160160
requireTensorHandle(handle);
161161
try (PointerScope scope = new PointerScope()) {
162162
TF_Status status = TF_Status.newStatus();
163163
TF_Tensor tensor = TFE_TensorHandleResolve(handle, status).withDeallocator();
164164
status.throwExceptionIfNotOK();
165-
return Tensor.fromHandle(tensor, session);
165+
return RawTensor.fromHandle(tensor, session).asTypedTensor();
166166
}
167167
}
168168

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,13 @@ public EagerOperationBuilder setAttr(String name, DataType<?>[] values) {
175175
}
176176

177177
@Override
178-
public EagerOperationBuilder setAttr(String name, Tensor<?> value) {
179-
setAttrTensor(opHandle, name, value.nativeHandle());
178+
public EagerOperationBuilder setAttr(String name, Tensor value) {
179+
setAttrTensor(opHandle, name, value.asRawTensor().nativeHandle());
180180
return this;
181181
}
182182

183183
@Override
184-
public EagerOperationBuilder setAttr(String name, Tensor<?>[] values) {
184+
public EagerOperationBuilder setAttr(String name, Tensor[] values) {
185185
// TODO (karllessard) could be supported by adding this attribute type in the eager C API
186186
throw new UnsupportedOperationException(
187187
"Tensor list attributes are not supported in eager mode");

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ DataType<?> dtype(int outputIdx) {
158158
}
159159

160160
@Override
161-
Tensor<?> tensor(int outputIdx) {
161+
Tensor tensor(int outputIdx) {
162162
throw new IllegalStateException("Graph tensors must be fetched by running a session");
163163
}
164164

0 commit comments

Comments
 (0)