diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java index fc6268f40a2..60657837969 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataType.java @@ -15,7 +15,6 @@ package org.tensorflow; -import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.types.TBfloat16; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat16; @@ -30,23 +29,6 @@ /** Represents a type of elements in a {@link Tensor} */ public final class DataType { - @FunctionalInterface - public interface TensorMapper { - - /** - * Maps the tensor memory to a n-dimensional typed data space. - * - *

This method is designed to be invoked internally by this library only, in order to pass the - * native handle of {@code tensor} as {@code nativeHandle} (and since only classes from the - * {@code org.tensorflow} package can retrieve such handle). - * - * @param tensor the tensor to map in its raw nature - * @param nativeHandle native handle of the tensor - * @return a typed tensor of type {@code T} - */ - T apply(RawTensor tensor, TF_Tensor nativeHandle); - } - /** * Creates a new datatype * @@ -167,7 +149,7 @@ int nativeCode() { * @return data structure of elements of this type */ T map(RawTensor tensor) { - return tensorMapper.apply(tensor, tensor.nativeHandle()); + return tensorMapper.mapDense(tensor); } private final int nativeCode; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorMapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorMapper.java new file mode 100644 index 00000000000..54148e6a019 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorMapper.java @@ -0,0 +1,35 @@ +package org.tensorflow; + +import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.types.family.TType; + +/** + * Maps the native memory of a {@link RawTensor} to a n-dimensional typed data space + * accessible from the JVM. + * + *

Usage of this class is reserved for internal purposes only. + * + * @param tensor type mapped by this object + * @see {@link TType} + */ +public abstract class TensorMapper { + + /** + * Maps the provided dense raw {@code tensor} as a tensor of type {@code T}. + * + * @param tensor the dense tensor to map, in its raw nature + * @return an instance of {@code T} + */ + protected abstract T mapDense(RawTensor tensor); + + /** + * Helper for retrieving the native handle of a raw tensor + * + * @param tensor a raw tensor + * @return the native handle of that tensor + * @throws IllegalStateException if the tensor has been released + */ + protected static TF_Tensor nativeHandle(RawTensor tensor) { + return tensor.nativeHandle(); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceProvider.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceProvider.java new file mode 100644 index 00000000000..15503aad32e --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceProvider.java @@ -0,0 +1,53 @@ +package org.tensorflow.internal.buffer; + +import java.util.Iterator; +import java.util.function.Function; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArraySequence; + +/** + * Produces sequence of bytes to be stored in a {@link ByteSequenceTensorBuffer}. + * + * @param source of bytes (byte arrays or strings) + */ +public class ByteSequenceProvider implements Iterable { + + /** + * Constructor + * + * @param source source of data + * @param byteExtractor method that converts one value of the source into a sequence of bytes + */ + public ByteSequenceProvider(NdArray source, Function byteExtractor) { + this.source = source; + this.byteExtractor = byteExtractor; + } + + @Override + public Iterator iterator() { + return new Iterator() { + + @Override + public boolean hasNext() { + return scalarIterator.hasNext(); + } + + @Override + public byte[] next() { + return byteExtractor.apply(scalarIterator.next().getObject()); + } + + private final Iterator> scalarIterator = source.scalars().iterator(); + }; + } + + /** + * @return total number of byte sequences that can be produced by this sequencer + */ + long numSequences() { + return source.size(); + } + + private final NdArray source; + private final Function byteExtractor; +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/StringTensorBuffer.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java similarity index 81% rename from tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/StringTensorBuffer.java rename to tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java index 83cdab33452..e3b9152c6ba 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/StringTensorBuffer.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceTensorBuffer.java @@ -50,23 +50,21 @@ *

After its data has been initialized, the buffer is read-only as it is not possible to change * safely a value without reinitializing the whole data. */ -public class StringTensorBuffer extends AbstractDataBuffer { +public class ByteSequenceTensorBuffer extends AbstractDataBuffer { /** * Computes how many bytes are required to store the given data in a string buffer. * - * @param data data to store eventually by calling {@link #init(NdArray, Function)} - * @param getBytes method that converts one value of the data into a sequence of bytes + * @param byteSequenceProvider produces sequences of bytes * @return number of bytes required to store the data. */ - public static long computeSize(NdArray data, Function getBytes) { + public static long computeSize(ByteSequenceProvider byteSequenceProvider) { // reserve space to store 64-bit offsets - long size = data.size() * Long.BYTES; + long size = byteSequenceProvider.numSequences() * Long.BYTES; // reserve space to store length and data of each values - for (NdArray scalar : data.scalars()) { - byte[] elementBytes = getBytes.apply(scalar.getObject()); - size += elementBytes.length + StringTensorBuffer.varintLength(elementBytes.length); + for (byte[] elementBytes : byteSequenceProvider) { + size += elementBytes.length + ByteSequenceTensorBuffer.varintLength(elementBytes.length); } return size; } @@ -79,14 +77,11 @@ public static long computeSize(NdArray data, Function getBytes * same set of data, calling {@link #computeSize(NdArray, Function)} priory to make sure there is * enough space to store it. * - * @param data data to store - * @param getBytes method that converts one value of the data into a sequence of bytes + * @param byteSequenceProvider produces sequences of bytes to use as the tensor data */ - public void init(NdArray data, Function getBytes) { + public void init(ByteSequenceProvider byteSequenceProvider) { InitDataWriter writer = new InitDataWriter(); - for (NdArray scalar : data.scalars()) { - writer.writeNext(getBytes.apply(scalar.getObject())); - } + byteSequenceProvider.forEach(writer::writeNext); } @Override @@ -129,8 +124,8 @@ public boolean isReadOnly() { @Override public DataBuffer copyTo(DataBuffer dst, long size) { - if (size == size() && dst instanceof StringTensorBuffer) { - StringTensorBuffer tensorDst = (StringTensorBuffer) dst; + if (size == size() && dst instanceof ByteSequenceTensorBuffer) { + ByteSequenceTensorBuffer tensorDst = (ByteSequenceTensorBuffer) dst; if (offsets.size() != size || data.size() != size) { throw new IllegalArgumentException( "Cannot copy string tensor data to another tensor of a different size"); @@ -145,20 +140,20 @@ public DataBuffer copyTo(DataBuffer dst, long size) { @Override public DataBuffer offset(long index) { - return new StringTensorBuffer(offsets.offset(index), data); + return new ByteSequenceTensorBuffer(offsets.offset(index), data); } @Override public DataBuffer narrow(long size) { - return new StringTensorBuffer(offsets.narrow(size), data); + return new ByteSequenceTensorBuffer(offsets.narrow(size), data); } @Override public DataBuffer slice(long index, long size) { - return new StringTensorBuffer(offsets.slice(index, size), data); + return new ByteSequenceTensorBuffer(offsets.slice(index, size), data); } - StringTensorBuffer(LongDataBuffer offsets, ByteDataBuffer data) { + ByteSequenceTensorBuffer(LongDataBuffer offsets, ByteDataBuffer data) { this.offsets = offsets; this.data = data; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/TensorBuffers.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/TensorBuffers.java index f29396dd321..415c5ca35ef 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/TensorBuffers.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/TensorBuffers.java @@ -156,7 +156,7 @@ public static BooleanDataBuffer toBooleans(TF_Tensor nativeTensor) { * @param nativeTensor native reference to the tensor * @return a string buffer */ - public static StringTensorBuffer toStrings(TF_Tensor nativeTensor, long numElements) { + public static ByteSequenceTensorBuffer toStrings(TF_Tensor nativeTensor, long numElements) { Pointer tensorMemory = tensorMemory(nativeTensor); if (TensorRawDataBufferFactory.canBeUsed()) { return TensorRawDataBufferFactory.mapTensorToStrings(tensorMemory, numElements); @@ -173,7 +173,7 @@ public static StringTensorBuffer toStrings(TF_Tensor nativeTensor, long numEleme dataBuffer.position((int)numElements * Long.BYTES); ByteDataBuffer data = DataBuffers.of(dataBuffer.slice()); - return new StringTensorBuffer(offsets, data); + return new ByteSequenceTensorBuffer(offsets, data); } private static Pointer tensorMemory(TF_Tensor nativeTensor) { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/TensorRawDataBufferFactory.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/TensorRawDataBufferFactory.java index 1cfb1c9ab9a..dbaf31f1dcc 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/TensorRawDataBufferFactory.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/TensorRawDataBufferFactory.java @@ -57,13 +57,13 @@ static BooleanDataBuffer mapTensorToBooleans(Pointer tensorMemory) { return mapNativeBooleans(tensorMemory.address(), tensorMemory.capacity(), false); } - static StringTensorBuffer mapTensorToStrings(Pointer tensorMemory, long numElements) { + static ByteSequenceTensorBuffer mapTensorToStrings(Pointer tensorMemory, long numElements) { long offsetByteSize = numElements * Long.BYTES; LongDataBuffer offsets = mapNativeLongs(tensorMemory.address(), offsetByteSize, false); ByteDataBuffer data = mapNativeBytes( tensorMemory.address() + offsetByteSize, tensorMemory.capacity() - offsetByteSize, false); - return new StringTensorBuffer(offsets, data); + return new ByteSequenceTensorBuffer(offsets, data); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBfloat16Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBfloat16Mapper.java new file mode 100644 index 00000000000..6c7102a365b --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBfloat16Mapper.java @@ -0,0 +1,43 @@ +package org.tensorflow.internal.types; + +import org.tensorflow.TensorMapper; +import org.tensorflow.DataType; +import org.tensorflow.RawTensor; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.layout.DataLayouts; +import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; +import org.tensorflow.types.TBfloat16; + +/** + * Maps memory of {@link org.tensorflow.proto.framework.DataType#DT_BFLOAT16} tensors + * to a n-dimensional data space. + */ +public final class TBfloat16Mapper extends TensorMapper { + + @Override + protected TBfloat16 mapDense(RawTensor tensor) { + FloatDataBuffer buffer = DataLayouts.BFLOAT16.applyTo(TensorBuffers.toShorts(nativeHandle(tensor))); + return new DenseTBfloat16(tensor, buffer); + } + + private static final class DenseTBfloat16 extends FloatDenseNdArray implements TBfloat16 { + + @Override + public DataType dataType() { + return TBfloat16.DTYPE; + } + + @Override + public RawTensor asRawTensor() { + return rawTensor; + } + + final RawTensor rawTensor; + + DenseTBfloat16(RawTensor rawTensor, FloatDataBuffer buffer) { + super(buffer, rawTensor.shape()); + this.rawTensor = rawTensor; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBoolMapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBoolMapper.java new file mode 100644 index 00000000000..9353f0fd7a9 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBoolMapper.java @@ -0,0 +1,42 @@ +package org.tensorflow.internal.types; + +import org.tensorflow.DataType; +import org.tensorflow.RawTensor; +import org.tensorflow.TensorMapper; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.impl.dense.BooleanDenseNdArray; +import org.tensorflow.types.TBool; + +/** + * Maps memory of {@link org.tensorflow.proto.framework.DataType#DT_BOOL} tensors + * to a n-dimensional data space. + */ +public final class TBoolMapper extends TensorMapper { + + @Override + protected TBool mapDense(RawTensor tensor) { + BooleanDataBuffer buffer = TensorBuffers.toBooleans(nativeHandle(tensor)); + return new DenseTBool(tensor, buffer); + } + + private static final class DenseTBool extends BooleanDenseNdArray implements TBool { + + @Override + public DataType dataType() { + return TBool.DTYPE; + } + + @Override + public RawTensor asRawTensor() { + return rawTensor; + } + + final RawTensor rawTensor; + + DenseTBool(RawTensor rawTensor, BooleanDataBuffer buffer) { + super(buffer, rawTensor.shape()); + this.rawTensor = rawTensor; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat16Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat16Mapper.java new file mode 100644 index 00000000000..1873acc8641 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat16Mapper.java @@ -0,0 +1,43 @@ +package org.tensorflow.internal.types; + +import org.tensorflow.DataType; +import org.tensorflow.RawTensor; +import org.tensorflow.TensorMapper; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.layout.DataLayouts; +import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; +import org.tensorflow.types.TFloat16; + +/** + * Maps memory of {@link org.tensorflow.proto.framework.DataType#DT_HALF} tensors + * to a n-dimensional data space. + */ +public final class TFloat16Mapper extends TensorMapper { + + @Override + protected TFloat16 mapDense(RawTensor tensor) { + FloatDataBuffer buffer = DataLayouts.FLOAT16.applyTo(TensorBuffers.toShorts(nativeHandle(tensor))); + return new DenseTFloat16(tensor, buffer); + } + + private static final class DenseTFloat16 extends FloatDenseNdArray implements TFloat16 { + + @Override + public DataType dataType() { + return TFloat16.DTYPE; + } + + @Override + public RawTensor asRawTensor() { + return rawTensor; + } + + final RawTensor rawTensor; + + DenseTFloat16(RawTensor rawTensor, FloatDataBuffer buffer) { + super(buffer, rawTensor.shape()); + this.rawTensor = rawTensor; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat32Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat32Mapper.java new file mode 100644 index 00000000000..68aacc0bff6 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat32Mapper.java @@ -0,0 +1,42 @@ +package org.tensorflow.internal.types; + +import org.tensorflow.DataType; +import org.tensorflow.RawTensor; +import org.tensorflow.TensorMapper; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; +import org.tensorflow.types.TFloat32; + +/** + * Maps memory of {@link org.tensorflow.proto.framework.DataType#DT_FLOAT} tensors + * to a n-dimensional data space. + */ +public final class TFloat32Mapper extends TensorMapper { + + @Override + protected TFloat32 mapDense(RawTensor tensor) { + FloatDataBuffer buffer = TensorBuffers.toFloats(nativeHandle(tensor)); + return new DenseTFloat32(tensor, buffer); + } + + private static final class DenseTFloat32 extends FloatDenseNdArray implements TFloat32 { + + @Override + public DataType dataType() { + return TFloat32.DTYPE; + } + + @Override + public RawTensor asRawTensor() { + return rawTensor; + } + + final RawTensor rawTensor; + + DenseTFloat32(RawTensor rawTensor, FloatDataBuffer buffer) { + super(buffer, rawTensor.shape()); + this.rawTensor = rawTensor; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat64Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat64Mapper.java new file mode 100644 index 00000000000..fee1d9cc8dc --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat64Mapper.java @@ -0,0 +1,42 @@ +package org.tensorflow.internal.types; + +import org.tensorflow.DataType; +import org.tensorflow.RawTensor; +import org.tensorflow.TensorMapper; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray; +import org.tensorflow.types.TFloat64; + +/** + * Maps memory of {@link org.tensorflow.proto.framework.DataType#DT_DOUBLE} tensors + * to a n-dimensional data space. + */ +public final class TFloat64Mapper extends TensorMapper { + + @Override + protected TFloat64 mapDense(RawTensor tensor) { + DoubleDataBuffer buffer = TensorBuffers.toDoubles(nativeHandle(tensor)); + return new DenseTFloat64(tensor, buffer); + } + + private static final class DenseTFloat64 extends DoubleDenseNdArray implements TFloat64 { + + @Override + public DataType dataType() { + return TFloat64.DTYPE; + } + + @Override + public RawTensor asRawTensor() { + return rawTensor; + } + + final RawTensor rawTensor; + + DenseTFloat64(RawTensor rawTensor, DoubleDataBuffer buffer) { + super(buffer, rawTensor.shape()); + this.rawTensor = rawTensor; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt32Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt32Mapper.java new file mode 100644 index 00000000000..c1e8404dde0 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt32Mapper.java @@ -0,0 +1,42 @@ +package org.tensorflow.internal.types; + +import org.tensorflow.DataType; +import org.tensorflow.RawTensor; +import org.tensorflow.TensorMapper; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.impl.dense.IntDenseNdArray; +import org.tensorflow.types.TInt32; + +/** + * Maps memory of {@link org.tensorflow.proto.framework.DataType#DT_INT32} tensors + * to a n-dimensional data space. + */ +public final class TInt32Mapper extends TensorMapper { + + @Override + protected TInt32 mapDense(RawTensor tensor) { + IntDataBuffer buffer = TensorBuffers.toInts(nativeHandle(tensor)); + return new DenseTInt32(tensor, buffer); + } + + private static final class DenseTInt32 extends IntDenseNdArray implements TInt32 { + + @Override + public DataType dataType() { + return TInt32.DTYPE; + } + + @Override + public RawTensor asRawTensor() { + return rawTensor; + } + + final RawTensor rawTensor; + + DenseTInt32(RawTensor rawTensor, IntDataBuffer buffer) { + super(buffer, rawTensor.shape()); + this.rawTensor = rawTensor; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt64Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt64Mapper.java new file mode 100644 index 00000000000..5dddcf0e16b --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TInt64Mapper.java @@ -0,0 +1,42 @@ +package org.tensorflow.internal.types; + +import org.tensorflow.DataType; +import org.tensorflow.RawTensor; +import org.tensorflow.TensorMapper; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.impl.dense.LongDenseNdArray; +import org.tensorflow.types.TInt64; + +/** + * Maps memory of {@link org.tensorflow.proto.framework.DataType#DT_INT64} tensors + * to a n-dimensional data space. + */ +public final class TInt64Mapper extends TensorMapper { + + @Override + protected TInt64 mapDense(RawTensor tensor) { + LongDataBuffer buffer = TensorBuffers.toLongs(nativeHandle(tensor)); + return new DenseTInt64(tensor, buffer); + } + + private static final class DenseTInt64 extends LongDenseNdArray implements TInt64 { + + @Override + public DataType dataType() { + return TInt64.DTYPE; + } + + @Override + public RawTensor asRawTensor() { + return rawTensor; + } + + final RawTensor rawTensor; + + DenseTInt64(RawTensor rawTensor, LongDataBuffer buffer) { + super(buffer, rawTensor.shape()); + this.rawTensor = rawTensor; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TStringInitializer.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TStringInitializer.java new file mode 100644 index 00000000000..84c2d6aa228 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TStringInitializer.java @@ -0,0 +1,38 @@ +package org.tensorflow.internal.types; + +import java.util.function.Consumer; +import java.util.function.Function; +import org.tensorflow.internal.buffer.ByteSequenceTensorBuffer; +import org.tensorflow.internal.buffer.ByteSequenceProvider; +import org.tensorflow.internal.types.TStringMapper.TStringInternal; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.types.TString; + +/** + * Helper class for initializing a {@link TString} tensor. + * + * @param source of bytes ({@code byte[]} or {@code String}) + */ +public final class TStringInitializer implements Consumer { + + public TStringInitializer(NdArray source, Function byteExtractor) { + this.byteSequenceProvider = new ByteSequenceProvider<>(source, byteExtractor); + } + + /** + * Compute the minimum size for a tensor to hold all the data provided by the source. + * + * @return minimum tensor size, in bytes + * @see ByteSequenceTensorBuffer#computeSize(ByteSequenceProvider) + */ + public long computeRequiredSize() { + return ByteSequenceTensorBuffer.computeSize(byteSequenceProvider); + } + + @Override + public void accept(TString tensor) { + ((TStringInternal)tensor).init(byteSequenceProvider); + } + + private final ByteSequenceProvider byteSequenceProvider; +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TStringMapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TStringMapper.java new file mode 100644 index 00000000000..ce892c58944 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TStringMapper.java @@ -0,0 +1,88 @@ +package org.tensorflow.internal.types; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import org.tensorflow.DataType; +import org.tensorflow.RawTensor; +import org.tensorflow.TensorMapper; +import org.tensorflow.internal.buffer.ByteSequenceTensorBuffer; +import org.tensorflow.internal.buffer.ByteSequenceProvider; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.layout.DataLayout; +import org.tensorflow.ndarray.buffer.layout.DataLayouts; +import org.tensorflow.ndarray.impl.dense.DenseNdArray; +import org.tensorflow.types.TString; + +/** + * Maps memory of {@link org.tensorflow.proto.framework.DataType#DT_STRING} tensors + * to a n-dimensional data space. + */ +public final class TStringMapper extends TensorMapper { + + private static final DataLayout, String> UTF_8_LAYOUT = + DataLayouts.ofStrings(StandardCharsets.UTF_8); + + @Override + protected TString mapDense(RawTensor tensor) { + ByteSequenceTensorBuffer buffer = TensorBuffers.toStrings(nativeHandle(tensor), tensor.shape().size()); + return new DenseTString(tensor, buffer, UTF_8_LAYOUT); + } + + /** + * Adds package-private methods to all instances of {@code TString} + */ + interface TStringInternal extends TString { + + /** + * Initialize the buffer of this string tensor using the provided byte sequencer. + * + * @param byteSequenceProvider produces sequences of bytes to use as the tensor data + * @param source of bytes ({@code byte[]} or {@code String}) + */ + void init(ByteSequenceProvider byteSequenceProvider); + } + + private static final class DenseTString extends DenseNdArray implements TStringInternal { + + @Override + public void init(ByteSequenceProvider byteSequenceProvider) { + buffer.init(byteSequenceProvider); + } + + @Override + public TString using(Charset charset) { + return new DenseTString(rawTensor, buffer, DataLayouts.ofStrings(charset)); + } + + @Override + public NdArray asBytes() { + return NdArrays.wrap(shape(), buffer); + } + + @Override + public DataType dataType() { + return TString.DTYPE; + } + + @Override + public RawTensor asRawTensor() { + return rawTensor; + } + + final RawTensor rawTensor; + final ByteSequenceTensorBuffer buffer; + + DenseTString( + RawTensor rawTensor, + ByteSequenceTensorBuffer buffer, + DataLayout, String> layout + ) { + super(layout.applyTo(buffer), rawTensor.shape()); + this.rawTensor = rawTensor; + this.buffer = buffer; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TUint8Mapper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TUint8Mapper.java new file mode 100644 index 00000000000..72c556ee411 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TUint8Mapper.java @@ -0,0 +1,42 @@ +package org.tensorflow.internal.types; + +import org.tensorflow.DataType; +import org.tensorflow.RawTensor; +import org.tensorflow.TensorMapper; +import org.tensorflow.internal.buffer.TensorBuffers; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.impl.dense.ByteDenseNdArray; +import org.tensorflow.types.TUint8; + +/** + * Maps memory of {@link org.tensorflow.proto.framework.DataType#DT_UINT8} tensors + * to a n-dimensional data space. + */ +public final class TUint8Mapper extends TensorMapper { + + @Override + protected TUint8 mapDense(RawTensor tensor) { + ByteDataBuffer buffer = TensorBuffers.toBytes(nativeHandle(tensor)); + return new DenseTUint8(tensor, buffer); + } + + private static final class DenseTUint8 extends ByteDenseNdArray implements TUint8 { + + @Override + public DataType dataType() { + return TUint8.DTYPE; + } + + @Override + public RawTensor asRawTensor() { + return rawTensor; + } + + final RawTensor rawTensor; + + DenseTUint8(RawTensor rawTensor, ByteDataBuffer buffer) { + super(buffer, rawTensor.shape()); + this.rawTensor = rawTensor; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java index e7fd03af46a..94c6f8790b5 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java @@ -19,18 +19,14 @@ import java.util.function.Consumer; import org.tensorflow.DataType; -import org.tensorflow.RawTensor; import org.tensorflow.Tensor; import org.tensorflow.exceptions.TensorFlowException; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.internal.types.TBfloat16Mapper; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.buffer.FloatDataBuffer; -import org.tensorflow.ndarray.buffer.layout.DataLayouts; -import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; import org.tensorflow.types.family.TFloating; /** @@ -54,7 +50,7 @@ public interface TBfloat16 extends FloatNdArray, TFloating { static final String NAME = "BFLOAT16"; /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 14, 2, TBfloat16Impl::mapTensor); + DataType DTYPE = DataType.create(NAME, 14, 2, new TBfloat16Mapper()); /** * Allocates a new tensor for storing a single float value. @@ -125,28 +121,3 @@ static TBfloat16 tensorOf(Shape shape, Consumer dataInit) { } } -/** Hidden implementation of a {@code TBfloat16} */ -class TBfloat16Impl extends FloatDenseNdArray implements TBfloat16 { - - @Override - public DataType dataType() { - return TBfloat16.DTYPE; - } - - @Override - public RawTensor asRawTensor() { - return rawTensor; - } - - static TBfloat16 mapTensor(RawTensor tensor, TF_Tensor nativeHandle) { - FloatDataBuffer buffer = DataLayouts.BFLOAT16.applyTo(TensorBuffers.toShorts(nativeHandle)); - return new TBfloat16Impl(tensor, buffer); - } - - private final RawTensor rawTensor; - - private TBfloat16Impl(RawTensor rawTensor, FloatDataBuffer buffer) { - super(buffer, rawTensor.shape()); - this.rawTensor = rawTensor; - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java index 0571dce410c..bab5e7910b4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java @@ -24,6 +24,7 @@ import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.buffer.TensorBuffers; import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.internal.types.TBoolMapper; import org.tensorflow.ndarray.BooleanNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.Shape; @@ -45,7 +46,7 @@ public interface TBool extends BooleanNdArray, TType { static final String NAME = "BOOL"; /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 10, 1, TBoolImpl::mapTensor); + DataType DTYPE = DataType.create(NAME, 10, 1, new TBoolMapper()); /** * Allocates a new tensor for storing a single boolean value. @@ -115,29 +116,3 @@ static TBool tensorOf(Shape shape, Consumer dataInit) { return Tensor.of(DTYPE, shape, dataInit); } } - -/** Hidden implementation of a {@code TBool} */ -class TBoolImpl extends BooleanDenseNdArray implements TBool { - - @Override - public DataType dataType() { - return TBool.DTYPE; - } - - @Override - public RawTensor asRawTensor() { - return rawTensor; - } - - static TBool mapTensor(RawTensor tensor, TF_Tensor nativeHandle) { - BooleanDataBuffer buffer = TensorBuffers.toBooleans(nativeHandle); - return new TBoolImpl(tensor, buffer); - } - - private final RawTensor rawTensor; - - private TBoolImpl(RawTensor rawTensor, BooleanDataBuffer buffer) { - super(buffer, rawTensor.shape()); - this.rawTensor = rawTensor; - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java index b675701b0d0..0decbb66d12 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java @@ -19,18 +19,14 @@ import java.util.function.Consumer; import org.tensorflow.DataType; -import org.tensorflow.RawTensor; import org.tensorflow.Tensor; import org.tensorflow.exceptions.TensorFlowException; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.internal.types.TFloat16Mapper; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.buffer.FloatDataBuffer; -import org.tensorflow.ndarray.buffer.layout.DataLayouts; -import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; import org.tensorflow.types.family.TFloating; /** @@ -52,7 +48,7 @@ public interface TFloat16 extends FloatNdArray, TFloating { static final String NAME = "FLOAT16"; /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 19, 2, TFloat16Impl::mapTensor); + DataType DTYPE = DataType.create(NAME, 19, 2, new TFloat16Mapper()); /** * Allocates a new tensor for storing a single float value. @@ -122,29 +118,3 @@ static TFloat16 tensorOf(Shape shape, Consumer dataInit) { return Tensor.of(DTYPE, shape, dataInit); } } - -/** Hidden implementation of a {@code TFloat16} */ -class TFloat16Impl extends FloatDenseNdArray implements TFloat16 { - - @Override - public DataType dataType() { - return TFloat16.DTYPE; - } - - @Override - public RawTensor asRawTensor() { - return rawTensor; - } - - static TFloat16 mapTensor(RawTensor tensor, TF_Tensor nativeHandle) { - FloatDataBuffer buffer = DataLayouts.FLOAT16.applyTo(TensorBuffers.toShorts(nativeHandle)); - return new TFloat16Impl(tensor, buffer); - } - - private final RawTensor rawTensor; - - private TFloat16Impl(RawTensor rawTensor, FloatDataBuffer buffer) { - super(buffer, rawTensor.shape()); - this.rawTensor = rawTensor; - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java index 9bcefd628c2..6300650811e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java @@ -19,17 +19,14 @@ import java.util.function.Consumer; import org.tensorflow.DataType; -import org.tensorflow.RawTensor; import org.tensorflow.Tensor; import org.tensorflow.exceptions.TensorFlowException; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.internal.types.TFloat32Mapper; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.buffer.FloatDataBuffer; -import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; import org.tensorflow.types.family.TFloating; /** IEEE-754 single-precision 32-bit float tensor type. */ @@ -39,7 +36,7 @@ public interface TFloat32 extends FloatNdArray, TFloating { static final String NAME = "FLOAT"; /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 1, 4, TFloat32Impl::mapTensor); + DataType DTYPE = DataType.create(NAME, 1, 4, new TFloat32Mapper()); /** * Allocates a new tensor for storing a single float value. @@ -109,29 +106,3 @@ static TFloat32 tensorOf(Shape shape, Consumer dataInit) { return Tensor.of(DTYPE, shape, dataInit); } } - -/** Hidden implementation of a {@code TFloat32} */ -class TFloat32Impl extends FloatDenseNdArray implements TFloat32 { - - @Override - public DataType dataType() { - return TFloat32.DTYPE; - } - - @Override - public RawTensor asRawTensor() { - return rawTensor; - } - - static TFloat32 mapTensor(RawTensor tensor, TF_Tensor nativeHandle) { - FloatDataBuffer buffer = TensorBuffers.toFloats(nativeHandle); - return new TFloat32Impl(tensor, buffer); - } - - private final RawTensor rawTensor; - - private TFloat32Impl(RawTensor rawTensor, FloatDataBuffer buffer) { - super(buffer, rawTensor.shape()); - this.rawTensor = rawTensor; - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java index 806725d5b21..923b9992400 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java @@ -24,6 +24,7 @@ import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.buffer.TensorBuffers; import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.internal.types.TFloat64Mapper; import org.tensorflow.ndarray.DoubleNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.Shape; @@ -40,7 +41,7 @@ public interface TFloat64 extends DoubleNdArray, TFloating { static final String NAME = "DOUBLE"; /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 2, 8, TFloat64Impl::mapTensor); + DataType DTYPE = DataType.create(NAME, 2, 8, new TFloat64Mapper()); /** * Allocates a new tensor for storing a single double value. @@ -110,29 +111,3 @@ static TFloat64 tensorOf(Shape shape, Consumer dataInit) { return Tensor.of(DTYPE, shape, dataInit); } } - -/** Hidden implementation of a {@code TFloat64} */ -class TFloat64Impl extends DoubleDenseNdArray implements TFloat64 { - - @Override - public DataType dataType() { - return TFloat64.DTYPE; - } - - @Override - public RawTensor asRawTensor() { - return rawTensor; - } - - static TFloat64 mapTensor(RawTensor tensor, TF_Tensor nativeHandle) { - DoubleDataBuffer buffer = TensorBuffers.toDoubles(nativeHandle); - return new TFloat64Impl(tensor, buffer); - } - - private final RawTensor rawTensor; - - private TFloat64Impl(RawTensor rawTensor, DoubleDataBuffer buffer) { - super(buffer, rawTensor.shape()); - this.rawTensor = rawTensor; - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java index 1aa4333f34f..ccb865e2793 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java @@ -23,6 +23,7 @@ import org.tensorflow.Tensor; import org.tensorflow.internal.buffer.TensorBuffers; import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.internal.types.TInt32Mapper; import org.tensorflow.ndarray.IntNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.Shape; @@ -38,7 +39,7 @@ public interface TInt32 extends IntNdArray, TNumber { static final String NAME = "INT32"; /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 3, 4, TInt32Impl::mapTensor); + DataType DTYPE = DataType.create(NAME, 3, 4, new TInt32Mapper()); /** * Allocates a new tensor for storing a single int value. @@ -109,28 +110,3 @@ static TInt32 tensorOf(Shape shape, Consumer dataInit) { } } -/** Hidden implementation of a {@code TInt32} */ -class TInt32Impl extends IntDenseNdArray implements TInt32 { - - @Override - public DataType dataType() { - return TInt32.DTYPE; - } - - @Override - public RawTensor asRawTensor() { - return rawTensor; - } - - static TInt32 mapTensor(RawTensor tensor, TF_Tensor nativeHandle) { - IntDataBuffer buffer = TensorBuffers.toInts(nativeHandle); - return new TInt32Impl(tensor, buffer); - } - - private final RawTensor rawTensor; - - private TInt32Impl(RawTensor rawTensor, IntDataBuffer buffer) { - super(buffer, rawTensor.shape()); - this.rawTensor = rawTensor; - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java index 0853ae9bac7..02763391ff6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java @@ -24,6 +24,7 @@ import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.buffer.TensorBuffers; import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.internal.types.TInt64Mapper; import org.tensorflow.ndarray.LongNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.Shape; @@ -39,7 +40,7 @@ public interface TInt64 extends LongNdArray, TNumber { static final String NAME = "INT64"; /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 9, 8, TInt64Impl::mapTensor); + DataType DTYPE = DataType.create(NAME, 9, 8, new TInt64Mapper()); /** * Allocates a new tensor for storing a single long value. @@ -109,29 +110,3 @@ static TInt64 tensorOf(Shape shape, Consumer dataInit) { return Tensor.of(DTYPE, shape, dataInit); } } - -/** Hidden implementation of a {@code TInt64} */ -class TInt64Impl extends LongDenseNdArray implements TInt64 { - - @Override - public DataType dataType() { - return TInt64.DTYPE; - } - - @Override - public RawTensor asRawTensor() { - return rawTensor; - } - - static TInt64 mapTensor(RawTensor tensor, TF_Tensor nativeHandle) { - LongDataBuffer buffer = TensorBuffers.toLongs(nativeHandle); - return new TInt64Impl(tensor, buffer); - } - - private final RawTensor rawTensor; - - private TInt64Impl(RawTensor rawTensor, LongDataBuffer buffer) { - super(buffer, rawTensor.shape()); - this.rawTensor = rawTensor; - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java index c6f3a9872a9..6d7f7426c1a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java @@ -21,18 +21,13 @@ import java.nio.charset.StandardCharsets; import java.util.function.Function; import org.tensorflow.DataType; -import org.tensorflow.RawTensor; import org.tensorflow.Tensor; -import org.tensorflow.internal.buffer.StringTensorBuffer; -import org.tensorflow.internal.buffer.TensorBuffers; -import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.internal.types.TStringInitializer; +import org.tensorflow.internal.types.TStringMapper; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.buffer.DataBuffer; -import org.tensorflow.ndarray.buffer.layout.DataLayout; -import org.tensorflow.ndarray.buffer.layout.DataLayouts; -import org.tensorflow.ndarray.impl.dense.DenseNdArray; import org.tensorflow.types.family.TType; /** @@ -50,7 +45,7 @@ public interface TString extends NdArray, TType { static final String NAME = "STRING"; /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 7, -1, TStringImpl::mapTensor); + DataType DTYPE = DataType.create(NAME, 7, -1, new TStringMapper()); /** * Allocates a new tensor for storing a string scalar. @@ -114,7 +109,8 @@ static TString tensorOf(NdArray src) { * @return the new tensor */ static TString tensorOf(Charset charset, NdArray src) { - return TStringImpl.createTensor(src, s -> s.getBytes(charset)); + TStringInitializer initializer = new TStringInitializer<>(src, s -> s.getBytes(charset)); + return Tensor.of(TString.DTYPE, src.shape(), initializer.computeRequiredSize(), initializer); } /** @@ -174,7 +170,8 @@ static TString tensorOf(Charset charset, Shape shape, DataBuffer data) { * @return the new tensor */ static TString tensorOfBytes(NdArray src) { - return TStringImpl.createTensor(src, Function.identity()); + TStringInitializer initializer = new TStringInitializer<>(src, Function.identity()); + return Tensor.of(TString.DTYPE, src.shape(), initializer.computeRequiredSize(), initializer); } /** @@ -218,57 +215,3 @@ static TString tensorOfBytes(Shape shape, DataBuffer data) { /** @return the tensor data as a n-dimensional array of raw byte sequences. */ NdArray asBytes(); } - -/** Hidden implementation of a {@code TString} */ -class TStringImpl extends DenseNdArray implements TString { - - @Override - public TString using(Charset charset) { - return new TStringImpl(rawTensor, tensorBuffer, DataLayouts.ofStrings(charset)); - } - - @Override - public NdArray asBytes() { - return NdArrays.wrap(shape(), tensorBuffer); - } - - @Override - public DataType dataType() { - return TString.DTYPE; - } - - @Override - public RawTensor asRawTensor() { - return rawTensor; - } - - static TString createTensor(NdArray src, Function getBytes) { - long size = StringTensorBuffer.computeSize(src, getBytes); - return Tensor.of( - TString.DTYPE, - src.shape(), - size, - data -> ((TStringImpl) data).tensorBuffer.init(src, getBytes)); - } - - static TString mapTensor(RawTensor tensor, TF_Tensor nativeHandle) { - StringTensorBuffer buffer = TensorBuffers.toStrings(nativeHandle, tensor.shape().size()); - return new TStringImpl(tensor, buffer, UTF_8_LAYOUT); - } - - private static final DataLayout, String> UTF_8_LAYOUT = - DataLayouts.ofStrings(StandardCharsets.UTF_8); - - private final RawTensor rawTensor; - private final StringTensorBuffer tensorBuffer; - - private TStringImpl( - RawTensor rawTensor, - StringTensorBuffer buffer, - DataLayout, String> layout - ) { - super(layout.applyTo(buffer), rawTensor.shape()); - this.rawTensor = rawTensor; - tensorBuffer = buffer; - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java index fd05857c295..a6f5dba8971 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java @@ -24,6 +24,7 @@ import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.buffer.TensorBuffers; import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.internal.types.TUint8Mapper; import org.tensorflow.ndarray.ByteNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.Shape; @@ -39,7 +40,7 @@ public interface TUint8 extends ByteNdArray, TNumber { static final String NAME = "UINT8"; /** Type metadata */ - DataType DTYPE = DataType.create(NAME, 4, 1, TUint8Impl::mapTensor); + DataType DTYPE = DataType.create(NAME, 4, 1, new TUint8Mapper()); /** * Allocates a new tensor for storing a single byte value. @@ -109,29 +110,3 @@ static TUint8 tensorOf(Shape shape, Consumer dataInit) { return Tensor.of(DTYPE, shape, dataInit); } } - -/** Hidden implementation of a {@code TUint8} */ -class TUint8Impl extends ByteDenseNdArray implements TUint8 { - - @Override - public DataType dataType() { - return TUint8.DTYPE; - } - - @Override - public RawTensor asRawTensor() { - return rawTensor; - } - - static TUint8 mapTensor(RawTensor tensor, TF_Tensor nativeHandle) { - ByteDataBuffer buffer = TensorBuffers.toBytes(nativeHandle); - return new TUint8Impl(tensor, buffer); - } - - private final RawTensor rawTensor; - - private TUint8Impl(RawTensor rawTensor, ByteDataBuffer buffer) { - super(buffer, rawTensor.shape()); - this.rawTensor = rawTensor; - } -}