Skip to content

Commit f14682a

Browse files
authored
Generalize ConvNormActivation function to accept tuple for some parameters (#6251)
* Make ConvNormActivation function accept tuple for kernel_size, stride, padding, and dilation * Fix the method to get the conv_dim * Simplify if-elif logic
1 parent 5416031 commit f14682a

File tree

2 files changed

+38
-15
lines changed

2 files changed

+38
-15
lines changed

torchvision/ops/misc.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import warnings
2-
from typing import Callable, List, Optional
2+
from typing import Callable, List, Optional, Union, Tuple, Sequence
33

44
import torch
55
from torch import Tensor
66

7-
from ..utils import _log_api_usage_once
7+
from ..utils import _log_api_usage_once, _make_ntuple
88

99

1010
interpolate = torch.nn.functional.interpolate
@@ -70,20 +70,26 @@ def __init__(
7070
self,
7171
in_channels: int,
7272
out_channels: int,
73-
kernel_size: int = 3,
74-
stride: int = 1,
75-
padding: Optional[int] = None,
73+
kernel_size: Union[int, Tuple[int, ...]] = 3,
74+
stride: Union[int, Tuple[int, ...]] = 1,
75+
padding: Optional[Union[int, Tuple[int, ...], str]] = None,
7676
groups: int = 1,
7777
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
7878
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
79-
dilation: int = 1,
79+
dilation: Union[int, Tuple[int, ...]] = 1,
8080
inplace: Optional[bool] = True,
8181
bias: Optional[bool] = None,
8282
conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
8383
) -> None:
8484

8585
if padding is None:
86-
padding = (kernel_size - 1) // 2 * dilation
86+
if isinstance(kernel_size, int) and isinstance(dilation, int):
87+
padding = (kernel_size - 1) // 2 * dilation
88+
else:
89+
_conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation)
90+
kernel_size = _make_ntuple(kernel_size, _conv_dim)
91+
dilation = _make_ntuple(dilation, _conv_dim)
92+
padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim))
8793
if bias is None:
8894
bias = norm_layer is None
8995

@@ -139,13 +145,13 @@ def __init__(
139145
self,
140146
in_channels: int,
141147
out_channels: int,
142-
kernel_size: int = 3,
143-
stride: int = 1,
144-
padding: Optional[int] = None,
148+
kernel_size: Union[int, Tuple[int, int]] = 3,
149+
stride: Union[int, Tuple[int, int]] = 1,
150+
padding: Optional[Union[int, Tuple[int, int], str]] = None,
145151
groups: int = 1,
146152
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
147153
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
148-
dilation: int = 1,
154+
dilation: Union[int, Tuple[int, int]] = 1,
149155
inplace: Optional[bool] = True,
150156
bias: Optional[bool] = None,
151157
) -> None:
@@ -188,13 +194,13 @@ def __init__(
188194
self,
189195
in_channels: int,
190196
out_channels: int,
191-
kernel_size: int = 3,
192-
stride: int = 1,
193-
padding: Optional[int] = None,
197+
kernel_size: Union[int, Tuple[int, int, int]] = 3,
198+
stride: Union[int, Tuple[int, int, int]] = 1,
199+
padding: Optional[Union[int, Tuple[int, int, int], str]] = None,
194200
groups: int = 1,
195201
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d,
196202
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
197-
dilation: int = 1,
203+
dilation: Union[int, Tuple[int, int, int]] = 1,
198204
inplace: Optional[bool] = True,
199205
bias: Optional[bool] = None,
200206
) -> None:

torchvision/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import collections
12
import math
23
import pathlib
34
import warnings
5+
from itertools import repeat
46
from types import FunctionType
57
from typing import Any, BinaryIO, List, Optional, Tuple, Union
68

@@ -569,3 +571,18 @@ def _log_api_usage_once(obj: Any) -> None:
569571
if isinstance(obj, FunctionType):
570572
name = obj.__name__
571573
torch._C._log_api_usage_once(f"{module}.{name}")
574+
575+
576+
def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
577+
"""
578+
Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
579+
Otherwise we will make a tuple of length n, all with value of x.
580+
reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8
581+
582+
Args:
583+
x (Any): input value
584+
n (int): length of the resulting tuple
585+
"""
586+
if isinstance(x, collections.abc.Iterable):
587+
return tuple(x)
588+
return tuple(repeat(x, n))

0 commit comments

Comments
 (0)