From 6a9a90f7cd9c72ca6fc98d29e8fad6fe46541d52 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 12 Dec 2023 19:12:38 -0800 Subject: [PATCH 1/2] fix: output shape bug in deconv --- .../dynamo/conversion/aten_ops_converters.py | 4 ++-- py/torch_tensorrt/dynamo/conversion/impl/conv.py | 7 ++++--- py/torch_tensorrt/dynamo/conversion/impl/deconv.py | 9 +++++---- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index d0646e4bc6..fb5db527fb 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2106,7 +2106,7 @@ def aten_ops_convolution( is_conv1d=len(args[3]) == 1, input=args[0], weight=args[1], - bias=args[2], + bias=args_bounds_check(args, 2, None), stride=args[3], padding=args[4], dilation=args[5], @@ -2121,7 +2121,7 @@ def aten_ops_convolution( is_deconv1d=len(args[3]) == 1, input=args[0], weight=args[1], - bias=args[2], + bias=args_bounds_check(args, 2, None), stride=args[3], padding=args[4], dilation=args[5], diff --git a/py/torch_tensorrt/dynamo/conversion/impl/conv.py b/py/torch_tensorrt/dynamo/conversion/impl/conv.py index 33b5fcbd87..26e0d59b8f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/conv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/conv.py @@ -32,10 +32,11 @@ def convNd( input: TRTTensor, weight: Union[TRTTensor, torch.Tensor, np.ndarray], bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], - stride: Optional[Union[int, Sequence[int]]], - padding: Optional[Union[int, Sequence[int]]], - dilation: Optional[Union[int, Sequence[int]]], + stride: Union[int, Sequence[int]], + padding: Union[int, Sequence[int]], + dilation: Union[int, Sequence[int]], groups: Optional[int], + output_padding: Union[int, Sequence[int]] = 0, scale: Optional[Union[torch.Tensor, float]] = None, zero_point: Optional[Union[torch.Tensor, float]] = None, ) -> TRTTensor: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py index ebb9b1bec2..f66bff7c82 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py @@ -32,10 +32,11 @@ def deconvNd( input: TRTTensor, weight: Union[TRTTensor, torch.Tensor, np.ndarray], bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], - stride: Optional[Union[int, Sequence[int]]], - padding: Optional[Union[int, Sequence[int]]], + stride: Union[int, Sequence[int]], + padding: Union[int, Sequence[int]], + dilation: Union[int, Sequence[int]], groups: Optional[int], - dilation: Optional[Union[int, Sequence[int]]], + output_padding: Union[int, Sequence[int]] = 0, scale: Optional[Union[torch.Tensor, float]] = None, zero_point: Optional[Union[torch.Tensor, float]] = None, ) -> TRTTensor: @@ -86,7 +87,7 @@ def deconvNd( # add deconv layer deconv_layer = ctx.net.add_deconvolution_nd( input=input, - num_output_maps=weight.shape[0], + num_output_maps=weight.shape[1] * groups, kernel_shape=weight.shape[2:], kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight, bias=trt.Weights() if isinstance(bias, TRTTensor) else bias, From a3932ed3f0866eaffa8bcfb6af95cfaf05d94b04 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 12 Dec 2023 20:44:15 -0800 Subject: [PATCH 2/2] add a deconv model to tests --- .../dynamo/backend/test_specialized_models.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/py/dynamo/backend/test_specialized_models.py b/tests/py/dynamo/backend/test_specialized_models.py index db9520ccd5..3885627b5f 100644 --- a/tests/py/dynamo/backend/test_specialized_models.py +++ b/tests/py/dynamo/backend/test_specialized_models.py @@ -365,5 +365,39 @@ def forward(self, x, y): torch._dynamo.reset() +class TestDeconvolution(TestCase): + def test_ConvTranspose2d(self): + class Up(torch.nn.Module): + def __init__(self, in_channels, out_channels, upsample_stride): + super().__init__() + self.up = torch.nn.ConvTranspose2d( + in_channels, + out_channels, + upsample_stride, + stride=upsample_stride, + bias=False, + ) + + def forward(self, x): + return self.up(x) + + device = torch.device("cuda:0") + model = Up(64, 128, 2).to(device) + model.eval() + print(model) + + x = torch.rand((1, 64, 100, 100)).to(device) + model_opt = torch.compile( + model, + backend="torch_tensorrt", + options={ + "min_block_size": 1, + "debug": True, + }, + ) + with torch.no_grad(): + _ = model_opt(x) + + if __name__ == "__main__": run_tests()