-
Notifications
You must be signed in to change notification settings - Fork 214
[Type Refactor] Move type implementations and mappers to internal package #172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
karllessard
merged 3 commits into
tensorflow:master
from
karllessard:type-refactor-mappers
Dec 18, 2020
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
35 changes: 35 additions & 0 deletions
35
tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorMapper.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
package org.tensorflow; | ||
|
||
import org.tensorflow.internal.c_api.TF_Tensor; | ||
import org.tensorflow.types.family.TType; | ||
|
||
/** | ||
* Maps the native memory of a {@link RawTensor} to a n-dimensional typed data space | ||
* accessible from the JVM. | ||
* | ||
* <p>Usage of this class is reserved for internal purposes only. | ||
* | ||
* @param <T> tensor type mapped by this object | ||
* @see {@link TType} | ||
*/ | ||
public abstract class TensorMapper<T extends TType> { | ||
|
||
/** | ||
* Maps the provided dense raw {@code tensor} as a tensor of type {@code T}. | ||
* | ||
* @param tensor the dense tensor to map, in its raw nature | ||
* @return an instance of {@code T} | ||
*/ | ||
protected abstract T mapDense(RawTensor tensor); | ||
|
||
/** | ||
* Helper for retrieving the native handle of a raw tensor | ||
* | ||
* @param tensor a raw tensor | ||
* @return the native handle of that tensor | ||
* @throws IllegalStateException if the tensor has been released | ||
*/ | ||
protected static TF_Tensor nativeHandle(RawTensor tensor) { | ||
return tensor.nativeHandle(); | ||
} | ||
} |
53 changes: 53 additions & 0 deletions
53
...ensorflow-core-api/src/main/java/org/tensorflow/internal/buffer/ByteSequenceProvider.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
package org.tensorflow.internal.buffer; | ||
|
||
import java.util.Iterator; | ||
import java.util.function.Function; | ||
import org.tensorflow.ndarray.NdArray; | ||
import org.tensorflow.ndarray.NdArraySequence; | ||
|
||
/** | ||
* Produces sequence of bytes to be stored in a {@link ByteSequenceTensorBuffer}. | ||
* | ||
* @param <T> source of bytes (byte arrays or strings) | ||
*/ | ||
public class ByteSequenceProvider<T> implements Iterable<byte[]> { | ||
|
||
/** | ||
* Constructor | ||
* | ||
* @param source source of data | ||
* @param byteExtractor method that converts one value of the source into a sequence of bytes | ||
*/ | ||
public ByteSequenceProvider(NdArray<T> source, Function<T, byte[]> byteExtractor) { | ||
this.source = source; | ||
this.byteExtractor = byteExtractor; | ||
} | ||
|
||
@Override | ||
public Iterator<byte[]> iterator() { | ||
return new Iterator<byte[]>() { | ||
|
||
@Override | ||
public boolean hasNext() { | ||
return scalarIterator.hasNext(); | ||
} | ||
|
||
@Override | ||
public byte[] next() { | ||
return byteExtractor.apply(scalarIterator.next().getObject()); | ||
} | ||
|
||
private final Iterator<? extends NdArray<T>> scalarIterator = source.scalars().iterator(); | ||
}; | ||
} | ||
|
||
/** | ||
* @return total number of byte sequences that can be produced by this sequencer | ||
*/ | ||
long numSequences() { | ||
return source.size(); | ||
} | ||
|
||
private final NdArray<T> source; | ||
private final Function<T, byte[]> byteExtractor; | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
43 changes: 43 additions & 0 deletions
43
...core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBfloat16Mapper.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package org.tensorflow.internal.types; | ||
|
||
import org.tensorflow.TensorMapper; | ||
import org.tensorflow.DataType; | ||
import org.tensorflow.RawTensor; | ||
import org.tensorflow.internal.buffer.TensorBuffers; | ||
import org.tensorflow.ndarray.buffer.FloatDataBuffer; | ||
import org.tensorflow.ndarray.buffer.layout.DataLayouts; | ||
import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; | ||
import org.tensorflow.types.TBfloat16; | ||
|
||
/** | ||
* Maps memory of {@link org.tensorflow.proto.framework.DataType#DT_BFLOAT16} tensors | ||
* to a n-dimensional data space. | ||
*/ | ||
public final class TBfloat16Mapper extends TensorMapper<TBfloat16> { | ||
|
||
@Override | ||
protected TBfloat16 mapDense(RawTensor tensor) { | ||
FloatDataBuffer buffer = DataLayouts.BFLOAT16.applyTo(TensorBuffers.toShorts(nativeHandle(tensor))); | ||
return new DenseTBfloat16(tensor, buffer); | ||
} | ||
|
||
private static final class DenseTBfloat16 extends FloatDenseNdArray implements TBfloat16 { | ||
|
||
@Override | ||
public DataType<TBfloat16> dataType() { | ||
return TBfloat16.DTYPE; | ||
} | ||
|
||
@Override | ||
public RawTensor asRawTensor() { | ||
return rawTensor; | ||
} | ||
|
||
final RawTensor rawTensor; | ||
|
||
DenseTBfloat16(RawTensor rawTensor, FloatDataBuffer buffer) { | ||
super(buffer, rawTensor.shape()); | ||
this.rawTensor = rawTensor; | ||
} | ||
} | ||
} |
42 changes: 42 additions & 0 deletions
42
...low-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TBoolMapper.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
package org.tensorflow.internal.types; | ||
|
||
import org.tensorflow.DataType; | ||
import org.tensorflow.RawTensor; | ||
import org.tensorflow.TensorMapper; | ||
import org.tensorflow.internal.buffer.TensorBuffers; | ||
import org.tensorflow.ndarray.buffer.BooleanDataBuffer; | ||
import org.tensorflow.ndarray.impl.dense.BooleanDenseNdArray; | ||
import org.tensorflow.types.TBool; | ||
|
||
/** | ||
* Maps memory of {@link org.tensorflow.proto.framework.DataType#DT_BOOL} tensors | ||
* to a n-dimensional data space. | ||
*/ | ||
public final class TBoolMapper extends TensorMapper<TBool> { | ||
|
||
@Override | ||
protected TBool mapDense(RawTensor tensor) { | ||
BooleanDataBuffer buffer = TensorBuffers.toBooleans(nativeHandle(tensor)); | ||
return new DenseTBool(tensor, buffer); | ||
} | ||
|
||
private static final class DenseTBool extends BooleanDenseNdArray implements TBool { | ||
|
||
@Override | ||
public DataType<TBool> dataType() { | ||
return TBool.DTYPE; | ||
} | ||
|
||
@Override | ||
public RawTensor asRawTensor() { | ||
return rawTensor; | ||
} | ||
|
||
final RawTensor rawTensor; | ||
|
||
DenseTBool(RawTensor rawTensor, BooleanDataBuffer buffer) { | ||
super(buffer, rawTensor.shape()); | ||
this.rawTensor = rawTensor; | ||
} | ||
} | ||
} |
43 changes: 43 additions & 0 deletions
43
...-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/types/TFloat16Mapper.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package org.tensorflow.internal.types; | ||
|
||
import org.tensorflow.DataType; | ||
import org.tensorflow.RawTensor; | ||
import org.tensorflow.TensorMapper; | ||
import org.tensorflow.internal.buffer.TensorBuffers; | ||
import org.tensorflow.ndarray.buffer.FloatDataBuffer; | ||
import org.tensorflow.ndarray.buffer.layout.DataLayouts; | ||
import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; | ||
import org.tensorflow.types.TFloat16; | ||
|
||
/** | ||
* Maps memory of {@link org.tensorflow.proto.framework.DataType#DT_HALF} tensors | ||
* to a n-dimensional data space. | ||
*/ | ||
public final class TFloat16Mapper extends TensorMapper<TFloat16> { | ||
|
||
@Override | ||
protected TFloat16 mapDense(RawTensor tensor) { | ||
FloatDataBuffer buffer = DataLayouts.FLOAT16.applyTo(TensorBuffers.toShorts(nativeHandle(tensor))); | ||
return new DenseTFloat16(tensor, buffer); | ||
} | ||
|
||
private static final class DenseTFloat16 extends FloatDenseNdArray implements TFloat16 { | ||
|
||
@Override | ||
public DataType<?> dataType() { | ||
return TFloat16.DTYPE; | ||
} | ||
|
||
@Override | ||
public RawTensor asRawTensor() { | ||
return rawTensor; | ||
} | ||
|
||
final RawTensor rawTensor; | ||
|
||
DenseTFloat16(RawTensor rawTensor, FloatDataBuffer buffer) { | ||
super(buffer, rawTensor.shape()); | ||
this.rawTensor = rawTensor; | ||
} | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.