Skip to content

🐛 [Bug] bug with torchvision.transforms.GaussianBlur #1526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
mjack3 opened this issue Dec 5, 2022 · 6 comments
Closed

🐛 [Bug] bug with torchvision.transforms.GaussianBlur #1526

mjack3 opened this issue Dec 5, 2022 · 6 comments
Assignees
Labels
bug Something isn't working No Activity

Comments

@mjack3
Copy link

mjack3 commented Dec 5, 2022

Bug Description

Your test does not work correctly. I have two models using the torchvision.transforms.GaussianBlur and torch_executed_modules is not able to skip the operation in the second model.

To Reproduce

I have prepared a toy sample

import torch
import torchvision
import tensorrt
import torch_tensorrt

class ToyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.gaus = torchvision.transforms.GaussianBlur([33, 33], [4., 4.])
        self.conv = torch.nn.Conv2d(3, 64, (3,3))
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv(x))
        x = self.gaus(x)
        return x

model1 = ToyModel().eval().to('cuda')
model2 = ToyModel().eval().to('cuda')

traced1 = torch.jit.trace(model1, torch.randn(1,3,224,224).to('cuda'))
traced2 = torch.jit.trace(model2, torch.randn(1,3,224,224).to('cuda'))

print(traced1.graph) ## Here we can see torchvision.transforms.transforms.GaussianBlur
print(traced2.graph) ## But here we can see torchvision.transforms.transforms.___torch_mangle_4.GaussianBlur

trt_model1 = torch_tensorrt.compile(
    traced1,
    "default",
    [torch_tensorrt.Input((1, 3, 224, 224), dtype=torch.float32)],
    torch.float32,
    truncate_long_and_double = True,
    torch_executed_modules = ['torchvision.transforms.transforms.GaussianBlur']
)

print("**** Done first! *****")

trt_model2 = torch_tensorrt.compile(
    traced2,
    "default",
    [torch_tensorrt.Input((1, 3, 224, 224), dtype=torch.float32)],
    torch.float32,
    truncate_long_and_double = True,
    torch_executed_modules = ['torchvision.transforms.transforms.GaussianBlur']
)
print("***** Done second! *****")

Expected behavior

According to your api test, the second model should be correctly converted into tensorRT, skipping the no supported operation GaussianBlur.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): v1.1.0
  • PyTorch Version (e.g. 1.0): 1.11.0+cu113
  • OS (e.g., Linux): Ubuntu 22
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Python version: 10.0

Additional context

This example can also be done with the resnet18, similar to your api test

model1 = torchvision.models.resnet.resnet18(pretrained=True).eval().to('cuda')
model2 = torchvision.models.resnet.resnet18(pretrained=True).eval().to('cuda')

scripted_model1 = torch.jit.trace(model1, torch.randn(1, 3, 224, 224).cuda())
scripted_model2 = torch.jit.trace(model2, torch.randn(1, 3, 224, 224).cuda())

print(scripted_model1.graph) # Here we can see  torchvision.models.resnet.ResNet
print(scripted_model2.graph) # And here torchvision.models.resnet.___torch_mangle_194.ResNet

The graphs have different names and then, torch_executed_modules = torchvision.models.resnet.BasicBlock may not be properly skipped.

The question is

How could I manage to skip the GaussianBlur in the second model? It seems that they can not exist at the same time

@mjack3 mjack3 added the bug Something isn't working label Dec 5, 2022
@narendasan
Copy link
Collaborator

I am a bit confused here, is the issue that torch_executed_modules is not recognizing torchvision.transforms.transforms.GaussianBlur? or that it is not properly detecting torchvision.transforms.transforms.___torch_mangle_4.GaussianBlur given that you provide torchvision.transforms.transforms.GaussianBlur to skip?

Did you observe this in torch_tensorrt 1.3.0 / master as well? There was a recent patch that might address this: #1454

@mjack3
Copy link
Author

mjack3 commented Dec 5, 2022

Hello and thanks for your answer.

The problem is that it does not recognize torchvision.transforms.transforms.___torch_mangle_4.GaussianBlur when It should be recognized in the same way than first case. :)

I did not tested it in v1.3.0. Could you test these toy example?

