Skip to content

fix: conv parameter check failure #3428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2468,13 +2468,14 @@ def aten_ops_convolution(
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
is_transposed = args[6]
is_conv1d = len(args[0].shape) == 3
if not is_transposed:
return impl.conv.convNd(
ctx,
target,
source_ir=SourceIR.ATEN,
name=name,
is_conv1d=len(args[3]) == 1,
is_conv1d=is_conv1d,
input=args[0],
weight=args[1],
bias=args_bounds_check(args, 2, None),
Expand All @@ -2489,7 +2490,7 @@ def aten_ops_convolution(
target,
source_ir=SourceIR.ATEN,
name=name,
is_deconv1d=len(args[3]) == 1,
is_deconv1d=is_conv1d,
input=args[0],
weight=args[1],
bias=args_bounds_check(args, 2, None),
Expand Down
26 changes: 22 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import tensorrt as trt
import torch
from torch.fx.node import Target

from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
Expand Down Expand Up @@ -45,6 +44,8 @@ def convNd(
if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for convolution."

num_dims = len(input.shape) - 2

if is_conv1d:
# Apply an unsqueeze operation to transform the conv1d problem into conv2d
input = impl.unsqueeze.unsqueeze(
Expand Down Expand Up @@ -104,9 +105,26 @@ def convNd(
conv_layer.set_input(2, bias)

# Cast certain fields to tuples, in accordance with TRT requirements
padding = (padding,) if isinstance(padding, int) else padding
stride = (stride,) if isinstance(stride, int) else stride
dilation = (dilation,) if isinstance(dilation, int) else dilation
if isinstance(padding, int):
padding = (padding,) * num_dims
elif isinstance(padding, (list, tuple)):
padding = tuple(padding)
if len(padding) == 1:
padding = (padding[0],) * num_dims

if isinstance(stride, int):
stride = (stride,) * num_dims
elif isinstance(stride, (list, tuple)):
stride = tuple(stride)
if len(stride) == 1:
stride = (stride[0],) * num_dims

if isinstance(dilation, int):
dilation = (dilation,) * num_dims
elif isinstance(dilation, (list, tuple)):
dilation = tuple(dilation)
if len(dilation) == 1:
dilation = (dilation[0],) * num_dims

# Expand parameters manually for Conv1D computations
if is_conv1d:
Expand Down
14 changes: 11 additions & 3 deletions tests/py/dynamo/conversion/test_convolution_aten.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from parameterized import param, parameterized
from torch.testing._internal.common_utils import run_tests

from torch_tensorrt import Input

from .harness import DispatchTestCase
Expand Down Expand Up @@ -133,8 +132,13 @@ def forward(self, x):
("default", 1),
param("no_bias", 1, bias=False),
("tuple_parameters", 1, (1, 1), (1, 1)),
param("list_stride", 2, stride=[2]),
param("non_zero_padding", 1, padding=1),
param("dilation", 1, dilation=2),
param("list_zero_padding", 1, padding=[0]),
param("list_non_padding", 1, padding=[1]),
param("dilation", 2, dilation=3),
param("tuple_dilation", 2, dilation=(3, 3)),
param("list_dilation", 2, dilation=[3]),
param("groups", 1, groups=3),
]
)
Expand Down Expand Up @@ -204,8 +208,12 @@ def forward(self, x):
("default", 1),
param("no_bias", 1, bias=False),
("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)),
param("list_stride", 2, stride=[2]),
param("non_zero_padding", 1, padding=1),
param("dilation", 1, dilation=2),
param("list_zero_padding", 1, padding=[0]),
param("list_non_padding", 1, padding=[1]),
param("dilation", 2, dilation=2),
param("list_dilation", 2, dilation=[2]),
## TODO TRT 8.4.1 will trigger issue with this test. T127981773
# param("groups", 1, groups=3),
]
Expand Down