✨[Feature] Add lowering pass cases to avoid aten::Int.Tensor
calls
#1880
Labels
feature request
New feature or request
Problem Context
The function schema
aten::Int.Tensor(Tensor a) -> int
is a known problematic case for Torch-TensorRT (see #513). This issue arises from the fact that once an integer becomes a TensorRT tensor, we can no longer extract its contained data. An example graph which has this issue is shown below:In the above, none of the intermediate
aten::mul
operations need to be operating on Tensor inputs, since all are simply multiplying single-element Tensors. We do already have lowering passes which replace generic cases of this sort, however catching more of these scenarios would be helpful for performance. See:TensorRT/core/lowering/passes/remove_unnecessary_casts.cpp
Line 58 in bf5d663
Proposed Solution
The proposed solution to this issue is to add a new lowering pass which can resolve cases like the above, by detecting operators like
aten::mul
oraten::floor_divide
, which are operating on single-element Tensors. More specifically, if both inputs toaten::mul
are any of the following:prim::NumToTensor
outputprim::Constant
constructing a single-element TensorScalarType
Then that
aten::mul
can be replaced by a newaten::mul
which takes as input the original integer arguments, and outputs an integer. For example:Additional Context
Relates to #1836 - first step in developing solution.
The text was updated successfully, but these errors were encountered: