Skip to content

🐛 [Bug] RuntimeError: linear convolution has bias of type <class 'tensorrt.tensorrt.ITensor'>, Expect Optional[Tensor] when using torch_tensorrt as backend in torch.compile #2506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
airalcorn2 opened this issue Nov 30, 2023 · 5 comments
Assignees
Labels
bug Something isn't working component: converters Issues re: Specific op converters

Comments

@airalcorn2
Copy link

Bug Description

When trying to compile a simple PyTorch module with torch.compile(model, backend="torch_tensorrt"), I get:

RuntimeError: linear convolution has bias of type <class 'tensorrt.tensorrt.ITensor'>, Expect Optional[Tensor]

I don't get any errors when using torch.compile(model) or torch_tensorrt.compile(model, inputs=inputs).

To Reproduce

Steps to reproduce the behavior:

import torch
import torch_tensorrt

from torch import nn


class PointNetLayer(nn.Module):
    def __init__(self, in_feats, out_feats):
        super().__init__()
        self.linear = nn.Conv2d(in_feats, out_feats, 1)
        self.norm = nn.BatchNorm2d(out_feats)
        self.relu = nn.ReLU()

    def forward(self, points):
        pn_feats = self.relu(self.norm(self.linear(points)))
        return pn_feats


def main():
    device = torch.device("cuda:0")
    model = PointNetLayer(3, 64).to(device)
    model.eval()
    print(model)

    points = torch.rand((1, 3, 12000, 200)).to(device)
    # Works.
    with torch.no_grad():
        _ = model(points)

    model_opt = torch.compile(model)
    # Works.
    with torch.no_grad():
        _ = model_opt(points)

    torch._dynamo.reset()
    model_opt = torch.compile(model, backend="torch_tensorrt")
    # RuntimeError.
    with torch.no_grad():
        _ = model_opt(points)

    torch._dynamo.reset()
    inputs = [torch_tensorrt.Input(points.shape)]
    # Works.
    trt_ts_module = torch_tensorrt.compile(model, inputs=inputs)


if __name__ == "__main__":
    main()

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 1.4.0
  • PyTorch Version (e.g. 1.0): 2.0.1+cu117
  • CPU Architecture: i7-12800H
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.10.10
  • CUDA version: 12.2
  • GPU models and configuration: GeForce RTX 3080 Ti
  • Any other relevant information:

Additional context

@narendasan
Copy link
Collaborator

@gs-olive can you take a look at this?

@narendasan narendasan added the component: converters Issues re: Specific op converters label Nov 30, 2023
@apbose
Copy link
Collaborator

apbose commented Dec 1, 2023

Hi @airalcorn2 the PointNet() model is passing on my end with model_opt = torch.compile(model, backend="torch_tensorrt"). May I please know what is the torch version and the torch_tensorrt version you are using?
Because the torch_tensorrt seems obsolete and the above should have been fixed in this - #1972

@airalcorn2
Copy link
Author

airalcorn2 commented Dec 3, 2023

I was using PyTorch 2.0.1 and Torch-TensorRT 1.4.0, which is the latest Torch-TensorRT release and was released before #1972 was merged. Should users default to using the version in main?

@apbose
Copy link
Collaborator

apbose commented Dec 4, 2023

Yes could you please try with the latest version of main?

@airalcorn2
Copy link
Author

Confirmed I no longer get the error when using the Torch-TensorRT version included in nvcr.io/nvidia/pytorch:23.10-py3, i.e., running the script in:

docker pull nvcr.io/nvidia/pytorch:23.10-py3
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3

works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working component: converters Issues re: Specific op converters
Projects
None yet
Development

No branches or pull requests

4 participants