Skip to content

fix: output shape bug in deconv #2537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2106,7 +2106,7 @@ def aten_ops_convolution(
is_conv1d=len(args[3]) == 1,
input=args[0],
weight=args[1],
bias=args[2],
bias=args_bounds_check(args, 2, None),
stride=args[3],
padding=args[4],
dilation=args[5],
Expand All @@ -2121,7 +2121,7 @@ def aten_ops_convolution(
is_deconv1d=len(args[3]) == 1,
input=args[0],
weight=args[1],
bias=args[2],
bias=args_bounds_check(args, 2, None),
stride=args[3],
padding=args[4],
dilation=args[5],
Expand Down
7 changes: 4 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ def convNd(
input: TRTTensor,
weight: Union[TRTTensor, torch.Tensor, np.ndarray],
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
stride: Optional[Union[int, Sequence[int]]],
padding: Optional[Union[int, Sequence[int]]],
dilation: Optional[Union[int, Sequence[int]]],
stride: Union[int, Sequence[int]],
padding: Union[int, Sequence[int]],
dilation: Union[int, Sequence[int]],
groups: Optional[int],
output_padding: Union[int, Sequence[int]] = 0,
scale: Optional[Union[torch.Tensor, float]] = None,
zero_point: Optional[Union[torch.Tensor, float]] = None,
) -> TRTTensor:
Expand Down
9 changes: 5 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ def deconvNd(
input: TRTTensor,
weight: Union[TRTTensor, torch.Tensor, np.ndarray],
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
stride: Optional[Union[int, Sequence[int]]],
padding: Optional[Union[int, Sequence[int]]],
stride: Union[int, Sequence[int]],
padding: Union[int, Sequence[int]],
dilation: Union[int, Sequence[int]],
groups: Optional[int],
dilation: Optional[Union[int, Sequence[int]]],
output_padding: Union[int, Sequence[int]] = 0,
scale: Optional[Union[torch.Tensor, float]] = None,
zero_point: Optional[Union[torch.Tensor, float]] = None,
) -> TRTTensor:
Expand Down Expand Up @@ -86,7 +87,7 @@ def deconvNd(
# add deconv layer
deconv_layer = ctx.net.add_deconvolution_nd(
input=input,
num_output_maps=weight.shape[0],
num_output_maps=weight.shape[1] * groups,
kernel_shape=weight.shape[2:],
kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight,
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
Expand Down
34 changes: 34 additions & 0 deletions tests/py/dynamo/backend/test_specialized_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,5 +365,39 @@ def forward(self, x, y):
torch._dynamo.reset()


class TestDeconvolution(TestCase):
def test_ConvTranspose2d(self):
class Up(torch.nn.Module):
def __init__(self, in_channels, out_channels, upsample_stride):
super().__init__()
self.up = torch.nn.ConvTranspose2d(
in_channels,
out_channels,
upsample_stride,
stride=upsample_stride,
bias=False,
)

def forward(self, x):
return self.up(x)

device = torch.device("cuda:0")
model = Up(64, 128, 2).to(device)
model.eval()
print(model)

x = torch.rand((1, 64, 100, 100)).to(device)
model_opt = torch.compile(
model,
backend="torch_tensorrt",
options={
"min_block_size": 1,
"debug": True,
},
)
with torch.no_grad():
_ = model_opt(x)


if __name__ == "__main__":
run_tests()