|
1 | 1 | import warnings
|
2 |
| -from typing import Callable, List, Optional |
| 2 | +from typing import Callable, List, Optional, Union, Tuple, Sequence |
3 | 3 |
|
4 | 4 | import torch
|
5 | 5 | from torch import Tensor
|
6 | 6 |
|
7 |
| -from ..utils import _log_api_usage_once |
| 7 | +from ..utils import _log_api_usage_once, _make_ntuple |
8 | 8 |
|
9 | 9 |
|
10 | 10 | interpolate = torch.nn.functional.interpolate
|
@@ -70,20 +70,26 @@ def __init__(
|
70 | 70 | self,
|
71 | 71 | in_channels: int,
|
72 | 72 | out_channels: int,
|
73 |
| - kernel_size: int = 3, |
74 |
| - stride: int = 1, |
75 |
| - padding: Optional[int] = None, |
| 73 | + kernel_size: Union[int, Tuple[int, ...]] = 3, |
| 74 | + stride: Union[int, Tuple[int, ...]] = 1, |
| 75 | + padding: Optional[Union[int, Tuple[int, ...], str]] = None, |
76 | 76 | groups: int = 1,
|
77 | 77 | norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
|
78 | 78 | activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
|
79 |
| - dilation: int = 1, |
| 79 | + dilation: Union[int, Tuple[int, ...]] = 1, |
80 | 80 | inplace: Optional[bool] = True,
|
81 | 81 | bias: Optional[bool] = None,
|
82 | 82 | conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
|
83 | 83 | ) -> None:
|
84 | 84 |
|
85 | 85 | if padding is None:
|
86 |
| - padding = (kernel_size - 1) // 2 * dilation |
| 86 | + if isinstance(kernel_size, int) and isinstance(dilation, int): |
| 87 | + padding = (kernel_size - 1) // 2 * dilation |
| 88 | + else: |
| 89 | + _conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation) |
| 90 | + kernel_size = _make_ntuple(kernel_size, _conv_dim) |
| 91 | + dilation = _make_ntuple(dilation, _conv_dim) |
| 92 | + padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim)) |
87 | 93 | if bias is None:
|
88 | 94 | bias = norm_layer is None
|
89 | 95 |
|
@@ -139,13 +145,13 @@ def __init__(
|
139 | 145 | self,
|
140 | 146 | in_channels: int,
|
141 | 147 | out_channels: int,
|
142 |
| - kernel_size: int = 3, |
143 |
| - stride: int = 1, |
144 |
| - padding: Optional[int] = None, |
| 148 | + kernel_size: Union[int, Tuple[int, int]] = 3, |
| 149 | + stride: Union[int, Tuple[int, int]] = 1, |
| 150 | + padding: Optional[Union[int, Tuple[int, int], str]] = None, |
145 | 151 | groups: int = 1,
|
146 | 152 | norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
|
147 | 153 | activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
|
148 |
| - dilation: int = 1, |
| 154 | + dilation: Union[int, Tuple[int, int]] = 1, |
149 | 155 | inplace: Optional[bool] = True,
|
150 | 156 | bias: Optional[bool] = None,
|
151 | 157 | ) -> None:
|
@@ -188,13 +194,13 @@ def __init__(
|
188 | 194 | self,
|
189 | 195 | in_channels: int,
|
190 | 196 | out_channels: int,
|
191 |
| - kernel_size: int = 3, |
192 |
| - stride: int = 1, |
193 |
| - padding: Optional[int] = None, |
| 197 | + kernel_size: Union[int, Tuple[int, int, int]] = 3, |
| 198 | + stride: Union[int, Tuple[int, int, int]] = 1, |
| 199 | + padding: Optional[Union[int, Tuple[int, int, int], str]] = None, |
194 | 200 | groups: int = 1,
|
195 | 201 | norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d,
|
196 | 202 | activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
|
197 |
| - dilation: int = 1, |
| 203 | + dilation: Union[int, Tuple[int, int, int]] = 1, |
198 | 204 | inplace: Optional[bool] = True,
|
199 | 205 | bias: Optional[bool] = None,
|
200 | 206 | ) -> None:
|
|
0 commit comments