Skip to content

Eliminating dynamic shaped tensors for the static inputs #911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
erman-gurses opened this issue Jun 7, 2022 · 18 comments
Closed

Eliminating dynamic shaped tensors for the static inputs #911

erman-gurses opened this issue Jun 7, 2022 · 18 comments
Assignees

Comments

@erman-gurses
Copy link
Collaborator

erman-gurses commented Jun 7, 2022

This issue is related to #935

For the torch IR of miniLM model, we observed dynamic tensors like below even if we provided static input as torch.randint(2, (1, 128))

    %186 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
     .
     .
     .
    %190 = torch.prim.NumToTensor.Scalar %int128 : !torch.int -> !torch.vtensor<[],si64>
    %191 = torch.aten.add.Tensor %190, %186, %int1 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>
    %192 = torch.aten.Int.Tensor %191 : !torch.vtensor<[],si64> -> !torch.int
    %193 = torch.prim.ListConstruct %int1, %192 : (!torch.int, !torch.int) -> !torch.list<int>
    %194 = torch.aten.ones %193, %int6, %none, %cpu, %false : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[1,?],f32>

Dynamic shape of %194 makes many of the torch.aten.add.Tensor operations also dynamic in the next part of the IR. So solving this code snipped as static might also can solve other dynamic shaped tensor issues.

I am working on to find a canonical way eliminate those dynamic shaped tensors.

@erman-gurses erman-gurses self-assigned this Jun 7, 2022
@erman-gurses erman-gurses changed the title Eliminating dynamic tensor shapes for the static inputs Eliminating dynamic tensors for the static inputs Jun 7, 2022
@erman-gurses erman-gurses changed the title Eliminating dynamic tensors for the static inputs Eliminating dynamic shaped tensors for the static inputs Jun 7, 2022
@nirvedhmeshram
Copy link
Collaborator

nirvedhmeshram commented Jun 7, 2022

So the current plan is to add canonicalization for add.tensor and friends that can folds 0d operands, these operand can be coming either from a torch.prim.NumToTensor.Scalar op or as a 0d torch.vtensor.literal below are two example before and after canonicalization
Example 1

%0 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
%1 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64>
%2 = torch.aten.add.Tensor %0, %1, %int3 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>

after canonicalization

%0 =  aten.add.int(%int0, %int2) : !torch.int, !torch.int -> !torch.int
%1 =  aten.add.int(%0, %int3) : !torch.int, !torch.int -> !torch.int
%2 = torch.prime.NumToTensor %1 :  !torch.int -> !torch.vtensor<[],si64>

Example 2:

%0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%1 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64>
%2 = torch.aten.add.Tensor %0, %1, %int3 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>

Will also have the exact same IR after canonicalization

@silvasean
Copy link
Contributor

👍

@vivekkhandelwal1
Copy link
Collaborator

Hi @nirvedhmeshram, I think this IR:

%0 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
%1 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64>
%2 = torch.aten.add.Tensor %0, %1, %int3 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>

after canonicalization should look like this:

%0 =  aten.mul.int(%int2, %int3) : !torch.int, !torch.int -> !torch.int
%1 =  aten.add.int(%0, %int0) : !torch.int, !torch.int -> !torch.int
%2 = torch.prime.NumToTensor %1 :  !torch.int -> !torch.vtensor<[],si64>

The argument %int3 in the original IR for op torch.aten.add.Tensor is alpha, and it should be multiplied to the second argument %1.

Please correct me if I'm wrong.

@nirvedhmeshram
Copy link
Collaborator

Yup, I misunderstood that. Thanks for the correction.

@powderluv
Copy link
Collaborator

@sjarus FYI

@sjarus
Copy link
Collaborator

sjarus commented Jun 13, 2022

SUPER helpful for our TorchToTosa work!

@powderluv
Copy link
Collaborator

@sjarus do you think we can try the TOSA Bert lowering now we have a WIP for this ?

@sjarus
Copy link
Collaborator

sjarus commented Jun 21, 2022

I'll try it out as a local patch @powderluv

@sjarus
Copy link
Collaborator

sjarus commented Jun 22, 2022

Update: I tried out #935 locally using the code that was present yesterday evening. I see that aten.ones does not appear to generate a static shaped form. This is consumed by a slice , which also propagates the unknown shape, so it looks like the aten.ones is a blocker here:

    ...
    %112 = torch.prim.ListConstruct %int1, %int7 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc22)
    %113 = torch.aten.ones %112, %int6, %none, %cpu, %false : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[1,?],f32> loc(#loc26)
    ...
    %118 = torch.aten.slice.Tensor %113, %int0, %int0, %int9223372036854775807, %int1 : !torch.vtensor<[1,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?],f32> loc(#loc30)

@silvasean
Copy link
Contributor

I suspect this is a phase ordering issue with when we apply this canonicalization transformation. See my comments here: #963

@silvasean
Copy link
Contributor

@powderluv did we ever find out where in the source Python code this pattern was coming from?

@nirvedhmeshram
Copy link
Collaborator

@silvasean I tried briefly, but could not get the location information in the heavydep test from which we derive the model so could not find the source.

@silvasean
Copy link
Contributor

@silvasean I tried briefly, but could not get the location information in the heavydep test from which we derive the model so could not find the source.

Could you try "non-briefly"? :D This is pretty important, as the way that PyTorch is architected "should not" produce this, so it's really valuable information for us to know how it gets created.

@sjarus
Copy link
Collaborator

sjarus commented Jun 23, 2022

The aten.ones should take the %int1 and %int7 as inputs in list form and emit a tensor of shape (1, 7). However it computes the shape as (1, ?) . It's picking up the parameters explicitly because changing the shape to something else still picks up the first parameter right. Where is this shape resolved ?

@silvasean
Copy link
Contributor

The shape transfer function for aten.ones is trivial. So if it isn't resolved it means that at the point where we are simplifying shapes (SimplifyShapeCalculations), the list elements are not fully resolved to %int1 and %int7. So the IR before that pass (ideally reduced) is necessary to debug this.

@sjarus
Copy link
Collaborator

sjarus commented Jun 23, 2022

Thanks @silvasean . Yes, I wondered what I'm missing about what seemed like a fairly straightforward op in terms of shape resolution.

@silvasean
Copy link
Contributor

Is this solved byhttps://github.com//pull/935 ?

@erman-gurses
Copy link
Collaborator Author

Yes @silvasean , this one was solved by #935

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants