Skip to content

[Type Refactor] Use Java type system instead of custom one for typing tensors #174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 30, 2020

Conversation

karllessard
Copy link
Collaborator

This PR is the last of importance in the tensor type refactoring currently in progress and it has a lot of backward incompatible changes.

We drop the current DataType custom class to leverage the standard Java type system instead, using .class. So now, each of the tensor types (i.e. subinterfaces of TType, like TFloat) carries itself the information required for allocating and mapping tensors of this type, via the annotation @TensorType.

All reference to the static variable DTYPE of the different tensor types can now be replaced by .class. For example:

// Before
Operand<TFloat32> x = tf.dtypes.cast(tf.constant(1), TFloat32.DTYPE);
assertSame(x.dataType(), TFloat32.DTYPE);
if (x.dataType() == TFloat32.DTYPE || x.dataType() == TFloat64.DTYPE) {
    // floating-point only computations...
}

// After
Operand<TFloat32> x = tf.dtypes.cast(tf.constant(1), TFloat32.class); 
assertSame(x.type(), TFloat32.class);
if (TFloating.class.isAssignableFrom(x.type()) {
    // floating-point only computations...
}

In most places where the custom DataType class has been dropped, it has been replaced with its protobuf equivalent of the same name which consists of a simple enum. This provides type information at low level from the runtime library directly, and the API of the Op framework will naturally guide the users to favor the usage of the tensor type classes (e.g. TFloat32.class) instead.

TFloat32 t = TFloat32.scalarOf(1.0f);
assertSame(t.type(), TFloat32.class);
assertSame(t.dataType(), DataType.DT_FLOAT);

In a nut shell, the PR consists of:

  • Dropping custom DataType and DataTypes classes to use Java Class instead
  • Adding annotation and other utilities to register the tensor type classes (subinterfaces of TType)
  • Replace all reference to T*.DTYPE by T*.class
  • Reshuffle a little bit some functions of the framework to leverage the new type system
  • A lot of changes in generated files that could be skipped during the review (everything under tensorflow-core/tensorflow-core-api/src/gen)

It is important to note that after this PR has been merged, the ongoing work from @rnett (Kotlin) and @JimClarke5 (Framework) could be resumed.

CC\ @deansher , @Craigacp

@karllessard karllessard requested a review from Craigacp December 22, 2020 03:16
Copy link
Collaborator

@Craigacp Craigacp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to clean up the import generation, as that's induced a lot of noise in this PR, but it doesn't have to be done here. I think that the framework module could do with a pass to enforce stricter types now TFloating and TIntegral exist, but again that doesn't have to happen here.

@@ -46,16 +47,12 @@
* @return a new instance of AssertCardinalityDataset
*/
@Endpoint(describeByClass = true)
public static AssertCardinalityDataset create(Scope scope, Operand<?> inputDataset, Operand<TInt64> cardinality, List<DataType<?>> outputTypes, List<Shape> outputShapes) {
public static AssertCardinalityDataset create(Scope scope, Operand<?> inputDataset, Operand<TInt64> cardinality, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These seem to be the same as the classes in org.tensorflow.op.data?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a AssertCardinalityDataset class under org.tensorflow.op.data, is it what you meant?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooops, sorry I meant AssertNextDataset. There are a bunch of classes that are in both org.tensorflow.op.data and org.tensorflow.op.data.experimental.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are two different kernels in TensorFlow runtime, like other duplicates found in these package as far as I can see. For example, the op under data is using AssertNextDataset while the other under experimental is using ExperimentalAssertNextDataset. So it is correct that they coexist, probably they achieve the same purpose but using different implementations. See here

@karllessard
Copy link
Collaborator Author

Thanks a lot for the quick review @Craigacp !

@JimClarke5
Copy link
Contributor

JimClarke5 commented Dec 22, 2020

In the TestSession classes, there is still a lot of checking on the types to call the right method for evaluating and printing. For example compare a float value to TFLoat32 data, double value to TFLoat64 data, etc. Is there now a better way to do this?

@karllessard
Copy link
Collaborator Author

In the TestSession classes, there is still a lot of checking on the types to call the right method for evaluating and printing. For example compare a float value to TFLoat32 data, double value to TFLoat64 data, etc. Is there now a better way to do this?

I don't think this PR is related to such case but an idea like that, since you know that the supported types are all numeric values, you can probably just cast the boxed value returned by getObject() to a Number and then use the doubleValue of that number?

For example, where expected is a double value,

o.asTensor().scalars().forEach(f -> assertEquals(expected, ((Number)f.getObject()).doubleValue()));

@JimClarke5
Copy link
Contributor

Let me play with your suggestion. I am redoing the TestSession classes to add support for Placeholders and Feeds, so now would be a good time to clean it up.

@JimClarke5
Copy link
Contributor

JimClarke5 commented Dec 24, 2020 via email

@karllessard
Copy link
Collaborator Author

Ok so I've pushed the last changes that, I think, covers up comments from @Craigacp and @JimClarke5 .

@deansher , are you interested to take a look as well, as you were a major instigator of that proposal as well? If you want but don't have time right now, just let me know and I'll wait before merging, thanks!

@deansher
Copy link
Contributor

deansher commented Dec 25, 2020 via email

@karllessard
Copy link
Collaborator Author

@rnett , @JimClarke5 : I don't expect that there will be major changes after Dean's review so if you want to get unblocked faster, I invite you to rebase your work on this PR branch and you can even start producing new PRs that we could merge after this one.

@rnett
Copy link
Contributor

rnett commented Dec 28, 2020

I'm working on the reified generation for the Kotlin API, and I'm noticing that lots of methods have unnecessary type parameters that makes the reified usage much less nice (since you have to specify all type parameters if you specify one). The best example is probably cast, which has the signiture:

<U extends TType, T extends TType> Cast<U> cast(Operand<T> x, Class<U> DstT, Cast.Options... options)

T is completely unnecessary and could be replaced with ? without issue, but it prevents cast<TInx32>(x) usage from Kotlin. This shows up in a number of Ops, mostly with the unnecessary type parameters on the input. It essentially needs a "is this type param only bounded by TType and only used on inputs" check.

It might be out of scope of this PR to change this now, since it probably is best done in the cc generation code (to change the op class create methods).

@karllessard
Copy link
Collaborator Author

That's very interesting @rnett , and yes it should be addressed outside this PR, I've created a new issue so we can continue the discussion from there: #176

Copy link
Contributor

@deansher deansher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love how this came out! I did propose changes, but they are nits.

Thanks, @karllessard , for inviting me back into the loop on this. It's exciting to see it land! What a lot of work!

// Minimum requirements for datatypes of variable length cannot be verified in a relevant way so
// we only validate them for fixed length datatypes
if (!dtype.isVariableLength() && shape.size() * dtype.byteSize() > size) {
static RawTensor allocate(Class<? extends TType> type, Shape shape, long size) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Augment this method to handle a shape of UNKNOWN_SIZE? (The previous version accidentally handled it.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, passing a totally or partially unknown shape to this constructor should be forbidden, I'll add a check for this.

* @param shape shape of the tensor
* @param size size, in bytes, of the tensor
* @param size size in bytes of the tensor or -1 to compute the size from the shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document the behavior when shape has UNKNOWN_SIZE and -1 is passed for size?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as in #174 (comment)

}

/**
* Allocates a tensor of a given datatype, shape and size.
*
* <p>This method is identical to {@link #of(DataType, Shape, Consumer)}, except that the final
* <p>This method is identical to {@link #of(Class, Shape, Consumer)}, except that the final
* size for the tensor is explicitly set instead of being computed from the datatype and shape.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the new support for size of -1, perhaps change to "can be explicitly set".

/**
* Returns the class of this tensor type
*/
public Class<T> typeClass() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Elsewhere (such as in Tensor), we simply call this type(). I do feel the emotional tug to be more explicit in this case, but I wonder if it will simply feel like inconsistency once we have lived with the new paradigm for a while.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hesitating on this one as well so sounds like we are two now, meaning that it should be type(), I'll rename it.

@@ -83,7 +79,7 @@
*/
@Endpoint(name = "flatten")
public static <T extends TType, U extends TNumber> Operand<T> flatten(
Scope scope, Operand<T> operand, DataType<U> dType) {
Scope scope, Operand<T> operand, Class<U> dType) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dType -> type throughout this file

return Cast.create(scope, result, logits.asOutput().dataType());
} else if(!logits.asOutput().dataType().equals(labels.asOutput().dataType())) {
return Cast.create(scope, result, logits.asOutput().type());
} else if(!logits.asOutput().type().equals(labels.asOutput().type())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!= would be more consistent.

@@ -197,8 +192,8 @@
*/
private static <T extends TNumber, U extends TNumber> Operand<T> moveDimToEnd(
Scope scope, Operand<T> input, int dimIndex, Operand<U> rank) {
DataType<? extends TNumber> rankDType = rank.asOutput().dataType();
Operand one = Cast.create(scope, Constant.scalarOf(scope, 1), rankDType);
Class<U> rankDType = rank.asOutput().type();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rankDType -> rankType

TFloat32 floatTensor = rawTensor.asTypedTensor();
assertSame(floatTensor.asRawTensor(), rawTensor);
try {
TInt32 intTensor = rawTensor.asTypedTensor();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test code throws the expected exception because of the assignment to intTensor, rather than in asTypedTensor(). Here is alternative test code that also passes, but demonstrates that asTypedTensor() doesn't have its documented behavior:

  @Test
  public void rawToTypedTensor() {
    RawTensor rawTensor = RawTensor.allocate(TFloat32.class, Shape.of(2, 2), -1);
    TFloat32 floatTensor = rawTensor.asTypedTensor();
    assertSame(floatTensor.asRawTensor(), rawTensor);
    Object objTensor = rawTensor.<TInt32>asTypedTensor();
    try {
      TInt32 intTensor = (TInt32) objTensor;
      fail();
    } catch (ClassCastException e) {
      // ok
    }
  }

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I've never tried that before. You want me to update the doc or try to throw an ClassCastException in this case as well?

Note that asTypedTensor() is an internal package-private method and where it is currently being used, it is implicitly enforced that the returned type matches the type of the tensor for all cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating the doc seems good to me for both of the reasons you give.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest I'm not too sure how to document this. It looks to me that this behavior is some kind of "defect" in the generic specification, since the explicit <TInt32> parameterization of the method invocation seems to be overridden implicitly to the type inferred by the target (Object in this case).

So basically, <TInt32> is ignored and that is probably the case for Java type inference in general when dealing with type parameters inferred by a target, is it worth documenting it here then?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type will be erased at runtime so the cast is erased to its bound (TType) which always succeeds. It might be better to have it explicitly return TType rather than have it promise something that can't be enforced by the type system. This is going to pollute the internals of our code with casts, but at least we'll remember to put them in rather than having it mysteriously blow up.

Copy link
Collaborator Author

@karllessard karllessard Dec 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok sounds fair, let's do that, anyway all these additional casts will happen internally only.

Update: ended up that only one additional cast was required in the source code...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On further reflection, I lean the same way as @Craigacp: have this method return TType.

Here's some analysis to support this choice: At runtime, the current behavior is "apply the mapper from typeInfo and return whatever TType it produces." Since the type parameter T is erased at runtime, the only way to guarantee a return of T would be to have the method take the tensor type class as a runtime parameter and explicitly verify that the mapper returns that type class. I might advocate this approach for a public method or even a widely used package-private method, but not for a rarely used package-private method.

But also, a review of the call points of this method raises another issue: whenever we treat the return value of asTypedTensor as simply a Tensor (rather than a TType), we "forget" at compile time the semantic upgrade that was presumably the point of the asTypedTensor call in the first place.

As an example, let's explore what happens underneath the following method:

class EagerOperation extends AbstractOperation {
  // ...
  /**
   * Returns the tensor of the {@code outputIdx}th output of this operation.
   *
   * <p>This is only supported in an eager execution environment.
   *
   * @param outputIdx index of the output of this operation
   * @return output tensor
   */
  @Override
  Tensor tensor(int outputIndex) {
    Tensor tensor = outputTensors.get(outputIndex);
    if (tensor == null) {
      tensor = resolveTensor(outputIndex);
    }
    return tensor;
  }
  // ...
}

(I copied the Javadoc above the @Override for easy reference.)

The caller should presumably be agnostic as to whether this returns a RawTensor or a TType. But looking deeper into the call path, here's how the returned Tensor is constructed:

  private static Tensor resolveTensorHandle(TFE_TensorHandle handle, EagerSession session) {
    requireTensorHandle(handle);
    try (PointerScope scope = new PointerScope()) {
      TF_Status status = TF_Status.newStatus();
      TF_Tensor tensor = TFE_TensorHandleResolve(handle, status).withDeallocator();
      status.throwExceptionIfNotOK();
      return RawTensor.fromHandle(tensor, session).asTypedTensor();
    }
  }

What's the point of asTypedTensor() on the last substantive line, above? It causes this method to return a special kind of Tensor -- a TType -- but that fact is immediately forgotten by the type system and is not even asserted in the Javadoc. Presumably, we should either drop the call to .asTypedTensor() or change the whole call path to return TType.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:-) But not for this PR!

}

@Test
public void allocateTensorWithoutSize() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add test to show intended behavior of shape with UNKNOWN_SIZE.

@karllessard
Copy link
Collaborator Author

I've just pushed a last version that should now covered all discussed topics during the review, I'll merge it as soon as I receive a green light from your behalf and we will be done with that massive refactoring (and hopefully the last one of this nature).

Copy link
Contributor

@deansher deansher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Eagle has landed.

Copy link
Contributor

@JimClarke5 JimClarke5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job @karllessard

Copy link
Collaborator

@Craigacp Craigacp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The typing in TensorTypeRegistry should probably be relaxed as I've indicated, but that can wait till the next PR.

OperationBuilder opBuilder = scope.env().opBuilder("CollectiveBcastRecv", scope.makeOpName("BroadcastRecv"));
opBuilder = scope.apply(opBuilder);
opBuilder.setAttr("T", T);
opBuilder.setAttr("T", Operands.toDataType(T));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point we should fix the generator so it doesn't shadow the type name with a variable name, that's just confusing.

* @return type registered information
* @throws IllegalArgumentException if no tensor type for this data type has been registered
*/
public static <T extends TType> TensorTypeInfo<T> find(DataType dataType) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a place where it will infer the type bound from the calling context, and if that's wrong then we'll get a weird class cast error when people use it. It might be better to return the wildcard as at least the user will get a warning when they make the cast.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again note that the TensorTypeRegistry is (for now) an internal class so we do have some control over the context it is being called. But like you've suggested, let's review this later.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I know. But I'm in favour of having the compiler warn me before I make a silly mistake, and 6 months from now I won't necessarily remember that this method doesn't quite live up to its contract.

@karllessard
Copy link
Collaborator Author

All right, let's merge this before someone changes his mind :) Thank you all for your reviews and comments!

@karllessard karllessard merged commit f85623e into tensorflow:master Dec 30, 2020
@JimClarke5 JimClarke5 mentioned this pull request Jan 3, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants