Skip to content

✨[Feature] Automatic conversion for int32<->int64 in fallback #1382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
inocsin opened this issue Sep 27, 2022 · 1 comment · Fixed by #1407
Closed

✨[Feature] Automatic conversion for int32<->int64 in fallback #1382

inocsin opened this issue Sep 27, 2022 · 1 comment · Fixed by #1407
Assignees
Labels
component: partitioning feature request New feature or request release: v1.3 Tagged to be included in v1.3

Comments

@inocsin
Copy link
Contributor

inocsin commented Sep 27, 2022

As we know, if there are some operators that torch-tensorrt doesn't support, the model will be partitioned into tensorrt and torch subgraphs. TensorRT doesn't support int64 value and will truncate int64 to int32.

In some cases, the operators in the torch subgraph consume int64 value(like aten::index), and this value is produced from tensorrt subgraph(truncated into int32), this will cause an error. We need to track the data type conversion and automatic convert the data type back to the origianl type between torch and tensorrt.

Here is a typical case

import torch
import torch.nn as nn
import torch_tensorrt


class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()

  def forward(self, data, index):
    src = 1
    index = index.to(torch.int64)
    data = data * data
    data = data.scatter_(1,index,src)
    data = data + 1
    return data

data = torch.randn([5,5])
index = torch.randint(0,4,[2,2], dtype = torch.int32)

compile_spec = {
    "inputs": None,
    "device": {
        "device_type": torch_tensorrt.DeviceType.GPU,
        "gpu_id": 0,
        "allow_gpu_fallback": False,
        "disable_tf32": False
    },
    "truncate_long_and_double": True,
    "require_full_compilation": False,
    "torch_executed_ops": ["aten::scatter_", "aten::scatter"],
    "min_block_size": 1
}

net = Net()
model = torch.jit.trace(net, (data, index))

torch_type = torch.float32
min_shape = [5,5]
data2 = torch_tensorrt.Input(shape=min_shape, dtype=torch_type)
torch_type = torch.int32
index2 = torch_tensorrt.Input(shape=min_shape, dtype=torch_type)

inputs = [data2, index2]

compile_spec["inputs"] = inputs

with torch_tensorrt.logging.debug():
    trt_mod = torch_tensorrt.ts.compile(model, **compile_spec)
inputs = [data.cuda(), index.cuda()]
output = trt_mod(*inputs)
print(output)

subgraph log

INFO: [Torch-TensorRT - Debug Build] - Partitioned Graph: [Segment Block @0:
    Target: TensorRT

    Graph: graph(%index.1 : Tensor,
      %data.1 : Tensor):
  %2 : int = prim::Constant[value=4]() # test_int64.py:28:0
  %3 : bool = prim::Constant[value=0]() # test_int64.py:28:0
  %4 : NoneType = prim::Constant()
  %index : Tensor = aten::to(%index.1, %2, %3, %3, %4) # test_int64.py:28:0
  %data.3 : Tensor = aten::mul(%data.1, %data.1) # test_int64.py:29:0
  return (%index, %data.3)

Segment Block @1:
    Target: Torch

    Graph: graph(%data.3 : Tensor,
      %index : Tensor):
  %2 : int = prim::Constant[value=1]() # test_int64.py:30:0
  %0 : Tensor = aten::scatter(%data.3, %2, %index, %2) # test_int64.py:30:0
  return (%0)

Segment Block @2:
    Target: TensorRT

    Graph: graph(%1 : Tensor):
  %2 : Tensor = prim::Constant[value={1}]() # test_int64.py:31:0
  %3 : int = prim::Constant[value=1]() # test_int64.py:30:0
  %0 : Tensor = aten::add(%1, %2, %3) # test_int64.py:31:0
  return (%0)

]
@inocsin inocsin added the feature request New feature or request label Sep 27, 2022
@inocsin
Copy link
Contributor Author

inocsin commented Sep 27, 2022

@bowang007

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component: partitioning feature request New feature or request release: v1.3 Tagged to be included in v1.3
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants