-
Notifications
You must be signed in to change notification settings - Fork 214
[Type Refactor] Use Java type system instead of custom one for typing tensors #174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Type Refactor] Use Java type system instead of custom one for typing tensors #174
Conversation
2a24e5d
to
ba3d471
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to clean up the import generation, as that's induced a lot of noise in this PR, but it doesn't have to be done here. I think that the framework module could do with a pass to enforce stricter types now TFloating and TIntegral exist, but again that doesn't have to happen here.
tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/bitwise/BitwiseOr.java
Show resolved
Hide resolved
...orflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective/BroadcastRecv.java
Outdated
Show resolved
Hide resolved
@@ -46,16 +47,12 @@ | |||
* @return a new instance of AssertCardinalityDataset | |||
*/ | |||
@Endpoint(describeByClass = true) | |||
public static AssertCardinalityDataset create(Scope scope, Operand<?> inputDataset, Operand<TInt64> cardinality, List<DataType<?>> outputTypes, List<Shape> outputShapes) { | |||
public static AssertCardinalityDataset create(Scope scope, Operand<?> inputDataset, Operand<TInt64> cardinality, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These seem to be the same as the classes in org.tensorflow.op.data
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see a AssertCardinalityDataset
class under org.tensorflow.op.data
, is it what you meant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ooops, sorry I meant AssertNextDataset
. There are a bunch of classes that are in both org.tensorflow.op.data
and org.tensorflow.op.data.experimental
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are two different kernels in TensorFlow runtime, like other duplicates found in these package as far as I can see. For example, the op under data
is using AssertNextDataset
while the other under experimental
is using ExperimentalAssertNextDataset
. So it is correct that they coexist, probably they achieve the same purpose but using different implementations. See here
tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java
Show resolved
Hide resolved
tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java
Show resolved
Hide resolved
tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java
Outdated
Show resolved
Hide resolved
tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java
Outdated
Show resolved
Hide resolved
tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java
Show resolved
Hide resolved
tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java
Outdated
Show resolved
Hide resolved
tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java
Show resolved
Hide resolved
Thanks a lot for the quick review @Craigacp ! |
In the TestSession classes, there is still a lot of checking on the types to call the right method for evaluating and printing. For example compare a float value to TFLoat32 data, double value to TFLoat64 data, etc. Is there now a better way to do this? |
I don't think this PR is related to such case but an idea like that, since you know that the supported types are all numeric values, you can probably just cast the boxed value returned by For example, where o.asTensor().scalars().forEach(f -> assertEquals(expected, ((Number)f.getObject()).doubleValue())); |
Let me play with your suggestion. I am redoing the TestSession classes to add support for Placeholders and Feeds, so now would be a good time to clean it up. |
The labels are restricted to the values -1, 0 and 1. So restricting it to integers seems appropriate.
Jim
… On Dec 23, 2020, at 11:15 PM, Karl Lessard ***@***.***> wrote:
@karllessard commented on this pull request.
In tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java:
> @@ -124,16 +124,13 @@ public Hinge(Ops tf, String name, Reduction reduction) {
public <T extends TNumber, U extends TNumber> Operand<T> call(
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
Just to make sure I understood @JimClarke5 , so labels could be any numeric values while predictions and sampleWeights must be restricted to floating-points only in all three losses you've mentioned?
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub, or unsubscribe.
|
Ok so I've pushed the last changes that, I think, covers up comments from @Craigacp and @JimClarke5 . @deansher , are you interested to take a look as well, as you were a major instigator of that proposal as well? If you want but don't have time right now, just let me know and I'll wait before merging, thanks! |
Thank you, Karl — I would love to take a look, if waiting until early New
Year's week to merge doesn't seem like a drag.
Dean
…On Thu, Dec 24, 2020 at 6:06 PM Karl Lessard ***@***.***> wrote:
Ok so I've pushed the last changes that, I think, covers up comments from
@Craigacp <https://github.com/Craigacp> and @JimClarke5
<https://github.com/JimClarke5> .
@deansher <https://github.com/deansher> , are you interested to take a
look as well, as you were a major instigator of that proposal as well? If
you want but don't have time right now, just let me know and I'll wait
before merging, thanks!
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#174 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AABZ7X2JPU6NOVOIHIWQRVDSWPCQNANCNFSM4VFBJWLA>
.
|
@rnett , @JimClarke5 : I don't expect that there will be major changes after Dean's review so if you want to get unblocked faster, I invite you to rebase your work on this PR branch and you can even start producing new PRs that we could merge after this one. |
I'm working on the <U extends TType, T extends TType> Cast<U> cast(Operand<T> x, Class<U> DstT, Cast.Options... options)
It might be out of scope of this PR to change this now, since it probably is best done in the cc generation code (to change the op class |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I love how this came out! I did propose changes, but they are nits.
Thanks, @karllessard , for inviting me back into the loop on this. It's exciting to see it land! What a lot of work!
// Minimum requirements for datatypes of variable length cannot be verified in a relevant way so | ||
// we only validate them for fixed length datatypes | ||
if (!dtype.isVariableLength() && shape.size() * dtype.byteSize() > size) { | ||
static RawTensor allocate(Class<? extends TType> type, Shape shape, long size) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Augment this method to handle a shape of UNKNOWN_SIZE
? (The previous version accidentally handled it.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, passing a totally or partially unknown shape to this constructor should be forbidden, I'll add a check for this.
* @param shape shape of the tensor | ||
* @param size size, in bytes, of the tensor | ||
* @param size size in bytes of the tensor or -1 to compute the size from the shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document the behavior when shape
has UNKNOWN_SIZE
and -1
is passed for size
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as in #174 (comment)
} | ||
|
||
/** | ||
* Allocates a tensor of a given datatype, shape and size. | ||
* | ||
* <p>This method is identical to {@link #of(DataType, Shape, Consumer)}, except that the final | ||
* <p>This method is identical to {@link #of(Class, Shape, Consumer)}, except that the final | ||
* size for the tensor is explicitly set instead of being computed from the datatype and shape. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the new support for size
of -1
, perhaps change to "can be explicitly set".
tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java
Outdated
Show resolved
Hide resolved
/** | ||
* Returns the class of this tensor type | ||
*/ | ||
public Class<T> typeClass() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Elsewhere (such as in Tensor
), we simply call this type()
. I do feel the emotional tug to be more explicit in this case, but I wonder if it will simply feel like inconsistency once we have lived with the new paradigm for a while.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was hesitating on this one as well so sounds like we are two now, meaning that it should be type()
, I'll rename it.
@@ -83,7 +79,7 @@ | |||
*/ | |||
@Endpoint(name = "flatten") | |||
public static <T extends TType, U extends TNumber> Operand<T> flatten( | |||
Scope scope, Operand<T> operand, DataType<U> dType) { | |||
Scope scope, Operand<T> operand, Class<U> dType) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dType
-> type
throughout this file
return Cast.create(scope, result, logits.asOutput().dataType()); | ||
} else if(!logits.asOutput().dataType().equals(labels.asOutput().dataType())) { | ||
return Cast.create(scope, result, logits.asOutput().type()); | ||
} else if(!logits.asOutput().type().equals(labels.asOutput().type())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
!=
would be more consistent.
@@ -197,8 +192,8 @@ | |||
*/ | |||
private static <T extends TNumber, U extends TNumber> Operand<T> moveDimToEnd( | |||
Scope scope, Operand<T> input, int dimIndex, Operand<U> rank) { | |||
DataType<? extends TNumber> rankDType = rank.asOutput().dataType(); | |||
Operand one = Cast.create(scope, Constant.scalarOf(scope, 1), rankDType); | |||
Class<U> rankDType = rank.asOutput().type(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rankDType
-> rankType
TFloat32 floatTensor = rawTensor.asTypedTensor(); | ||
assertSame(floatTensor.asRawTensor(), rawTensor); | ||
try { | ||
TInt32 intTensor = rawTensor.asTypedTensor(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test code throws the expected exception because of the assignment to intTensor
, rather than in asTypedTensor()
. Here is alternative test code that also passes, but demonstrates that asTypedTensor()
doesn't have its documented behavior:
@Test
public void rawToTypedTensor() {
RawTensor rawTensor = RawTensor.allocate(TFloat32.class, Shape.of(2, 2), -1);
TFloat32 floatTensor = rawTensor.asTypedTensor();
assertSame(floatTensor.asRawTensor(), rawTensor);
Object objTensor = rawTensor.<TInt32>asTypedTensor();
try {
TInt32 intTensor = (TInt32) objTensor;
fail();
} catch (ClassCastException e) {
// ok
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I've never tried that before. You want me to update the doc or try to throw an ClassCastException
in this case as well?
Note that asTypedTensor()
is an internal package-private method and where it is currently being used, it is implicitly enforced that the returned type matches the type of the tensor for all cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updating the doc seems good to me for both of the reasons you give.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be honest I'm not too sure how to document this. It looks to me that this behavior is some kind of "defect" in the generic specification, since the explicit <TInt32>
parameterization of the method invocation seems to be overridden implicitly to the type inferred by the target (Object
in this case).
So basically, <TInt32>
is ignored and that is probably the case for Java type inference in general when dealing with type parameters inferred by a target, is it worth documenting it here then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type will be erased at runtime so the cast is erased to its bound (TType
) which always succeeds. It might be better to have it explicitly return TType
rather than have it promise something that can't be enforced by the type system. This is going to pollute the internals of our code with casts, but at least we'll remember to put them in rather than having it mysteriously blow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok sounds fair, let's do that, anyway all these additional casts will happen internally only.
Update: ended up that only one additional cast was required in the source code...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On further reflection, I lean the same way as @Craigacp: have this method return TType
.
Here's some analysis to support this choice: At runtime, the current behavior is "apply the mapper from typeInfo
and return whatever TType
it produces." Since the type parameter T
is erased at runtime, the only way to guarantee a return of T
would be to have the method take the tensor type class as a runtime parameter and explicitly verify that the mapper returns that type class. I might advocate this approach for a public method or even a widely used package-private method, but not for a rarely used package-private method.
But also, a review of the call points of this method raises another issue: whenever we treat the return value of asTypedTensor
as simply a Tensor
(rather than a TType
), we "forget" at compile time the semantic upgrade that was presumably the point of the asTypedTensor
call in the first place.
As an example, let's explore what happens underneath the following method:
class EagerOperation extends AbstractOperation {
// ...
/**
* Returns the tensor of the {@code outputIdx}th output of this operation.
*
* <p>This is only supported in an eager execution environment.
*
* @param outputIdx index of the output of this operation
* @return output tensor
*/
@Override
Tensor tensor(int outputIndex) {
Tensor tensor = outputTensors.get(outputIndex);
if (tensor == null) {
tensor = resolveTensor(outputIndex);
}
return tensor;
}
// ...
}
(I copied the Javadoc above the @Override
for easy reference.)
The caller should presumably be agnostic as to whether this returns a RawTensor
or a TType
. But looking deeper into the call path, here's how the returned Tensor
is constructed:
private static Tensor resolveTensorHandle(TFE_TensorHandle handle, EagerSession session) {
requireTensorHandle(handle);
try (PointerScope scope = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
TF_Tensor tensor = TFE_TensorHandleResolve(handle, status).withDeallocator();
status.throwExceptionIfNotOK();
return RawTensor.fromHandle(tensor, session).asTypedTensor();
}
}
What's the point of asTypedTensor()
on the last substantive line, above? It causes this method to return a special kind of Tensor
-- a TType
-- but that fact is immediately forgotten by the type system and is not even asserted in the Javadoc. Presumably, we should either drop the call to .asTypedTensor()
or change the whole call path to return TType
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:-) But not for this PR!
} | ||
|
||
@Test | ||
public void allocateTensorWithoutSize() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add test to show intended behavior of shape with UNKNOWN_SIZE
.
I've just pushed a last version that should now covered all discussed topics during the review, I'll merge it as soon as I receive a green light from your behalf and we will be done with that massive refactoring (and hopefully the last one of this nature). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Eagle has landed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good job @karllessard
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The typing in TensorTypeRegistry
should probably be relaxed as I've indicated, but that can wait till the next PR.
OperationBuilder opBuilder = scope.env().opBuilder("CollectiveBcastRecv", scope.makeOpName("BroadcastRecv")); | ||
opBuilder = scope.apply(opBuilder); | ||
opBuilder.setAttr("T", T); | ||
opBuilder.setAttr("T", Operands.toDataType(T)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At some point we should fix the generator so it doesn't shadow the type name with a variable name, that's just confusing.
* @return type registered information | ||
* @throws IllegalArgumentException if no tensor type for this data type has been registered | ||
*/ | ||
public static <T extends TType> TensorTypeInfo<T> find(DataType dataType) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a place where it will infer the type bound from the calling context, and if that's wrong then we'll get a weird class cast error when people use it. It might be better to return the wildcard as at least the user will get a warning when they make the cast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again note that the TensorTypeRegistry
is (for now) an internal class so we do have some control over the context it is being called. But like you've suggested, let's review this later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I know. But I'm in favour of having the compiler warn me before I make a silly mistake, and 6 months from now I won't necessarily remember that this method doesn't quite live up to its contract.
All right, let's merge this before someone changes his mind :) Thank you all for your reviews and comments! |
This PR is the last of importance in the tensor type refactoring currently in progress and it has a lot of backward incompatible changes.
We drop the current
DataType
custom class to leverage the standard Java type system instead, using.class
. So now, each of the tensor types (i.e. subinterfaces ofTType
, likeTFloat
) carries itself the information required for allocating and mapping tensors of this type, via the annotation@TensorType
.All reference to the static variable
DTYPE
of the different tensor types can now be replaced by.class
. For example:In most places where the custom
DataType
class has been dropped, it has been replaced with its protobuf equivalent of the same name which consists of a simple enum. This provides type information at low level from the runtime library directly, and the API of theOp
framework will naturally guide the users to favor the usage of the tensor type classes (e.g.TFloat32.class
) instead.In a nut shell, the PR consists of:
DataType
andDataTypes
classes to use JavaClass
insteadTType
)T*.DTYPE
byT*.class
tensorflow-core/tensorflow-core-api/src/gen
)It is important to note that after this PR has been merged, the ongoing work from @rnett (Kotlin) and @JimClarke5 (Framework) could be resumed.
CC\ @deansher , @Craigacp