Skip to content

🐛 [Bug] inception_v3 pretrained compilation - Unsupported ATen data type Double #1096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
apivovarov opened this issue Jun 1, 2022 · 3 comments
Labels
bug Something isn't working

Comments

@apivovarov
Copy link

apivovarov commented Jun 1, 2022

Bug Description

To Reproduce

Steps to reproduce the behavior:

import torch
import torchvision.models as models
import torch_tensorrt

# get inception_v3 pretrained model
# It also implicitly sets transform_input=True which causes Double type issue during the compilation
model = models.inception_v3(pretrained=True).eval()
x = torch.rand(1, 3, 299, 299)
y=model(x)
tmodel = torch.jit.trace(model, x)
trt_model = torch_tensorrt.compile(tmodel,
    inputs= [torch_tensorrt.Input((1, 3, 299, 299))]
)

Error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 115, in compile
    return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 116, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/util/trt_util.cpp:283] Expected type to be true but got false
Unsupported ATen data type Double

Environment

Build information about Torch-TensorRT can be found by turning on debug messages
I use the latest Nvidia PyTorch Docker Image nvcr.io/nvidia/pytorch:22.04-py3.

docker run -ti --gpus all \
--ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
nvcr.io/nvidia/pytorch:22.04-py3
>>> torch.__version__
'1.12.0a0+bd13bc6'
>>> torch_tensorrt.__version__
'1.1.0a0'

Additional context

To solve the issue I need to replace existing inception model method _transform_input() with a fixed one
Existing _transform_input() method

    def _transform_input(self, x: Tensor) -> Tensor:
        if self.transform_input:
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
        return x

Fixed _transform_input() where I explicitly wrap constants with float32 tensors. It works, but PyTorch developers usually do not write models code this way.

    def _transform_input(self, x: Tensor) -> Tensor:
        if self.transform_input:
            a0=torch.tensor(0.229 / 0.5, dtype=torch.float32)
            a1=torch.tensor(0.224 / 0.5, dtype=torch.float32)
            a2=torch.tensor(0.225 / 0.5, dtype=torch.float32)
            b0=torch.tensor((0.485 - 0.5) / 0.5, dtype=torch.float32)
            b1=torch.tensor((0.456 - 0.5) / 0.5, dtype=torch.float32)
            b2=torch.tensor((0.406 - 0.5) / 0.5, dtype=torch.float32)
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * a0 + b0
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * a1 + b1
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * a2 + b2
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
        return x

Is it possible to ask torch.jit.trace to automatically wrap constants with float32 tensors?

tmodel.graph shows Double for the original model

  %1512 : Double(requires_grad=0, device=cpu) = prim::Constant[value={0.458}]()

for the fixed model (with fixed _transform_input() method) it shows Float - which works for TRT compiler

 %1502 : Float(requires_grad=0, device=cpu) = prim::Constant[value={0.458}]()
@narendasan
Copy link
Collaborator

Can you try using the setting truncate_long_and_double? https://pytorch.org/TensorRT/py_api/ts.html

@apivovarov
Copy link
Author

Looks like it was just fixed - #1259

#266 (comment)

@ncomly-nvidia
Copy link
Contributor

Closing based on #1096 (comment).

Please comment if this is not resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants