Skip to content

Commit 598fdc3

Browse files
committed
Rectify documentation based on PR review
1 parent 0225ab6 commit 598fdc3

File tree

9 files changed

+71
-38
lines changed

9 files changed

+71
-38
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,17 +1070,6 @@ public <T extends TNumber> Bucketize bucketize(Operand<T> input, List<Float> bou
10701070
return Bucketize.create(scope, input, boundaries);
10711071
}
10721072

1073-
/**
1074-
* Capture a {@code tensor} by making a constant copy of it.
1075-
*
1076-
* @param scope is a scope used to add the underlying operation.
1077-
* @param tensor a Tensor holding the constant value
1078-
* @return a constant of the same data type as `tensor`
1079-
*/
1080-
public <T extends TType> Constant<T> capture(T tensor) {
1081-
return Constant.create(scope, tensor);
1082-
}
1083-
10841073
/**
10851074
* Clips tensor values to a specified min and max.
10861075
* <p>
@@ -1876,6 +1865,20 @@ public <T extends TType> Constant<T> constant(DataType<T> type, Shape shape,
18761865
return Constant.tensorOf(scope, type, shape, data);
18771866
}
18781867

1868+
/**
1869+
* Create a constant by making an immutable copy of {@code tensor}.
1870+
*
1871+
* <p>Note: this endpoint cannot be simply called {@code constant} since it will conflict with
1872+
* other endpoints accepting an NdArray in parameter {e.g. {@link #tensorOf(Scope, FloatNdArray)}}.
1873+
*
1874+
* @param scope is a scope used to add the underlying operation.
1875+
* @param tensor a Tensor holding the constant value
1876+
* @return a constant of the same data type as `tensor`
1877+
*/
1878+
public <T extends TType> Constant<T> constantOf(T tensor) {
1879+
return Constant.create(scope, tensor);
1880+
}
1881+
18791882
/**
18801883
* This op consumes a lock created by `MutexLock`.
18811884
* <p>

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ public interface TensorMapper<T> {
3636
/**
3737
* Maps the tensor memory to a n-dimensional typed data space.
3838
*
39+
* <p>This method is designed to be invoked internally by this library only, in order to pass the
40+
* native handle of {@code tensor} as {@code nativeHandle} (and since only classes from the
41+
* {@code org.tensorflow} package can retrieve such handle).
42+
*
3943
* @param tensor the tensor to map in its raw nature
4044
* @param nativeHandle native handle of the tensor
4145
* @return a typed tensor of type {@code T}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,18 @@
2929
import org.tensorflow.types.family.TType;
3030

3131
/**
32-
* A tensor which memory has not been mapped.
32+
* A tensor which memory has not been mapped to a data space directly accessible from the JVM.
3333
*
3434
* <p>A raw tensor is a minimalist representation of a tensor allocated in native memory by the
3535
* TensorFlow runtime library and it controls its lifetime within the current process. The data
3636
* is represented by a flat {@link ByteDataBuffer buffer of bytes}, until it is mapped in a
3737
* n-dimensional typed space by a {@link TType typed tensor}.</p>
3838
*
3939
* <p>Instances of a RawTensor are <b>not</b> thread-safe and their resource must be released
40-
* by calling {@link #close()} explicitly or implicitly (try-with-resources).</p>
40+
* by calling {@link #close()} explicitly or implicitly via try-with-resources.</p>
4141
*/
4242
public final class RawTensor implements Tensor {
4343

44-
/**
45-
* Returns a typed version of this tensor
46-
*/
47-
TType asTypedTensor() {
48-
return dtype.map(this);
49-
}
50-
5144
@Override
5245
public DataType<?> dataType() {
5346
return dtype;
@@ -152,13 +145,28 @@ static RawTensor fromHandle(TF_Tensor handle, EagerSession session) {
152145
}
153146

154147
/**
155-
* @return native handle to this tensor
148+
* Returns the native handle to this tensor
156149
* @throws IllegalStateException if tensor has been closed
157150
*/
158151
TF_Tensor nativeHandle() {
159152
return requireHandle(tensorHandle);
160153
}
161154

155+
/**
156+
* Returns a typed reference to this tensor
157+
*
158+
* <p>In some cases, it is more useful to keep a typed reference to a tensor rather than its raw
159+
* nature to prevent mapping its memory on every access (e.g. when calling {@link Operand#asTensor()}).
160+
*
161+
* @param <T> type of the tensor (must be compatible with the internal representation of this tensor,
162+
* as indicated by {@link #dataType()})
163+
* @return typed reference to this tensor
164+
* @throws ClassCastException if {@code T} is not compatible type with {@link #dataType()}
165+
*/
166+
<T extends TType> T asTypedTensor() {
167+
return (T)dtype.map(this);
168+
}
169+
162170
private static TF_Tensor requireHandle(TF_Tensor handle) {
163171
if (handle == null || handle.isNull()) {
164172
throw new IllegalStateException("close() was called on the Tensor");

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,8 @@
4646
* try (Tensor t = Tensor.of(...)) {
4747
* doSomethingWith(t);
4848
* }
49-
*
50-
* <p>Instances of a Tensor are <b>not</b> thread-safe.
5149
* }</pre>
50+
* <p>Instances of a Tensor are <b>not</b> thread-safe.
5251
*/
5352
public interface Tensor extends Shaped, AutoCloseable {
5453

@@ -88,7 +87,7 @@ static <T extends TType> T of(DataType<T> dtype, Shape shape) {
8887
static <T extends TType> T of(DataType<T> dtype, Shape shape, long size) {
8988
RawTensor tensor = RawTensor.allocate(dtype, shape, size);
9089
try {
91-
return dtype.map(tensor);
90+
return tensor.asTypedTensor();
9291
} catch (Exception e) {
9392
tensor.close();
9493
throw e;
@@ -130,7 +129,7 @@ static <T extends TType> T of(DataType<T> dtype, Shape shape, Consumer<T> dataIn
130129
* size for the tensor is explicitly set instead of being computed from the datatype and shape.
131130
*
132131
* <p>This could be useful for tensor types that stores data but also metadata in the tensor memory,
133-
* such as lookup table in a tensor of strings.
132+
* such as the lookup table in a tensor of strings.
134133
*
135134
* @param <T> the tensor element type
136135
* @param dtype datatype of the tensor
@@ -148,7 +147,7 @@ static <T extends TType> T of(DataType<T> dtype, Shape shape, long size, Consume
148147
try {
149148
dataInitializer.accept(tensor);
150149
return tensor;
151-
} catch (Throwable t) {
150+
} catch (Exception t) {
152151
tensor.close();
153152
throw t;
154153
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,13 +1278,16 @@ public static Constant<TInt64> tensorOf(Scope scope, Shape shape) {
12781278
}
12791279

12801280
/**
1281-
* Capture a {@code tensor} by making a constant copy of it.
1281+
* Create a constant by making an immutable copy of {@code tensor}.
1282+
*
1283+
* <p>Note: this endpoint cannot be simply called {@code constant} since it will conflict with
1284+
* other endpoints accepting an NdArray in parameter {e.g. {@link #tensorOf(Scope, FloatNdArray)}}.
12821285
*
12831286
* @param scope is a scope used to add the underlying operation.
12841287
* @param tensor a Tensor holding the constant value
12851288
* @return a constant of the same data type as `tensor`
12861289
*/
1287-
@Endpoint(name = "capture") // Cannot be "constant" since it will conflict with other endpoints accepting an NdArray
1290+
@Endpoint(name = "constantOf")
12881291
public static <T extends TType> Constant<T> create(Scope scope, T tensor) {
12891292
return new Constant<>(
12901293
scope

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
/**
2323
* Common interface for all typed tensors.
2424
*
25-
* <p>Typed tensors wraps a {@link RawTensor} by mapping their native memory to a n-dimensional
26-
* data space allowing direct I/O access from the JVM.</p>
25+
* <p>Typed tensors wrap a {@link org.tensorflow.RawTensor RawTensor} by mapping their native memory
26+
* to a n-dimensional data space allowing direct I/O access from the JVM.</p>
2727
*
2828
* <p>Subinterfaces of {@code TType} are propagated as a generic parameter to various entities of
2929
* TensorFlow to identify the type of the tensor they carry. For example, a
30-
* {@link org.tensorflow.Operand Operand<TFloat32>} is an operand which outputs is a 32-bit floating
30+
* {@link org.tensorflow.Operand Operand<TFloat32>} is an operand which outputs a 32-bit floating
3131
* point tensor. This parameter ensure type-compatibility between operands of a computation at
3232
* compile-time. For example:
3333
*
@@ -41,6 +41,22 @@
4141
* tf.math.add(c1, c2); // OK
4242
* tf.math.add(c1, c3); // Compilation failure
4343
* }</pre>
44+
*
45+
* <p>Even if all typed tensors implements somehow {@link org.tensorflow.ndarray.NdArray NdArray}
46+
* to provide access to their data, {@code TType} deliberately does not extend directly from this
47+
* interface, for the following reasons:
48+
* <ul>
49+
* <li>Implementing {@code NdArray} at this level could only expose boxed-type accessors, which
50+
* are less performant than their primitive equivalent, only exposed by subinterfaces of
51+
* {@code NdArray} (e.g. {@code FloatNdArray}).
52+
* </li>
53+
* <li>{@code TType} would need to carry a new generic parameter for typing the {@code NdArray},
54+
* which will increase the verbosity in the signature of any method accepting or returning
55+
* an instance of this interface, which is very common.
56+
* </li>
57+
* </ul>
58+
* Therefore, enforcing the user to cast a reference of {@code TType} in a concrete tensor type before
59+
* accessing its data guarantees better performance and improves readability.
4460
*/
4561
public interface TType extends Tensor {
4662

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ public void createFromTensorsInEagerMode() throws IOException {
158158
assertEquals(c1.asTensor(), t);
159159

160160
// A different endpoint for capturing a tensor as a constant, which supports all data types
161-
Constant<TInt32> c2 = tf.capture(t);
161+
Constant<TInt32> c2 = tf.constantOf(t);
162162
assertEquals(c2.asTensor(), t);
163163
assertEquals(c1.asTensor(), c2.asTensor());
164164

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public void createGradients() {
4848
assertEquals(2, grads.dy().size());
4949

5050
try (TFloat32 c = TFloat32.scalarOf(3.0f);
51-
AutoCloseableList<?> outputs =
51+
AutoCloseableList<Tensor> outputs =
5252
new AutoCloseableList<>(
5353
sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) {
5454

@@ -75,7 +75,7 @@ public void createGradientsWithSum() {
7575
assertEquals(1, grads.dy().size());
7676

7777
try (TFloat32 c = TFloat32.scalarOf(3.0f);
78-
AutoCloseableList<?> outputs =
78+
AutoCloseableList<Tensor> outputs =
7979
new AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) {
8080

8181
assertEquals(114.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f);
@@ -101,7 +101,7 @@ public void createGradientsWithInitialValues() {
101101
assertEquals(1, grads1.dy().size());
102102

103103
try (TFloat32 c = TFloat32.scalarOf(3.0f);
104-
AutoCloseableList<?> outputs =
104+
AutoCloseableList<Tensor> outputs =
105105
new AutoCloseableList<>(
106106
sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) {
107107

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ public void initializeTensorsWithZeros() {
4646

4747
// Initialize tensor memory with zeros and take a snapshot
4848
data.scalars().forEach(scalar -> ((NdArray<U>)scalar).setObject(valueOf(0)));
49-
Constant<T> x = tf.capture(tensor);
49+
Constant<T> x = tf.constantOf(tensor);
5050

5151
// Initialize the same tensor memory with ones and take a snapshot
5252
data.scalars().forEach(scalar -> ((NdArray<U>)scalar).setObject(valueOf(1)));
53-
Constant<T> y = tf.capture(tensor);
53+
Constant<T> y = tf.constantOf(tensor);
5454

5555
// Subtract y from x and validate the result
5656
Sub<T> sub = tf.math.sub(x, y);
@@ -94,7 +94,7 @@ public void setAndCompute() {
9494
try (EagerSession session = EagerSession.create()) {
9595
Ops tf = Ops.create(session);
9696

97-
Add<T> add = tf.math.add(tf.capture(tensor), tf.capture(tensor));
97+
Add<T> add = tf.math.add(tf.constantOf(tensor), tf.constantOf(tensor));
9898
NdArray<U> result = (NdArray<U>)add.asTensor();
9999

100100
assertEquals(valueOf(0), result.getObject(0, 0));

0 commit comments

Comments
 (0)