Skip to content

✨[Feature] Add lowering pass cases to avoid aten::Int.Tensor calls #1880

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
gs-olive opened this issue May 2, 2023 · 0 comments · Fixed by #1937
Closed

✨[Feature] Add lowering pass cases to avoid aten::Int.Tensor calls #1880

gs-olive opened this issue May 2, 2023 · 0 comments · Fixed by #1937
Assignees
Labels
feature request New feature or request

Comments

@gs-olive
Copy link
Collaborator

gs-olive commented May 2, 2023

Problem Context

The function schema aten::Int.Tensor(Tensor a) -> int is a known problematic case for Torch-TensorRT (see #513). This issue arises from the fact that once an integer becomes a TensorRT tensor, we can no longer extract its contained data. An example graph which has this issue is shown below:

  %37436 : int = aten::size(%t.1, %25)  
  %element.2 : Tensor = prim::NumToTensor(%37436)
  %result.2 : Tensor = aten::mul(%element.2, %20)
  %37445 : Tensor = aten::mul(%result.2, %element368.1)
  %37446 : Tensor = aten::mul(%37445, %element369.1)
  %37447 : int = aten::Int(%37446)

In the above, none of the intermediate aten::mul operations need to be operating on Tensor inputs, since all are simply multiplying single-element Tensors. We do already have lowering passes which replace generic cases of this sort, however catching more of these scenarios would be helpful for performance. See:

void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {

Proposed Solution

The proposed solution to this issue is to add a new lowering pass which can resolve cases like the above, by detecting operators like aten::mul or aten::floor_divide, which are operating on single-element Tensors. More specifically, if both inputs to aten::mul are any of the following:

  • prim::NumToTensor output
  • prim::Constant constructing a single-element Tensor
  • A ScalarType

Then that aten::mul can be replaced by a new aten::mul which takes as input the original integer arguments, and outputs an integer. For example:

  %20 : Tensor = prim::Constant[value={1}]()
  %37436 : int = aten::size(%t.1, %25)  
  %element.2 : Tensor = prim::NumToTensor(%37436)
  %result.2 : Tensor = aten::mul(%element.2, %20)
  %37445 : Tensor = aten::mul(%result.2, %element368.1)
  %37447 : int = aten::Int(%37446)

##### REPLACED WITH #####

  %20 : Tensor = prim::Constant[value={1}]()
  %37436 : int = aten::size(%t.1, %25)  
  %result.2 : int = aten::mul(%37436, 1)
   ...

Additional Context

Relates to #1836 - first step in developing solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant