Skip to content

[Converter] Add support for aten::Int in TRTorch #513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ApluUalberta opened this issue Jun 28, 2021 · 9 comments · Fixed by #870 or #1937
Closed

[Converter] Add support for aten::Int in TRTorch #513

ApluUalberta opened this issue Jun 28, 2021 · 9 comments · Fixed by #870 or #1937
Assignees
Labels
Blocked Issue cannot be resolved until a change in a dependency feature request New feature or request No Activity

Comments

@ApluUalberta
Copy link

ApluUalberta commented Jun 28, 2021

aten::Int.Tensor(a)

  • Function Schema:
    aten::Int.Tensor(a) -> (int)

Is your feature request related to a problem?
The conversion of a model to the TRTorch variant had troubles supporting a specific operator (aten::Tensor(a) -> (int))
image_from_ios

Alternatives

Additional Context

@ApluUalberta ApluUalberta added the feature request New feature or request label Jun 28, 2021
@narendasan
Copy link
Collaborator

narendasan commented Jun 28, 2021

This op is a known limitation of TRTorch. Since there are two distinct phases for TRTorch programs, compilation and execution that have different sets of available information, aten::Int.tensor(Tensor a) -> (Int) is something we are currently unable to handle. This is because the concrete value of the Tensor is not known at compile time when the integer must be extracted for use later in the compilation process, this tensor value only gets realized at runtime.

My suggestion is to try using the partial compilation feature to force aten::Int to run in PyTorch or try to rewrite your model so that you are not trying to cast a 0D Tensor to an Int.

@narendasan narendasan added the Blocked Issue cannot be resolved until a change in a dependency label Jun 28, 2021
@itsliupeng
Copy link
Contributor

itsliupeng commented Jun 29, 2021

My temporary solution is to use c = int(a / b) instead of c = a // b in pytorch module forward method.
It will give a warning as below but can be ignored.

TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

@borisfom
Copy link
Collaborator

@narendasan : Can we actually work around this issue by internally pretending the result of that Int() operation is a Tensor?
We can probably have a rewrite rule for that. Here is a real life example of how the value is actually used :

  • %232 : int = aten::size(%input_ids.1, %6) # /usr/local/lib/python3.6/distpackages/transformers/models/bert/modeling_bert.py:187:0
  • %seq_length.1 : Tensor = prim::NumToTensor(%232) # :0:0
  • %234 : Tensor = aten::add(%seq_length.1, %18, %6) # /usr/local/lib/python3.6/dist-packages/transformers/models/bert/modeling_bert.py:194:0
  • %235 : int = aten::Int(%234)
    %input0.23 : Tensor = aten::slice(%19, %6, %4, %235, %6) # /usr/local/lib/python3.6/dist-packages/transformers/models/bert/modeling_bert.py:194:0

So, if we rewrite "%235 : int = aten::Int(%234)" --> "%235 := %234", wouldn't that work?
Or would that require us to also bend riles on what slice expects as 4th argument, as we only allow int now ?
Having said that - fixing Python code seem to be much more beneficial, as without explicit int() Pytorch generates a lot of useless operations: storing literal 0 as {0}, converting int to Tensor, adding {0} to that, converting back to int and sending that to slice(). With int(), it folds down the constant and even slice() disappears into something else.

@narendasan
Copy link
Collaborator

narendasan commented Jul 16, 2021

I could see if there is a pair of aten::int, prim::NumToTensor and no other uses of the Int output we could fuse them, I dont know how general that is. We could also explore setting up a torch.fx pass to do the int() source conversion so users dont need to manually rewrite, but that pass would need to be run by the user pre-torchscript conversion.

@oliver-batchelor
Copy link

I'm having trouble with this in a model which does some math on the shape of the tensor - I've had this model working in torch2trt before, because it just treats it as a constant (and that's fine), where as the torch.jit tries to capture the math on the shape somehow.

Any way around this? Ideally this isn't in the compiled model - is there a way to make the torch jit treat it like a constant?

@narendasan
Copy link
Collaborator

Have you tried the workaround above? Basically explicitly cast the int in python

@p1x31
Copy link
Contributor

p1x31 commented Aug 22, 2021

Stumbled upon this issue as well, tried explicitly casting the int to all floor operations in the forward method and it did work

@github-actions
Copy link

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

@github-actions
Copy link

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Blocked Issue cannot be resolved until a change in a dependency feature request New feature or request No Activity
Projects
None yet
6 participants