-
Notifications
You must be signed in to change notification settings - Fork 365
🐛 [Bug] Encountered bug when using Torch-TensorRT #1687
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
@narendasan @bowang007 any update on this? I am observing the same behaviour. |
Actually I see some deprecation warnings when compiling Torch-TensorRT which seem to come from interpolate_plugin (amongst other files):
Can that explain why interpolate is behaving in a weird way? Are these warning expected? |
This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days |
I'd like to see this issue fixed as it is preventing us from using the software for our use case currently. |
@narendasan or @bowang007 is anyone working on this? If not, what could be done to help? |
Hi @philippewarren, I went through this several weeks ago but I don't think we have seen such issues previously. |
This seems to be caused by a combination of an upscale and multiple return values for the import torch
import torch.nn as nn
import torch_tensorrt
class MRE(nn.Module):
def __init__(self):
super(MRE, self).__init__()
self._upsample = nn.Upsample(scale_factor=2, mode='nearest')
def forward(self, x):
y = self._upsample(x)
return [x, y]
model = MRE()
x = torch.ones((1, 3, 416, 416))
model.eval()
device = torch.device('cuda')
model = model.to(device)
trt_module = torch_tensorrt.compile(
model,
inputs=[x.to(device)],
enabled_precisions={torch.float},
)
torch.jit.save(trt_module, "mre.trt.pth") The same error happens if the list returned ( This is the associated output:
|
This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days |
any update? |
Compiling a Scripted model gives error at torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
To Reproduce
Steps to reproduce the behavior:
Expected behavior
No error.
Environment
conda
,pip
,libtorch
, source): condaAdditional context
The error occurs at
torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
line in the network definition. JIT Scripting works.The text was updated successfully, but these errors were encountered: