Skip to content

Commit 8b14b1d

Browse files
karllessardrnett
authored andcommitted
Move type implementations and mappers to internal package (tensorflow#172)
1 parent 206638c commit 8b14b1d

25 files changed

+595
-335
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
package org.tensorflow;
1717

18-
import org.tensorflow.internal.c_api.TF_Tensor;
1918
import org.tensorflow.types.TBfloat16;
2019
import org.tensorflow.types.TBool;
2120
import org.tensorflow.types.TFloat16;
@@ -30,23 +29,6 @@
3029
/** Represents a type of elements in a {@link Tensor} */
3130
public final class DataType<T extends TType> {
3231

33-
@FunctionalInterface
34-
public interface TensorMapper<T> {
35-
36-
/**
37-
* Maps the tensor memory to a n-dimensional typed data space.
38-
*
39-
* <p>This method is designed to be invoked internally by this library only, in order to pass the
40-
* native handle of {@code tensor} as {@code nativeHandle} (and since only classes from the
41-
* {@code org.tensorflow} package can retrieve such handle).
42-
*
43-
* @param tensor the tensor to map in its raw nature
44-
* @param nativeHandle native handle of the tensor
45-
* @return a typed tensor of type {@code T}
46-
*/
47-
T apply(RawTensor tensor, TF_Tensor nativeHandle);
48-
}
49-
5032
/**
5133
* Creates a new datatype
5234
*
@@ -167,7 +149,7 @@ int nativeCode() {
167149
* @return data structure of elements of this type
168150
*/
169151
T map(RawTensor tensor) {
170-
return tensorMapper.apply(tensor, tensor.nativeHandle());
152+
return tensorMapper.mapDense(tensor);
171153
}
172154

173155
private final int nativeCode;
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package org.tensorflow;
2+
3+
import org.tensorflow.internal.c_api.TF_Tensor;
4+
import org.tensorflow.types.family.TType;
5+
6+
/**
7+
* Maps the native memory of a {@link RawTensor} to a n-dimensional typed data space
8+
* accessible from the JVM.
9+
*
10+
* <p>Usage of this class is reserved for internal purposes only.
11+
*
12+
* @param <T> tensor type mapped by this object
13+
* @see {@link TType}
14+
*/
15+
public abstract class TensorMapper<T extends TType> {
16+
17+
/**
18+
* Maps the provided dense raw {@code tensor} as a tensor of type {@code T}.
19+
*
20+
* @param tensor the dense tensor to map, in its raw nature
21+
* @return an instance of {@code T}
22+
*/
23+
protected abstract T mapDense(RawTensor tensor);
24+
25+
/**
26+
* Helper for retrieving the native handle of a raw tensor
27+
*
28+
* @param tensor a raw tensor
29+
* @return the native handle of that tensor
30+
* @throws IllegalStateException if the tensor has been released
31+
*/
32+
protected static TF_Tensor nativeHandle(RawTensor tensor) {
33+
return tensor.nativeHandle();
34+
}
35+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package org.tensorflow.internal.buffer;
2+
3+
import java.util.Iterator;
4+
import java.util.function.Function;
5+
import org.tensorflow.ndarray.NdArray;
6+
import org.tensorflow.ndarray.NdArraySequence;
7+
8+
/**
9+
* Produces sequence of bytes to be stored in a {@link ByteSequenceTensorBuffer}.
10+
*
11+
* @param <T> source of bytes (byte arrays or strings)
12+
*/
13+
public class ByteSequenceProvider<T> implements Iterable<byte[]> {
14+
15+
/**
16+
* Constructor
17+
*
18+
* @param source source of data
19+
* @param byteExtractor method that converts one value of the source into a sequence of bytes
20+
*/
21+
public ByteSequenceProvider(NdArray<T> source, Function<T, byte[]> byteExtractor) {
22+
this.source = source;
23+
this.byteExtractor = byteExtractor;
24+
}
25+
26+
@Override
27+
public Iterator<byte[]> iterator() {
28+
return new Iterator<byte[]>() {
29+
30+
@Override
31+
public boolean hasNext() {
32+
return scalarIterator.hasNext();
33+
}
34+
35+
@Override
36+
public byte[] next() {
37+
return byteExtractor.apply(scalarIterator.next().getObject());
38+
}
39+
40+
private final Iterator<? extends NdArray<T>> scalarIterator = source.scalars().iterator();
41+
};
42+
}
43+
44+
/**
45+
* @return total number of byte sequences that can be produced by this sequencer
46+
*/
47+
long numSequences() {
48+
return source.size();
49+
}
50+
51+
private final NdArray<T> source;
52+
private final Function<T, byte[]> byteExtractor;
53+
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/StringTensorBuffer.java renamed to tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,21 @@
5050
* <p>After its data has been initialized, the buffer is read-only as it is not possible to change
5151
* safely a value without reinitializing the whole data.
5252
*/
53-
public class StringTensorBuffer extends AbstractDataBuffer<byte[]> {
53+
public class ByteSequenceTensorBuffer extends AbstractDataBuffer<byte[]> {
5454

5555
/**
5656
* Computes how many bytes are required to store the given data in a string buffer.
5757
*
58-
* @param data data to store eventually by calling {@link #init(NdArray, Function)}
59-
* @param getBytes method that converts one value of the data into a sequence of bytes
58+
* @param byteSequenceProvider produces sequences of bytes
6059
* @return number of bytes required to store the data.
6160
*/
62-
public static <T> long computeSize(NdArray<T> data, Function<T, byte[]> getBytes) {
61+
public static <T> long computeSize(ByteSequenceProvider<?> byteSequenceProvider) {
6362
// reserve space to store 64-bit offsets
64-
long size = data.size() * Long.BYTES;
63+
long size = byteSequenceProvider.numSequences() * Long.BYTES;
6564

6665
// reserve space to store length and data of each values
67-
for (NdArray<T> scalar : data.scalars()) {
68-
byte[] elementBytes = getBytes.apply(scalar.getObject());
69-
size += elementBytes.length + StringTensorBuffer.varintLength(elementBytes.length);
66+
for (byte[] elementBytes : byteSequenceProvider) {
67+
size += elementBytes.length + ByteSequenceTensorBuffer.varintLength(elementBytes.length);
7068
}
7169
return size;
7270
}
@@ -79,14 +77,11 @@ public static <T> long computeSize(NdArray<T> data, Function<T, byte[]> getBytes
7977
* same set of data, calling {@link #computeSize(NdArray, Function)} priory to make sure there is
8078
* enough space to store it.
8179
*
82-
* @param data data to store
83-
* @param getBytes method that converts one value of the data into a sequence of bytes
80+
* @param byteSequenceProvider produces sequences of bytes to use as the tensor data
8481
*/
85-
public <T> void init(NdArray<T> data, Function<T, byte[]> getBytes) {
82+
public <T> void init(ByteSequenceProvider<T> byteSequenceProvider) {
8683
InitDataWriter writer = new InitDataWriter();
87-
for (NdArray<T> scalar : data.scalars()) {
88-
writer.writeNext(getBytes.apply(scalar.getObject()));
89-
}
84+
byteSequenceProvider.forEach(writer::writeNext);
9085
}
9186

9287
@Override
@@ -129,8 +124,8 @@ public boolean isReadOnly() {
129124

130125
@Override
131126
public DataBuffer<byte[]> copyTo(DataBuffer<byte[]> dst, long size) {
132-
if (size == size() && dst instanceof StringTensorBuffer) {
133-
StringTensorBuffer tensorDst = (StringTensorBuffer) dst;
127+
if (size == size() && dst instanceof ByteSequenceTensorBuffer) {
128+
ByteSequenceTensorBuffer tensorDst = (ByteSequenceTensorBuffer) dst;
134129
if (offsets.size() != size || data.size() != size) {
135130
throw new IllegalArgumentException(
136131
"Cannot copy string tensor data to another tensor of a different size");
@@ -145,20 +140,20 @@ public DataBuffer<byte[]> copyTo(DataBuffer<byte[]> dst, long size) {
145140

146141
@Override
147142
public DataBuffer<byte[]> offset(long index) {
148-
return new StringTensorBuffer(offsets.offset(index), data);
143+
return new ByteSequenceTensorBuffer(offsets.offset(index), data);
149144
}
150145

151146
@Override
152147
public DataBuffer<byte[]> narrow(long size) {
153-
return new StringTensorBuffer(offsets.narrow(size), data);
148+
return new ByteSequenceTensorBuffer(offsets.narrow(size), data);
154149
}
155150

156151
@Override
157152
public DataBuffer<byte[]> slice(long index, long size) {
158-
return new StringTensorBuffer(offsets.slice(index, size), data);
153+
return new ByteSequenceTensorBuffer(offsets.slice(index, size), data);
159154
}
160155

161-
StringTensorBuffer(LongDataBuffer offsets, ByteDataBuffer data) {
156+
ByteSequenceTensorBuffer(LongDataBuffer offsets, ByteDataBuffer data) {
162157
this.offsets = offsets;
163158
this.data = data;
164159
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/TensorBuffers.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ public static BooleanDataBuffer toBooleans(TF_Tensor nativeTensor) {
156156
* @param nativeTensor native reference to the tensor
157157
* @return a string buffer
158158
*/
159-
public static StringTensorBuffer toStrings(TF_Tensor nativeTensor, long numElements) {
159+
public static ByteSequenceTensorBuffer toStrings(TF_Tensor nativeTensor, long numElements) {
160160
Pointer tensorMemory = tensorMemory(nativeTensor);
161161
if (TensorRawDataBufferFactory.canBeUsed()) {
162162
return TensorRawDataBufferFactory.mapTensorToStrings(tensorMemory, numElements);
@@ -173,7 +173,7 @@ public static StringTensorBuffer toStrings(TF_Tensor nativeTensor, long numEleme
173173
dataBuffer.position((int)numElements * Long.BYTES);
174174
ByteDataBuffer data = DataBuffers.of(dataBuffer.slice());
175175

176-
return new StringTensorBuffer(offsets, data);
176+
return new ByteSequenceTensorBuffer(offsets, data);
177177
}
178178

179179
private static Pointer tensorMemory(TF_Tensor nativeTensor) {

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/TensorRawDataBufferFactory.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,13 @@ static BooleanDataBuffer mapTensorToBooleans(Pointer tensorMemory) {
5757
return mapNativeBooleans(tensorMemory.address(), tensorMemory.capacity(), false);
5858
}
5959

60-
static StringTensorBuffer mapTensorToStrings(Pointer tensorMemory, long numElements) {
60+
static ByteSequenceTensorBuffer mapTensorToStrings(Pointer tensorMemory, long numElements) {
6161
long offsetByteSize = numElements * Long.BYTES;
6262
LongDataBuffer offsets = mapNativeLongs(tensorMemory.address(), offsetByteSize, false);
6363
ByteDataBuffer data = mapNativeBytes(
6464
tensorMemory.address() + offsetByteSize,
6565
tensorMemory.capacity() - offsetByteSize,
6666
false);
67-
return new StringTensorBuffer(offsets, data);
67+
return new ByteSequenceTensorBuffer(offsets, data);
6868
}
6969
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package org.tensorflow.internal.types;
2+
3+
import org.tensorflow.TensorMapper;
4+
import org.tensorflow.DataType;
5+
import org.tensorflow.RawTensor;
6+
import org.tensorflow.internal.buffer.TensorBuffers;
7+
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
8+
import org.tensorflow.ndarray.buffer.layout.DataLayouts;
9+
import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray;
10+
import org.tensorflow.types.TBfloat16;
11+
12+
/**
13+
* Maps memory of {@link org.tensorflow.proto.framework.DataType#DT_BFLOAT16} tensors
14+
* to a n-dimensional data space.
15+
*/
16+
public final class TBfloat16Mapper extends TensorMapper<TBfloat16> {
17+
18+
@Override
19+
protected TBfloat16 mapDense(RawTensor tensor) {
20+
FloatDataBuffer buffer = DataLayouts.BFLOAT16.applyTo(TensorBuffers.toShorts(nativeHandle(tensor)));
21+
return new DenseTBfloat16(tensor, buffer);
22+
}
23+
24+
private static final class DenseTBfloat16 extends FloatDenseNdArray implements TBfloat16 {
25+
26+
@Override
27+
public DataType<TBfloat16> dataType() {
28+
return TBfloat16.DTYPE;
29+
}
30+
31+
@Override
32+
public RawTensor asRawTensor() {
33+
return rawTensor;
34+
}
35+
36+
final RawTensor rawTensor;
37+
38+
DenseTBfloat16(RawTensor rawTensor, FloatDataBuffer buffer) {
39+
super(buffer, rawTensor.shape());
40+
this.rawTensor = rawTensor;
41+
}
42+
}
43+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package org.tensorflow.internal.types;
2+
3+
import org.tensorflow.DataType;
4+
import org.tensorflow.RawTensor;
5+
import org.tensorflow.TensorMapper;
6+
import org.tensorflow.internal.buffer.TensorBuffers;
7+
import org.tensorflow.ndarray.buffer.BooleanDataBuffer;
8+
import org.tensorflow.ndarray.impl.dense.BooleanDenseNdArray;
9+
import org.tensorflow.types.TBool;
10+
11+
/**
12+
* Maps memory of {@link org.tensorflow.proto.framework.DataType#DT_BOOL} tensors
13+
* to a n-dimensional data space.
14+
*/
15+
public final class TBoolMapper extends TensorMapper<TBool> {
16+
17+
@Override
18+
protected TBool mapDense(RawTensor tensor) {
19+
BooleanDataBuffer buffer = TensorBuffers.toBooleans(nativeHandle(tensor));
20+
return new DenseTBool(tensor, buffer);
21+
}
22+
23+
private static final class DenseTBool extends BooleanDenseNdArray implements TBool {
24+
25+
@Override
26+
public DataType<TBool> dataType() {
27+
return TBool.DTYPE;
28+
}
29+
30+
@Override
31+
public RawTensor asRawTensor() {
32+
return rawTensor;
33+
}
34+
35+
final RawTensor rawTensor;
36+
37+
DenseTBool(RawTensor rawTensor, BooleanDataBuffer buffer) {
38+
super(buffer, rawTensor.shape());
39+
this.rawTensor = rawTensor;
40+
}
41+
}
42+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package org.tensorflow.internal.types;
2+
3+
import org.tensorflow.DataType;
4+
import org.tensorflow.RawTensor;
5+
import org.tensorflow.TensorMapper;
6+
import org.tensorflow.internal.buffer.TensorBuffers;
7+
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
8+
import org.tensorflow.ndarray.buffer.layout.DataLayouts;
9+
import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray;
10+
import org.tensorflow.types.TFloat16;
11+
12+
/**
13+
* Maps memory of {@link org.tensorflow.proto.framework.DataType#DT_HALF} tensors
14+
* to a n-dimensional data space.
15+
*/
16+
public final class TFloat16Mapper extends TensorMapper<TFloat16> {
17+
18+
@Override
19+
protected TFloat16 mapDense(RawTensor tensor) {
20+
FloatDataBuffer buffer = DataLayouts.FLOAT16.applyTo(TensorBuffers.toShorts(nativeHandle(tensor)));
21+
return new DenseTFloat16(tensor, buffer);
22+
}
23+
24+
private static final class DenseTFloat16 extends FloatDenseNdArray implements TFloat16 {
25+
26+
@Override
27+
public DataType<?> dataType() {
28+
return TFloat16.DTYPE;
29+
}
30+
31+
@Override
32+
public RawTensor asRawTensor() {
33+
return rawTensor;
34+
}
35+
36+
final RawTensor rawTensor;
37+
38+
DenseTFloat16(RawTensor rawTensor, FloatDataBuffer buffer) {
39+
super(buffer, rawTensor.shape());
40+
this.rawTensor = rawTensor;
41+
}
42+
}
43+
}

0 commit comments

Comments
 (0)