Skip to content

Commit 1cd215d

Browse files
authored
fix: conv parameter check failure (#3428)
1 parent 153f921 commit 1cd215d

File tree

3 files changed

+36
-9
lines changed

3 files changed

+36
-9
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2468,13 +2468,14 @@ def aten_ops_convolution(
24682468
name: str,
24692469
) -> Union[TRTTensor, Sequence[TRTTensor]]:
24702470
is_transposed = args[6]
2471+
is_conv1d = len(args[0].shape) == 3
24712472
if not is_transposed:
24722473
return impl.conv.convNd(
24732474
ctx,
24742475
target,
24752476
source_ir=SourceIR.ATEN,
24762477
name=name,
2477-
is_conv1d=len(args[3]) == 1,
2478+
is_conv1d=is_conv1d,
24782479
input=args[0],
24792480
weight=args[1],
24802481
bias=args_bounds_check(args, 2, None),
@@ -2489,7 +2490,7 @@ def aten_ops_convolution(
24892490
target,
24902491
source_ir=SourceIR.ATEN,
24912492
name=name,
2492-
is_deconv1d=len(args[3]) == 1,
2493+
is_deconv1d=is_conv1d,
24932494
input=args[0],
24942495
weight=args[1],
24952496
bias=args_bounds_check(args, 2, None),

py/torch_tensorrt/dynamo/conversion/impl/conv.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import tensorrt as trt
77
import torch
88
from torch.fx.node import Target
9-
109
from torch_tensorrt.dynamo.conversion import impl
1110
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1211
from torch_tensorrt.dynamo.conversion.converter_utils import (
@@ -45,6 +44,8 @@ def convNd(
4544
if has_dynamic_shape(input.shape):
4645
assert input.shape[1] != -1, "Channel dim can't be dynamic for convolution."
4746

47+
num_dims = len(input.shape) - 2
48+
4849
if is_conv1d:
4950
# Apply an unsqueeze operation to transform the conv1d problem into conv2d
5051
input = impl.unsqueeze.unsqueeze(
@@ -104,9 +105,26 @@ def convNd(
104105
conv_layer.set_input(2, bias)
105106

106107
# Cast certain fields to tuples, in accordance with TRT requirements
107-
padding = (padding,) if isinstance(padding, int) else padding
108-
stride = (stride,) if isinstance(stride, int) else stride
109-
dilation = (dilation,) if isinstance(dilation, int) else dilation
108+
if isinstance(padding, int):
109+
padding = (padding,) * num_dims
110+
elif isinstance(padding, (list, tuple)):
111+
padding = tuple(padding)
112+
if len(padding) == 1:
113+
padding = (padding[0],) * num_dims
114+
115+
if isinstance(stride, int):
116+
stride = (stride,) * num_dims
117+
elif isinstance(stride, (list, tuple)):
118+
stride = tuple(stride)
119+
if len(stride) == 1:
120+
stride = (stride[0],) * num_dims
121+
122+
if isinstance(dilation, int):
123+
dilation = (dilation,) * num_dims
124+
elif isinstance(dilation, (list, tuple)):
125+
dilation = tuple(dilation)
126+
if len(dilation) == 1:
127+
dilation = (dilation[0],) * num_dims
110128

111129
# Expand parameters manually for Conv1D computations
112130
if is_conv1d:

tests/py/dynamo/conversion/test_convolution_aten.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
from parameterized import param, parameterized
33
from torch.testing._internal.common_utils import run_tests
4-
54
from torch_tensorrt import Input
65

76
from .harness import DispatchTestCase
@@ -133,8 +132,13 @@ def forward(self, x):
133132
("default", 1),
134133
param("no_bias", 1, bias=False),
135134
("tuple_parameters", 1, (1, 1), (1, 1)),
135+
param("list_stride", 2, stride=[2]),
136136
param("non_zero_padding", 1, padding=1),
137-
param("dilation", 1, dilation=2),
137+
param("list_zero_padding", 1, padding=[0]),
138+
param("list_non_padding", 1, padding=[1]),
139+
param("dilation", 2, dilation=3),
140+
param("tuple_dilation", 2, dilation=(3, 3)),
141+
param("list_dilation", 2, dilation=[3]),
138142
param("groups", 1, groups=3),
139143
]
140144
)
@@ -204,8 +208,12 @@ def forward(self, x):
204208
("default", 1),
205209
param("no_bias", 1, bias=False),
206210
("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)),
211+
param("list_stride", 2, stride=[2]),
207212
param("non_zero_padding", 1, padding=1),
208-
param("dilation", 1, dilation=2),
213+
param("list_zero_padding", 1, padding=[0]),
214+
param("list_non_padding", 1, padding=[1]),
215+
param("dilation", 2, dilation=2),
216+
param("list_dilation", 2, dilation=[2]),
209217
## TODO TRT 8.4.1 will trigger issue with this test. T127981773
210218
# param("groups", 1, groups=3),
211219
]

0 commit comments

Comments
 (0)