Skip to content

Commit 0225ab6

Browse files
committed
Merge TType and Tensor instances as a single entity
1 parent fc3d960 commit 0225ab6

File tree

68 files changed

+1695
-1609
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+1695
-1609
lines changed

tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_specs.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,7 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
152152
types = MakeTypePair(Type::Class("Shape", "org.tensorflow.ndarray"));
153153

154154
} else if (attr_type == "tensor") {
155-
types = MakeTypePair(Type::Class("Tensor", "org.tensorflow")
156-
.add_parameter(Type::Wildcard()));
155+
types = MakeTypePair(Type::Class("Tensor", "org.tensorflow"));
157156

158157
} else if (attr_type == "type") {
159158
Type type = *iterable_out ? Type::Wildcard() : NextGeneric();

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DebuggingOps.java

Lines changed: 0 additions & 50 deletions
This file was deleted.

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import org.tensorflow.EagerSession;
2424
import org.tensorflow.ExecutionEnvironment;
2525
import org.tensorflow.Operand;
26-
import org.tensorflow.Tensor;
2726
import org.tensorflow.ndarray.BooleanNdArray;
2827
import org.tensorflow.ndarray.ByteNdArray;
2928
import org.tensorflow.ndarray.DoubleNdArray;
@@ -347,10 +346,10 @@ public final class Ops {
347346

348347
public final SignalOps signal;
349348

350-
public final TrainOps train;
351-
352349
public final QuantizationOps quantization;
353350

351+
public final TrainOps train;
352+
354353
private final Scope scope;
355354

356355
private Ops(Scope scope) {
@@ -372,8 +371,8 @@ private Ops(Scope scope) {
372371
math = new MathOps(this);
373372
audio = new AudioOps(this);
374373
signal = new SignalOps(this);
375-
train = new TrainOps(this);
376374
quantization = new QuantizationOps(this);
375+
train = new TrainOps(this);
377376
}
378377

379378
/**
@@ -1071,6 +1070,17 @@ public <T extends TNumber> Bucketize bucketize(Operand<T> input, List<Float> bou
10711070
return Bucketize.create(scope, input, boundaries);
10721071
}
10731072

1073+
/**
1074+
* Capture a {@code tensor} by making a constant copy of it.
1075+
*
1076+
* @param scope is a scope used to add the underlying operation.
1077+
* @param tensor a Tensor holding the constant value
1078+
* @return a constant of the same data type as `tensor`
1079+
*/
1080+
public <T extends TType> Constant<T> capture(T tensor) {
1081+
return Constant.create(scope, tensor);
1082+
}
1083+
10741084
/**
10751085
* Clips tensor values to a specified min and max.
10761086
* <p>
@@ -1706,17 +1716,6 @@ public Constant<TInt64> constant(Shape shape) {
17061716
return Constant.tensorOf(scope, shape);
17071717
}
17081718

1709-
/**
1710-
* Create a constant from a Tensor.
1711-
*
1712-
* @param scope is a scope used to add the underlying operation.
1713-
* @param tensor a Tensor holding the constant value
1714-
* @return a constant of the same data type as `tensor`
1715-
*/
1716-
public <T extends TType> Constant<T> constant(Tensor<T> tensor) {
1717-
return Constant.create(scope, tensor);
1718-
}
1719-
17201719
/**
17211720
* Creates a constant of {@code String} elements, using the given charset.
17221721
*

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImageSummary.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,13 @@ public Options maxImages(Long maxImages) {
9999
/**
100100
* @param badColor Color to use for pixels with non-finite values.
101101
*/
102-
public Options badColor(Tensor<?> badColor) {
102+
public Options badColor(Tensor badColor) {
103103
this.badColor = badColor;
104104
return this;
105105
}
106106

107107
private Long maxImages;
108-
private Tensor<?> badColor;
108+
private Tensor badColor;
109109

110110
private Options() {
111111
}
@@ -150,7 +150,7 @@ public static Options maxImages(Long maxImages) {
150150
/**
151151
* @param badColor Color to use for pixels with non-finite values.
152152
*/
153-
public static Options badColor(Tensor<?> badColor) {
153+
public static Options badColor(Tensor badColor) {
154154
return new Options().badColor(badColor);
155155
}
156156

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,5 @@ public String toString() {
8686
* @param outputIdx index of the output of this operation
8787
* @return output tensor
8888
*/
89-
abstract Tensor<?> tensor(int outputIdx);
89+
abstract Tensor tensor(int outputIdx);
9090
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
*
3535
* <pre>{@code
3636
* ConcreteFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
37-
* Map<String, Tensor<?>> outputTensorMap = myFunction.call(inputTensorMap);
37+
* Map<String, Tensor> outputTensorMap = myFunction.call(inputTensorMap);
3838
* }</pre>
3939
*/
4040
public class ConcreteFunction implements AutoCloseable {
@@ -61,8 +61,8 @@ public class ConcreteFunction implements AutoCloseable {
6161
*
6262
* public static void main(String args[]) {
6363
* try (ConcreteFunction function = ConcreteFunction.create(MyModel::addTwo);
64-
* Tensor<TFloat32> x = TFloat32.scalarOf(2.0f)) {
65-
* assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
64+
* TFloat32 x = TFloat32.scalarOf(2.0f)) {
65+
* assertEquals(4.0f, ((TFloat32)function.call(x)).getFloat());
6666
* }
6767
* }
6868
* }
@@ -97,8 +97,8 @@ public static ConcreteFunction create(Function<Ops, Signature> functionBuilder)
9797
* Signature signature = Signature.builder().input("x", input).output("y", output).build();
9898
*
9999
* try (ConcreteFunction f = ConcreteFunction.create(signature, g);
100-
* Tensor<TFloat32> x = TFloat32.scalarOf(2.0f)) {
101-
* assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
100+
* TFloat32 x = TFloat32.scalarOf(2.0f)) {
101+
* assertEquals(4.0f, ((TFloat32)function.call(x)).getFloat());
102102
* }
103103
* // Graph g is still valid at this point
104104
* }
@@ -129,8 +129,8 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
129129
* // Auto-closing the function just as an example but this is not required since it has
130130
* // no effect
131131
* try (ConcreteFunction f = ConcreteFunction.create(signature, s);
132-
* Tensor<TFloat32> t = TFloat32.scalarOf(2.0f)) {
133-
* assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
132+
* TFloat32 t = TFloat32.scalarOf(2.0f)) {
133+
* assertEquals(4.0f, ((TFloat32)function.call(x)).getFloat());
134134
* }
135135
* // Session s is still valid at this point
136136
* }
@@ -163,14 +163,14 @@ public Signature signature() {
163163
* @return output tensors resulting from the execution of the function,
164164
* mapped by their signature name
165165
*/
166-
public Map<String, Tensor<?>> call(Map<String, Tensor<?>> arguments)
166+
public Map<String, Tensor> call(Map<String, Tensor> arguments)
167167
throws IllegalArgumentException {
168168

169169
final SignatureDef signatureDef = signature.asSignatureDef();
170170
final Session.Runner runner = session.runner();
171171

172172
signatureDef.getInputsMap().forEach((argName, t) -> {
173-
Tensor<?> tensor = arguments.get(argName);
173+
Tensor tensor = arguments.get(argName);
174174
if (tensor == null) {
175175
throw new IllegalArgumentException(String.format("Missing argument [%s]", argName));
176176
}
@@ -180,10 +180,10 @@ public Map<String, Tensor<?>> call(Map<String, Tensor<?>> arguments)
180180
Map<String, TensorInfo> outputToNode = signatureDef.getOutputsMap();
181181
outputToNode.values().forEach(t -> runner.fetch(t.getName()));
182182

183-
List<Tensor<?>> resultTensors = runner.run();
183+
List<Tensor> resultTensors = runner.run();
184184
try {
185-
ListIterator<Tensor<?>> resultTensorIter = resultTensors.listIterator();
186-
Map<String, Tensor<?>> returnMap = new HashMap<String, Tensor<?>>();
185+
ListIterator<Tensor> resultTensorIter = resultTensors.listIterator();
186+
Map<String, Tensor> returnMap = new HashMap<String, Tensor>();
187187

188188
// Use the output names as present in the signature definition
189189
for (String nodeName: outputToNode.keySet()) {
@@ -193,7 +193,7 @@ public Map<String, Tensor<?>> call(Map<String, Tensor<?>> arguments)
193193

194194
} catch (Exception e) {
195195
// Release tensors before throwing exception
196-
for (Tensor<?> t : resultTensors) {
196+
for (Tensor t : resultTensors) {
197197
t.close();
198198
}
199199
throw e;
@@ -210,7 +210,7 @@ public Map<String, Tensor<?>> call(Map<String, Tensor<?>> arguments)
210210
* @throws IllegalArgumentException if there are multiple input or output parameters defined
211211
* in the function
212212
*/
213-
public Tensor<?> call(Tensor<?> tensor) throws IllegalArgumentException {
213+
public Tensor call(Tensor tensor) throws IllegalArgumentException {
214214
final SignatureDef signatureDef = signature.asSignatureDef();
215215

216216
if (signatureDef.getInputsCount() != 1) {

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
package org.tensorflow;
1717

1818
import org.tensorflow.internal.c_api.TF_Tensor;
19-
import org.tensorflow.ndarray.Shape;
2019
import org.tensorflow.types.TBfloat16;
2120
import org.tensorflow.types.TBool;
2221
import org.tensorflow.types.TFloat16;
@@ -35,13 +34,13 @@ public final class DataType<T extends TType> {
3534
public interface TensorMapper<T> {
3635

3736
/**
38-
* Maps tensor memory to a data structure for manipulating elements of this type.
37+
* Maps the tensor memory to a n-dimensional typed data space.
3938
*
40-
* @param nativeTensor pointer to the native tensor
41-
* @param shape the shape of the tensor
42-
* @return data structure of elements of this type
39+
* @param tensor the tensor to map in its raw nature
40+
* @param nativeHandle native handle of the tensor
41+
* @return a typed tensor of type {@code T}
4342
*/
44-
T apply(TF_Tensor nativeTensor, Shape shape);
43+
T apply(RawTensor tensor, TF_Tensor nativeHandle);
4544
}
4645

4746
/**
@@ -158,13 +157,13 @@ int nativeCode() {
158157
}
159158

160159
/**
161-
* Maps a tensor to a data structure for manipulating elements of this type.
160+
* Maps a raw tensor to a typed tensor.
162161
*
163162
* @param tensor tensor to map
164163
* @return data structure of elements of this type
165164
*/
166-
T map(Tensor<T> tensor) {
167-
return tensorMapper.apply(tensor.nativeHandle(), tensor.shape());
165+
T map(RawTensor tensor) {
166+
return tensorMapper.apply(tensor, tensor.nativeHandle());
168167
}
169168

170169
private final int nativeCode;

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ public TFE_TensorHandle getUnsafeNativeHandle(int outputIndex) {
9191
public Shape shape(int outputIndex) {
9292
// If the tensor of this output has already been resolved, return its shape.
9393
// Otherwise, retrieve the tensor shape from the native library.
94-
Tensor<?> tensor = outputTensors.get(outputIndex);
94+
Tensor tensor = outputTensors.get(outputIndex);
9595
if (tensor != null) {
9696
return tensor.shape();
9797
}
@@ -107,7 +107,7 @@ public Shape shape(int outputIndex) {
107107
public DataType<?> dtype(int outputIndex) {
108108
// If the tensor of this output has already been resolved, return its datatype.
109109
// Otherwise, retrieve the tensor datatype from the native library.
110-
Tensor<?> tensor = outputTensors.get(outputIndex);
110+
Tensor tensor = outputTensors.get(outputIndex);
111111
if (tensor != null) {
112112
return tensor.dataType();
113113
}
@@ -116,8 +116,8 @@ public DataType<?> dtype(int outputIndex) {
116116
}
117117

118118
@Override
119-
public Tensor<?> tensor(int outputIndex) {
120-
Tensor<?> tensor = outputTensors.get(outputIndex);
119+
public Tensor tensor(int outputIndex) {
120+
Tensor tensor = outputTensors.get(outputIndex);
121121
if (tensor == null) {
122122
tensor = resolveTensor(outputIndex);
123123
}
@@ -127,21 +127,21 @@ public Tensor<?> tensor(int outputIndex) {
127127
private final EagerSession session;
128128
private final String type;
129129
private final String name;
130-
private final AtomicReferenceArray<Tensor<?>> outputTensors;
130+
private final AtomicReferenceArray<Tensor> outputTensors;
131131

132-
private Tensor<?> resolveTensor(int outputIndex) {
132+
private Tensor resolveTensor(int outputIndex) {
133133
// Take an optimistic approach, where we attempt to resolve the output tensor without locking.
134134
// If another thread has resolved it meanwhile, release our copy and reuse the existing one
135135
// instead.
136-
Tensor<?> tensor = resolveTensorHandle(getUnsafeNativeHandle(outputIndex), session);
136+
Tensor tensor = resolveTensorHandle(getUnsafeNativeHandle(outputIndex), session);
137137
if (!outputTensors.compareAndSet(outputIndex, null, tensor)) {
138-
session.detach(tensor.nativeHandle());
138+
session.detach(tensor.asRawTensor().nativeHandle());
139139
tensor = outputTensors.get(outputIndex);
140140
}
141141
return tensor;
142142
}
143143

144-
private TFE_Op opHandle;
144+
private final TFE_Op opHandle;
145145
private final TFE_TensorHandle[] outputHandles;
146146

147147
private static void requireOp(TFE_Op handle) {
@@ -156,13 +156,13 @@ private static void requireTensorHandle(TFE_TensorHandle handle) {
156156
}
157157
}
158158

159-
private static Tensor<?> resolveTensorHandle(TFE_TensorHandle handle, EagerSession session) {
159+
private static Tensor resolveTensorHandle(TFE_TensorHandle handle, EagerSession session) {
160160
requireTensorHandle(handle);
161161
try (PointerScope scope = new PointerScope()) {
162162
TF_Status status = TF_Status.newStatus();
163163
TF_Tensor tensor = TFE_TensorHandleResolve(handle, status).withDeallocator();
164164
status.throwExceptionIfNotOK();
165-
return Tensor.fromHandle(tensor, session);
165+
return RawTensor.fromHandle(tensor, session).asTypedTensor();
166166
}
167167
}
168168

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,13 @@ public EagerOperationBuilder setAttr(String name, DataType<?>[] values) {
175175
}
176176

177177
@Override
178-
public EagerOperationBuilder setAttr(String name, Tensor<?> value) {
179-
setAttrTensor(opHandle, name, value.nativeHandle());
178+
public EagerOperationBuilder setAttr(String name, Tensor value) {
179+
setAttrTensor(opHandle, name, value.asRawTensor().nativeHandle());
180180
return this;
181181
}
182182

183183
@Override
184-
public EagerOperationBuilder setAttr(String name, Tensor<?>[] values) {
184+
public EagerOperationBuilder setAttr(String name, Tensor[] values) {
185185
// TODO (karllessard) could be supported by adding this attribute type in the eager C API
186186
throw new UnsupportedOperationException(
187187
"Tensor list attributes are not supported in eager mode");

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ DataType<?> dtype(int outputIdx) {
158158
}
159159

160160
@Override
161-
Tensor<?> tensor(int outputIdx) {
161+
Tensor tensor(int outputIdx) {
162162
throw new IllegalStateException("Graph tensors must be fetched by running a session");
163163
}
164164

0 commit comments

Comments
 (0)