@mjack3
Copy link
Author

mjack3 commented Dec 6, 2022

Hello @narendasan, finally I managed to solve the problem by changing the logic of my code. Now there is something more (rare) happening.

Error 1: trying not to convert the gaussian blur

WARNING: [Torch-TensorRT TorchScript Conversion Context] - CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage. See `CUDA_MODULE_LOADING` in https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars
WARNING: [Torch-TensorRT] - CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage. See `CUDA_MODULE_LOADING` in https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars
Traceback (most recent call last):
  File "/home/mjack3/projects/02_AD/code/TKHAD/kk.py", line 36, in <module>
    loaded = torch.jit.load('trt32.torch-tensorrt')
  File "/home/mjack3/projects/02_AD/code/TKHAD/venv/lib/python3.10/site-packages/torch/jit/_serialization.py", line 162, in load
    cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files)
RuntimeError: expected ) but found 'number' here:
Serialized   File "code/__torch__.py", line 6
  __torch___ToyModel_trt_engine_0x55f01779c490 : __torch__.torch.classes.tensorrt.Engine
  def forward(self_1: __torch__.ToyModel_trt,
    x.1: Tensor) -> Tensor:
     ~~ <--- HERE
    __torch___ToyModel_trt_engine_0x55f01779c490 = self_1.__torch___ToyModel_trt_engine_0x55f01779c490
    _0 = ops.tensorrt.execute_engine([x.1], __torch___ToyModel_trt_engine_0x55f01779c490)

Code to reproduce it

class ToyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.gauss = torchvision.transforms.GaussianBlur([33, 33], [4., 4.])
        self.conv1 = torch.nn.Conv2d(3, 64, (3, 3))
        self.conv2 = torch.nn.Conv2d(64, 124, (3, 3))
        self.activation = torch.nn.ReLU()


    def forward(self, x):
        x = self.conv1(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.activation(x)
        x = self.gauss(x)
        return x

model = ToyModel().eval().to('cuda')
traced = torch.jit.trace(model, torch.randn(1,3,224,224).to('cuda'))


trt_model = torch_tensorrt.compile(
    traced,
    "default",
    [torch_tensorrt.Input((1, 3, 224, 224), dtype=torch.float32)],
    torch.float32,
    truncate_long_and_double = True,
    torch_executed_modules = ['torchvision.transforms.transforms.GaussianBlur']
)

Error 2: Trying to convert the gaussian blur

WARNING: [Torch-TensorRT TorchScript Conversion Context] - CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage. See `CUDA_MODULE_LOADING` in https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size

and the output stay there stuck (after waiting 8 minutes)

code to reproduce it

class ToyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.gauss = torchvision.transforms.GaussianBlur([33, 33], [4., 4.])
        self.conv1 = torch.nn.Conv2d(3, 64, (3, 3))
        self.conv2 = torch.nn.Conv2d(64, 124, (3, 3))
        self.activation = torch.nn.ReLU()


    def forward(self, x):
        x = self.conv1(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.activation(x)
        x = self.gauss(x)
        return x

model = ToyModel().eval().to('cuda')
traced = torch.jit.trace(model, torch.randn(1,3,224,224).to('cuda'))


trt_model = torch_tensorrt.compile(
    traced,
    "default",
    [torch_tensorrt.Input((1, 3, 224, 224), dtype=torch.float32)],
    torch.float32,
    truncate_long_and_double = True,
    # torch_executed_modules = ['torchvision.transforms.transforms.GaussianBlur']
)

@mjack3 mjack3 changed the title 🐛 [Bug] bug if using torch_executed_modules 🐛 [Bug] bug with torchvision.transforms.GaussianBlur Dec 6, 2022
@mjack3
Copy link
Author

mjack3 commented Dec 6, 2022

Tested in v1.1.0 and v1.2.0

@peri044
Copy link
Collaborator

peri044 commented Dec 13, 2022

We have fixed similar issue RuntimeError: expected ) but found 'number' here: in the past. Can you try with the latest 1.3 release or with the latest master ?

@github-actions
Copy link

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working No Activity
Projects
None yet
Development

No branches or pull requests

3 participants