Skip to content

Commit 72d0653

Browse files
committed
Rectify documentation based on PR review
1 parent 0225ab6 commit 72d0653

File tree

8 files changed

+63
-32
lines changed

8 files changed

+63
-32
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,17 +1070,6 @@ public <T extends TNumber> Bucketize bucketize(Operand<T> input, List<Float> bou
10701070
return Bucketize.create(scope, input, boundaries);
10711071
}
10721072

1073-
/**
1074-
* Capture a {@code tensor} by making a constant copy of it.
1075-
*
1076-
* @param scope is a scope used to add the underlying operation.
1077-
* @param tensor a Tensor holding the constant value
1078-
* @return a constant of the same data type as `tensor`
1079-
*/
1080-
public <T extends TType> Constant<T> capture(T tensor) {
1081-
return Constant.create(scope, tensor);
1082-
}
1083-
10841073
/**
10851074
* Clips tensor values to a specified min and max.
10861075
* <p>
@@ -1876,6 +1865,20 @@ public <T extends TType> Constant<T> constant(DataType<T> type, Shape shape,
18761865
return Constant.tensorOf(scope, type, shape, data);
18771866
}
18781867

1868+
/**
1869+
* Create a constant by making an immutable copy of {@code tensor}.
1870+
*
1871+
* <p>Note: this endpoint cannot be simply called {@code constant} since it will conflict with
1872+
* other endpoints accepting an NdArray in parameter {e.g. {@link #tensorOf(Scope, FloatNdArray)}}.
1873+
*
1874+
* @param scope is a scope used to add the underlying operation.
1875+
* @param tensor a Tensor holding the constant value
1876+
* @return a constant of the same data type as `tensor`
1877+
*/
1878+
public <T extends TType> Constant<T> constantOf(T tensor) {
1879+
return Constant.create(scope, tensor);
1880+
}
1881+
18791882
/**
18801883
* This op consumes a lock created by `MutexLock`.
18811884
* <p>

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ public interface TensorMapper<T> {
3636
/**
3737
* Maps the tensor memory to a n-dimensional typed data space.
3838
*
39+
* <p>This method is designed to be invoked internally by this library only, in order to pass the
40+
* native handle of {@code tensor} as {@code nativeHandle} (and since only classes from the
41+
* {@code org.tensorflow} package can retrieve such handle).
42+
*
3943
* @param tensor the tensor to map in its raw nature
4044
* @param nativeHandle native handle of the tensor
4145
* @return a typed tensor of type {@code T}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,10 @@
3737
* n-dimensional typed space by a {@link TType typed tensor}.</p>
3838
*
3939
* <p>Instances of a RawTensor are <b>not</b> thread-safe and their resource must be released
40-
* by calling {@link #close()} explicitly or implicitly (try-with-resources).</p>
40+
* by calling {@link #close()} explicitly or implicitly via try-with-resources.</p>
4141
*/
4242
public final class RawTensor implements Tensor {
4343

44-
/**
45-
* Returns a typed version of this tensor
46-
*/
47-
TType asTypedTensor() {
48-
return dtype.map(this);
49-
}
50-
5144
@Override
5245
public DataType<?> dataType() {
5346
return dtype;
@@ -152,13 +145,25 @@ static RawTensor fromHandle(TF_Tensor handle, EagerSession session) {
152145
}
153146

154147
/**
155-
* @return native handle to this tensor
148+
* Returns the native handle to this tensor
156149
* @throws IllegalStateException if tensor has been closed
157150
*/
158151
TF_Tensor nativeHandle() {
159152
return requireHandle(tensorHandle);
160153
}
161154

155+
/**
156+
* Returns a typed reference to this tensor
157+
*
158+
* <p>In some cases, it is more useful to keep a typed reference to a tensor rather than its raw
159+
* nature to prevent mapping its memory on every access (e.g. when calling {@link Operand#asTensor()}).
160+
*
161+
* @return typed reference to this tensor
162+
*/
163+
TType asTypedTensor() {
164+
return dtype.map(this);
165+
}
166+
162167
private static TF_Tensor requireHandle(TF_Tensor handle) {
163168
if (handle == null || handle.isNull()) {
164169
throw new IllegalStateException("close() was called on the Tensor");

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,13 +1278,16 @@ public static Constant<TInt64> tensorOf(Scope scope, Shape shape) {
12781278
}
12791279

12801280
/**
1281-
* Capture a {@code tensor} by making a constant copy of it.
1281+
* Create a constant by making an immutable copy of {@code tensor}.
1282+
*
1283+
* <p>Note: this endpoint cannot be simply called {@code constant} since it will conflict with
1284+
* other endpoints accepting an NdArray in parameter {e.g. {@link #tensorOf(Scope, FloatNdArray)}}.
12821285
*
12831286
* @param scope is a scope used to add the underlying operation.
12841287
* @param tensor a Tensor holding the constant value
12851288
* @return a constant of the same data type as `tensor`
12861289
*/
1287-
@Endpoint(name = "capture") // Cannot be "constant" since it will conflict with other endpoints accepting an NdArray
1290+
@Endpoint(name = "constantOf")
12881291
public static <T extends TType> Constant<T> create(Scope scope, T tensor) {
12891292
return new Constant<>(
12901293
scope

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
/**
2323
* Common interface for all typed tensors.
2424
*
25-
* <p>Typed tensors wraps a {@link RawTensor} by mapping their native memory to a n-dimensional
26-
* data space allowing direct I/O access from the JVM.</p>
25+
* <p>Typed tensors wrap a {@link org.tensorflow.RawTensor RawTensor} by mapping their native memory
26+
* to a n-dimensional data space allowing direct I/O access from the JVM.</p>
2727
*
2828
* <p>Subinterfaces of {@code TType} are propagated as a generic parameter to various entities of
2929
* TensorFlow to identify the type of the tensor they carry. For example, a
30-
* {@link org.tensorflow.Operand Operand<TFloat32>} is an operand which outputs is a 32-bit floating
30+
* {@link org.tensorflow.Operand Operand<TFloat32>} is an operand which outputs a 32-bit floating
3131
* point tensor. This parameter ensure type-compatibility between operands of a computation at
3232
* compile-time. For example:
3333
*
@@ -41,6 +41,22 @@
4141
* tf.math.add(c1, c2); // OK
4242
* tf.math.add(c1, c3); // Compilation failure
4343
* }</pre>
44+
*
45+
* <p>Even if all typed tensors implements somehow {@link org.tensorflow.ndarray.NdArray NdArray}
46+
* to provide access to their data, {@code TType} deliberately does not extend directly from this
47+
* interface, for the following reasons:
48+
* <ul>
49+
* <li>Implementing {@code NdArray} at this level could only expose boxed-type accessors, which
50+
* are less performant than their primitive equivalent, only exposed by subinterfaces of
51+
* {@code NdArray} (e.g. {@code FloatNdArray}).
52+
* </li>
53+
* <li>{@code TType} would need to carry a new generic parameter for typing the {@code NdArray},
54+
* which will increase the verbosity in the signature of any method accepting or returning
55+
* an instance of this interface, which is very common.
56+
* </li>
57+
* </ul>
58+
* Therefore, enforcing the user to cast a reference of {@code TType} in a concrete tensor type before
59+
* accessing its data guarantees better performance and improves readability.
4460
*/
4561
public interface TType extends Tensor {
4662

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ public void createFromTensorsInEagerMode() throws IOException {
158158
assertEquals(c1.asTensor(), t);
159159

160160
// A different endpoint for capturing a tensor as a constant, which supports all data types
161-
Constant<TInt32> c2 = tf.capture(t);
161+
Constant<TInt32> c2 = tf.constantOf(t);
162162
assertEquals(c2.asTensor(), t);
163163
assertEquals(c1.asTensor(), c2.asTensor());
164164

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public void createGradients() {
4848
assertEquals(2, grads.dy().size());
4949

5050
try (TFloat32 c = TFloat32.scalarOf(3.0f);
51-
AutoCloseableList<?> outputs =
51+
AutoCloseableList<Tensor> outputs =
5252
new AutoCloseableList<>(
5353
sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) {
5454

@@ -75,7 +75,7 @@ public void createGradientsWithSum() {
7575
assertEquals(1, grads.dy().size());
7676

7777
try (TFloat32 c = TFloat32.scalarOf(3.0f);
78-
AutoCloseableList<?> outputs =
78+
AutoCloseableList<Tensor> outputs =
7979
new AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) {
8080

8181
assertEquals(114.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f);
@@ -101,7 +101,7 @@ public void createGradientsWithInitialValues() {
101101
assertEquals(1, grads1.dy().size());
102102

103103
try (TFloat32 c = TFloat32.scalarOf(3.0f);
104-
AutoCloseableList<?> outputs =
104+
AutoCloseableList<Tensor> outputs =
105105
new AutoCloseableList<>(
106106
sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) {
107107

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ public void initializeTensorsWithZeros() {
4646

4747
// Initialize tensor memory with zeros and take a snapshot
4848
data.scalars().forEach(scalar -> ((NdArray<U>)scalar).setObject(valueOf(0)));
49-
Constant<T> x = tf.capture(tensor);
49+
Constant<T> x = tf.constantOf(tensor);
5050

5151
// Initialize the same tensor memory with ones and take a snapshot
5252
data.scalars().forEach(scalar -> ((NdArray<U>)scalar).setObject(valueOf(1)));
53-
Constant<T> y = tf.capture(tensor);
53+
Constant<T> y = tf.constantOf(tensor);
5454

5555
// Subtract y from x and validate the result
5656
Sub<T> sub = tf.math.sub(x, y);
@@ -94,7 +94,7 @@ public void setAndCompute() {
9494
try (EagerSession session = EagerSession.create()) {
9595
Ops tf = Ops.create(session);
9696

97-
Add<T> add = tf.math.add(tf.capture(tensor), tf.capture(tensor));
97+
Add<T> add = tf.math.add(tf.constantOf(tensor), tf.constantOf(tensor));
9898
NdArray<U> result = (NdArray<U>)add.asTensor();
9999

100100
assertEquals(valueOf(0), result.getObject(0, 0));

0 commit comments

Comments
 (0)