From 0bf603983ca9adfa648a2dda02e817a6af9b7092 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Sat, 19 Feb 2022 21:38:39 +0530 Subject: [PATCH 01/13] Add ops.conv3d --- docs/source/ops.rst | 1 + torchvision/ops/__init__.py | 3 +- torchvision/ops/misc.py | 59 +++++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 1 deletion(-) diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 2a960474205..19ca9761186 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -46,4 +46,5 @@ Operators StochasticDepth FrozenBatchNorm2d ConvNormActivation + Conv3dNormActivation SqueezeExcitation diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 8ba10080c1f..e0685e009c8 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -14,7 +14,7 @@ from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss from .giou_loss import generalized_box_iou_loss -from .misc import FrozenBatchNorm2d, ConvNormActivation, SqueezeExcitation +from .misc import FrozenBatchNorm2d, ConvNormActivation, Conv3dNormActivation, SqueezeExcitation from .poolers import MultiScaleRoIAlign from .ps_roi_align import ps_roi_align, PSRoIAlign from .ps_roi_pool import ps_roi_pool, PSRoIPool @@ -52,6 +52,7 @@ "StochasticDepth", "FrozenBatchNorm2d", "ConvNormActivation", + "Conv3dNormActivation", "SqueezeExcitation", "generalized_box_iou_loss", ] diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 268962e204c..c424f03fa14 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -161,3 +161,62 @@ def _scale(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor: scale = self._scale(input) return scale * input + + +class Conv3dNormActivation(torch.nn.Sequential): + """ + Configurable block used for Convolution-Normalzation-Activation blocks. + + Args: + in_channels (int): Number of channels in the input video. + out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block + kernel_size: (int, optional): Size of the convolving kernel. Default: 3 + stride (int, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm3d`` + activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` + dilation (int): Spacing between kernel elements. Default: 1 + inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` + bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: Optional[int] = None, + groups: int = 1, + norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + dilation: int = 1, + inplace: Optional[bool] = True, + bias: Optional[bool] = None, + ) -> None: + + if padding is None: + padding = (kernel_size - 1) // 2 * dilation + if bias is None: + bias = norm_layer is None + layers = [ + torch.nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + ] + if norm_layer is not None: + layers.append(norm_layer(out_channels)) + if activation_layer is not None: + params = {} if inplace is None else {"inplace": inplace} + layers.append(activation_layer(**params)) + super().__init__(*layers) + _log_api_usage_once(self) + self.out_channels = out_channels From e3cb4698ac11b106155bf7f4527e7e36cc2a392b Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 22 Feb 2022 22:00:41 +0530 Subject: [PATCH 02/13] Refactor for conv2d and 3d --- docs/source/ops.rst | 1 + torchvision/ops/__init__.py | 4 +- torchvision/ops/misc.py | 229 ++++++++++++++++++++++++------------ 3 files changed, 160 insertions(+), 74 deletions(-) diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 19ca9761186..c84097e8768 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -46,5 +46,6 @@ Operators StochasticDepth FrozenBatchNorm2d ConvNormActivation + Conv2dNormActivation Conv3dNormActivation SqueezeExcitation diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index e0685e009c8..f1e2322f11f 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -14,7 +14,7 @@ from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss from .giou_loss import generalized_box_iou_loss -from .misc import FrozenBatchNorm2d, ConvNormActivation, Conv3dNormActivation, SqueezeExcitation +from .misc import FrozenBatchNorm2d, ConvNormActivation, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation from .poolers import MultiScaleRoIAlign from .ps_roi_align import ps_roi_align, PSRoIAlign from .ps_roi_pool import ps_roi_pool, PSRoIPool @@ -52,7 +52,7 @@ "StochasticDepth", "FrozenBatchNorm2d", "ConvNormActivation", - "Conv3dNormActivation", + "Conv2dNormActivation" "Conv3dNormActivation", "SqueezeExcitation", "generalized_box_iou_loss", ] diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index c424f03fa14..5d418793173 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -1,3 +1,4 @@ +import warnings from typing import Callable, List, Optional import torch @@ -65,29 +66,12 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})" -class ConvNormActivation(torch.nn.Sequential): - """ - Configurable block used for Convolution-Normalzation-Activation blocks. - - Args: - in_channels (int): Number of channels in the input image - out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block - kernel_size: (int, optional): Size of the convolving kernel. Default: 3 - stride (int, optional): Stride of the convolution. Default: 1 - padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` - groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 - norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolutiuon layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d`` - activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` - dilation (int): Spacing between kernel elements. Default: 1 - inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` - bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. - - """ - +class _ConvNormActivation(torch.nn.Sequential): def __init__( self, in_channels: int, out_channels: int, + layer: Callable[..., torch.nn.Module], kernel_size: int = 3, stride: int = 1, padding: Optional[int] = None, @@ -98,12 +82,14 @@ def __init__( inplace: Optional[bool] = True, bias: Optional[bool] = None, ) -> None: + if padding is None: padding = (kernel_size - 1) // 2 * dilation if bias is None: bias = norm_layer is None + layers = [ - torch.nn.Conv2d( + layer( in_channels, out_channels, kernel_size, @@ -114,56 +100,125 @@ def __init__( bias=bias, ) ] + if norm_layer is not None: layers.append(norm_layer(out_channels)) + if activation_layer is not None: params = {} if inplace is None else {"inplace": inplace} layers.append(activation_layer(**params)) super().__init__(*layers) - _log_api_usage_once(self) + self.out_channels = out_channels -class SqueezeExcitation(torch.nn.Module): +class ConvNormActivation(_ConvNormActivation): """ - This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1). - Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in in eq. 3. + Configurable block used for Convolution-Normalzation-Activation blocks. Args: - input_channels (int): Number of channels in the input image - squeeze_channels (int): Number of squeeze channels - activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU`` - scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid`` + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block + kernel_size: (int, optional): Size of the convolving kernel. Default: 3 + stride (int, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d`` + activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` + dilation (int): Spacing between kernel elements. Default: 1 + inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` + bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. + """ def __init__( self, - input_channels: int, - squeeze_channels: int, - activation: Callable[..., torch.nn.Module] = torch.nn.ReLU, - scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: Optional[int] = None, + groups: int = 1, + norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + dilation: int = 1, + inplace: Optional[bool] = True, + bias: Optional[bool] = None, ) -> None: - super().__init__() - _log_api_usage_once(self) - self.avgpool = torch.nn.AdaptiveAvgPool2d(1) - self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1) - self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1) - self.activation = activation() - self.scale_activation = scale_activation() + warnings.warn( + "The ConvNormActivation class are deprecated since 0.13 and will be removed in 0.15. " + "Use torchvision.ops.Conv2dNormActivation instead.", + FutureWarning, + ) + layer = torch.nn.Conv2d + super().__init__( + in_channels, + out_channels, + layer, + kernel_size, + stride, + padding, + groups, + norm_layer, + activation_layer, + dilation, + inplace, + bias, + ) - def _scale(self, input: Tensor) -> Tensor: - scale = self.avgpool(input) - scale = self.fc1(scale) - scale = self.activation(scale) - scale = self.fc2(scale) - return self.scale_activation(scale) - def forward(self, input: Tensor) -> Tensor: - scale = self._scale(input) - return scale * input +class Conv2dNormActivation(_ConvNormActivation): + """ + Configurable block used for Convolution-Normalzation-Activation blocks. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block + kernel_size: (int, optional): Size of the convolving kernel. Default: 3 + stride (int, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d`` + activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` + dilation (int): Spacing between kernel elements. Default: 1 + inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` + bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. + + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: Optional[int] = None, + groups: int = 1, + norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + dilation: int = 1, + inplace: Optional[bool] = True, + bias: Optional[bool] = None, + ) -> None: + + layer = torch.nn.Conv2d + super().__init__( + in_channels, + out_channels, + layer, + kernel_size, + stride, + padding, + groups, + norm_layer, + activation_layer, + dilation, + inplace, + bias, + ) -class Conv3dNormActivation(torch.nn.Sequential): +class Conv3dNormActivation(_ConvNormActivation): """ Configurable block used for Convolution-Normalzation-Activation blocks. @@ -196,27 +251,57 @@ def __init__( bias: Optional[bool] = None, ) -> None: - if padding is None: - padding = (kernel_size - 1) // 2 * dilation - if bias is None: - bias = norm_layer is None - layers = [ - torch.nn.Conv3d( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation=dilation, - groups=groups, - bias=bias, - ) - ] - if norm_layer is not None: - layers.append(norm_layer(out_channels)) - if activation_layer is not None: - params = {} if inplace is None else {"inplace": inplace} - layers.append(activation_layer(**params)) - super().__init__(*layers) + layer = torch.nn.Conv3d + super().__init__( + in_channels, + out_channels, + layer, + kernel_size, + stride, + padding, + groups, + norm_layer, + activation_layer, + dilation, + inplace, + bias, + ) + + +class SqueezeExcitation(torch.nn.Module): + """ + This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1). + Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in in eq. 3. + + Args: + input_channels (int): Number of channels in the input image + squeeze_channels (int): Number of squeeze channels + activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU`` + scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid`` + """ + + def __init__( + self, + input_channels: int, + squeeze_channels: int, + activation: Callable[..., torch.nn.Module] = torch.nn.ReLU, + scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid, + ) -> None: + super().__init__() _log_api_usage_once(self) - self.out_channels = out_channels + self.avgpool = torch.nn.AdaptiveAvgPool2d(1) + self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1) + self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1) + self.activation = activation() + self.scale_activation = scale_activation() + + def _scale(self, input: Tensor) -> Tensor: + scale = self.avgpool(input) + scale = self.fc1(scale) + scale = self.activation(scale) + scale = self.fc2(scale) + return self.scale_activation(scale) + + def forward(self, input: Tensor) -> Tensor: + scale = self._scale(input) + return scale * input From 996a5a88878c16c35b66a416988a6fbe9c6bbee8 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 22 Feb 2022 23:08:41 +0530 Subject: [PATCH 03/13] Refactor --- torchvision/ops/__init__.py | 3 +- torchvision/ops/misc.py | 132 ++++++------------------------------ 2 files changed, 21 insertions(+), 114 deletions(-) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index f1e2322f11f..62a7dfaa128 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -52,7 +52,8 @@ "StochasticDepth", "FrozenBatchNorm2d", "ConvNormActivation", - "Conv2dNormActivation" "Conv3dNormActivation", + "Conv2dNormActivation", + "Conv3dNormActivation", "SqueezeExcitation", "generalized_box_iou_loss", ] diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 5d418793173..f690e2f02ec 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Any import torch from torch import Tensor @@ -112,61 +112,6 @@ def __init__( self.out_channels = out_channels -class ConvNormActivation(_ConvNormActivation): - """ - Configurable block used for Convolution-Normalzation-Activation blocks. - - Args: - in_channels (int): Number of channels in the input image - out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block - kernel_size: (int, optional): Size of the convolving kernel. Default: 3 - stride (int, optional): Stride of the convolution. Default: 1 - padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` - groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 - norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d`` - activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` - dilation (int): Spacing between kernel elements. Default: 1 - inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` - bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. - - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int = 3, - stride: int = 1, - padding: Optional[int] = None, - groups: int = 1, - norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, - activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, - dilation: int = 1, - inplace: Optional[bool] = True, - bias: Optional[bool] = None, - ) -> None: - warnings.warn( - "The ConvNormActivation class are deprecated since 0.13 and will be removed in 0.15. " - "Use torchvision.ops.Conv2dNormActivation instead.", - FutureWarning, - ) - layer = torch.nn.Conv2d - super().__init__( - in_channels, - out_channels, - layer, - kernel_size, - stride, - padding, - groups, - norm_layer, - activation_layer, - dilation, - inplace, - bias, - ) - - class Conv2dNormActivation(_ConvNormActivation): """ Configurable block used for Convolution-Normalzation-Activation blocks. @@ -186,36 +131,9 @@ class Conv2dNormActivation(_ConvNormActivation): """ - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int = 3, - stride: int = 1, - padding: Optional[int] = None, - groups: int = 1, - norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, - activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, - dilation: int = 1, - inplace: Optional[bool] = True, - bias: Optional[bool] = None, - ) -> None: - + def __init__(self, *args: Any, **kwargs: Any) -> None: layer = torch.nn.Conv2d - super().__init__( - in_channels, - out_channels, - layer, - kernel_size, - stride, - padding, - groups, - norm_layer, - activation_layer, - dilation, - inplace, - bias, - ) + super().__init__(layer, *args, **kwargs) class Conv3dNormActivation(_ConvNormActivation): @@ -236,36 +154,24 @@ class Conv3dNormActivation(_ConvNormActivation): bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. """ - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int = 3, - stride: int = 1, - padding: Optional[int] = None, - groups: int = 1, - norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d, - activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, - dilation: int = 1, - inplace: Optional[bool] = True, - bias: Optional[bool] = None, - ) -> None: - + def __init__(self, *args: Any, **kwargs: Any) -> None: layer = torch.nn.Conv3d - super().__init__( - in_channels, - out_channels, - layer, - kernel_size, - stride, - padding, - groups, - norm_layer, - activation_layer, - dilation, - inplace, - bias, + super().__init__(layer, *args, **kwargs) + + +class ConvNormActivation(Conv2dNormActivation): + """ + DEPRECATED + Use Conv2dNormActivation instead. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( + "The ConvNormActivation class are deprecated since 0.13 and will be removed in 0.15. " + "Use torchvision.ops.Conv2dNormActivation instead.", + FutureWarning, ) + super().__init__(*args, **kwargs) class SqueezeExcitation(torch.nn.Module): From 0eb697858951b8f9e118a9849279cb9bbd87b67e Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 22 Feb 2022 23:21:56 +0530 Subject: [PATCH 04/13] Fix bug --- torchvision/ops/misc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index f690e2f02ec..4d7e8a6b16b 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -133,7 +133,7 @@ class Conv2dNormActivation(_ConvNormActivation): def __init__(self, *args: Any, **kwargs: Any) -> None: layer = torch.nn.Conv2d - super().__init__(layer, *args, **kwargs) + super().__init__(layer=layer, *args, **kwargs) class Conv3dNormActivation(_ConvNormActivation): @@ -156,7 +156,7 @@ class Conv3dNormActivation(_ConvNormActivation): def __init__(self, *args: Any, **kwargs: Any) -> None: layer = torch.nn.Conv3d - super().__init__(layer, *args, **kwargs) + super().__init__(layer=layer, *args, **kwargs) class ConvNormActivation(Conv2dNormActivation): From 07be45ba4ed7135909d90ecadfcf044c47060880 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Wed, 23 Feb 2022 00:03:56 +0530 Subject: [PATCH 05/13] Addres review --- torchvision/ops/misc.py | 80 ++++++++++++++++++++++++++++++++++------- 1 file changed, 67 insertions(+), 13 deletions(-) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 4d7e8a6b16b..f58d6301132 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, List, Optional, Any +from typing import Callable, List, Optional import torch from torch import Tensor @@ -71,12 +71,12 @@ def __init__( self, in_channels: int, out_channels: int, - layer: Callable[..., torch.nn.Module], + conv_layer: Callable[..., torch.nn.Module], kernel_size: int = 3, stride: int = 1, padding: Optional[int] = None, groups: int = 1, - norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, + norm_layer: Optional[Callable[..., torch.nn.Module]] = None, activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, dilation: int = 1, inplace: Optional[bool] = True, @@ -89,7 +89,7 @@ def __init__( bias = norm_layer is None layers = [ - layer( + conv_layer( in_channels, out_channels, kernel_size, @@ -108,13 +108,13 @@ def __init__( params = {} if inplace is None else {"inplace": inplace} layers.append(activation_layer(**params)) super().__init__(*layers) - + _log_api_usage_once(self) self.out_channels = out_channels class Conv2dNormActivation(_ConvNormActivation): """ - Configurable block used for Convolution-Normalzation-Activation blocks. + Configurable block used for Convolution2d-Normalzation-Activation blocks. Args: in_channels (int): Number of channels in the input image @@ -131,14 +131,41 @@ class Conv2dNormActivation(_ConvNormActivation): """ - def __init__(self, *args: Any, **kwargs: Any) -> None: - layer = torch.nn.Conv2d - super().__init__(layer=layer, *args, **kwargs) + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: Optional[int] = None, + groups: int = 1, + norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + dilation: int = 1, + inplace: Optional[bool] = True, + bias: Optional[bool] = None, + ) -> None: + + conv_layer = torch.nn.Conv2d + super().__init__( + in_channels, + out_channels, + conv_layer, + kernel_size, + stride, + padding, + groups, + norm_layer, + activation_layer, + dilation, + inplace, + bias, + ) class Conv3dNormActivation(_ConvNormActivation): """ - Configurable block used for Convolution-Normalzation-Activation blocks. + Configurable block used for Convolution3d-Normalzation-Activation blocks. Args: in_channels (int): Number of channels in the input video. @@ -154,9 +181,36 @@ class Conv3dNormActivation(_ConvNormActivation): bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. """ - def __init__(self, *args: Any, **kwargs: Any) -> None: - layer = torch.nn.Conv3d - super().__init__(layer=layer, *args, **kwargs) + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: Optional[int] = None, + groups: int = 1, + norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + dilation: int = 1, + inplace: Optional[bool] = True, + bias: Optional[bool] = None, + ) -> None: + + conv_layer = torch.nn.Conv3d + super().__init__( + in_channels, + out_channels, + conv_layer, + kernel_size, + stride, + padding, + groups, + norm_layer, + activation_layer, + dilation, + inplace, + bias, + ) class ConvNormActivation(Conv2dNormActivation): From eecb9e40c8714c2c5447203757f26eca51a781ea Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Wed, 23 Feb 2022 22:10:58 +0530 Subject: [PATCH 06/13] Fix bug --- torchvision/ops/misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index f58d6301132..a9c140324e9 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Any import torch from torch import Tensor From 6ab0f960ce0ce5b87470a7b5d4ffb9da9dc2d060 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Thu, 24 Feb 2022 14:26:13 +0530 Subject: [PATCH 07/13] nit fix --- torchvision/ops/misc.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index a9c140324e9..363df39293c 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -146,11 +146,10 @@ def __init__( bias: Optional[bool] = None, ) -> None: - conv_layer = torch.nn.Conv2d super().__init__( in_channels, out_channels, - conv_layer, + torch.nn.Conv2d, kernel_size, stride, padding, @@ -196,11 +195,10 @@ def __init__( bias: Optional[bool] = None, ) -> None: - conv_layer = torch.nn.Conv3d super().__init__( in_channels, out_channels, - conv_layer, + torch.nn.Conv3d, kernel_size, stride, padding, From 1f772cf284ac9002e3aa1d1734115df67d665f10 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Thu, 24 Feb 2022 23:19:41 +0530 Subject: [PATCH 08/13] Fix flake --- torchvision/ops/misc.py | 37 +++++++++++++------------------------ 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 363df39293c..48750e2e3c9 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -1,5 +1,4 @@ -import warnings -from typing import Callable, List, Optional, Any +from typing import Callable, List, Optional import torch from torch import Tensor @@ -66,21 +65,26 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})" -class _ConvNormActivation(torch.nn.Sequential): +class ConvNormActivation(torch.nn.Sequential): + """ + DEPRECATED + Use Conv2dNormActivation instead. + """ + def __init__( self, in_channels: int, out_channels: int, - conv_layer: Callable[..., torch.nn.Module], kernel_size: int = 3, stride: int = 1, padding: Optional[int] = None, groups: int = 1, - norm_layer: Optional[Callable[..., torch.nn.Module]] = None, + norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, dilation: int = 1, inplace: Optional[bool] = True, bias: Optional[bool] = None, + conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d, ) -> None: if padding is None: @@ -112,7 +116,7 @@ def __init__( self.out_channels = out_channels -class Conv2dNormActivation(_ConvNormActivation): +class Conv2dNormActivation(ConvNormActivation): """ Configurable block used for Convolution2d-Normalzation-Activation blocks. @@ -149,7 +153,6 @@ def __init__( super().__init__( in_channels, out_channels, - torch.nn.Conv2d, kernel_size, stride, padding, @@ -159,10 +162,11 @@ def __init__( dilation, inplace, bias, + torch.nn.Conv2d, ) -class Conv3dNormActivation(_ConvNormActivation): +class Conv3dNormActivation(ConvNormActivation): """ Configurable block used for Convolution3d-Normalzation-Activation blocks. @@ -198,7 +202,6 @@ def __init__( super().__init__( in_channels, out_channels, - torch.nn.Conv3d, kernel_size, stride, padding, @@ -208,24 +211,10 @@ def __init__( dilation, inplace, bias, + torch.nn.Conv3d, ) -class ConvNormActivation(Conv2dNormActivation): - """ - DEPRECATED - Use Conv2dNormActivation instead. - """ - - def __init__(self, *args: Any, **kwargs: Any) -> None: - warnings.warn( - "The ConvNormActivation class are deprecated since 0.13 and will be removed in 0.15. " - "Use torchvision.ops.Conv2dNormActivation instead.", - FutureWarning, - ) - super().__init__(*args, **kwargs) - - class SqueezeExcitation(torch.nn.Module): """ This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1). From 748391e0e4dd3a80d9bfacc72da63a4a48e76822 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Thu, 24 Feb 2022 23:48:09 +0530 Subject: [PATCH 09/13] Final fix --- torchvision/ops/misc.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 48750e2e3c9..3eee0fdb5b3 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -1,3 +1,4 @@ +import warnings from typing import Callable, List, Optional import torch @@ -67,7 +68,6 @@ def __repr__(self) -> str: class ConvNormActivation(torch.nn.Sequential): """ - DEPRECATED Use Conv2dNormActivation instead. """ @@ -115,6 +115,9 @@ def __init__( _log_api_usage_once(self) self.out_channels = out_channels + if self.__class__ == ConvNormActivation: + warnings.warn("Don't use ConvNormActivation directly. Use Conv2dNormActivation instead.") + class Conv2dNormActivation(ConvNormActivation): """ From d832e02398c5af9db31d8190940c575e46b37380 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 24 Feb 2022 21:01:33 +0000 Subject: [PATCH 10/13] remove documentation --- torchvision/ops/misc.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 3eee0fdb5b3..ad18f4ad446 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -67,9 +67,6 @@ def __repr__(self) -> str: class ConvNormActivation(torch.nn.Sequential): - """ - Use Conv2dNormActivation instead. - """ def __init__( self, From 6a6c8448a92cf3bd906f4681864202c7b5b5f772 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 24 Feb 2022 21:07:38 +0000 Subject: [PATCH 11/13] fix linter --- torchvision/ops/misc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index ad18f4ad446..b542de90742 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -67,7 +67,6 @@ def __repr__(self) -> str: class ConvNormActivation(torch.nn.Sequential): - def __init__( self, in_channels: int, From 9a5ab98c2566486cb608e1c41cde02c654d199eb Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Fri, 25 Feb 2022 12:58:13 +0530 Subject: [PATCH 12/13] Update all the implementations to use new Conv --- torchvision/models/convnext.py | 4 +-- torchvision/models/detection/ssdlite.py | 10 +++--- torchvision/models/efficientnet.py | 12 +++---- torchvision/models/mobilenetv2.py | 14 ++++---- torchvision/models/mobilenetv3.py | 12 +++---- torchvision/models/optical_flow/raft.py | 32 ++++++++++--------- .../models/quantization/mobilenetv2.py | 4 +-- .../models/quantization/mobilenetv3.py | 4 +-- torchvision/models/regnet.py | 12 +++---- torchvision/models/vision_transformer.py | 4 +-- 10 files changed, 55 insertions(+), 53 deletions(-) diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 9067b6876fd..3a0dcdb31cd 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -6,7 +6,7 @@ from torch.nn import functional as F from .._internally_replaced_utils import load_state_dict_from_url -from ..ops.misc import ConvNormActivation +from ..ops.misc import Conv2dNormActivation from ..ops.stochastic_depth import StochasticDepth from ..utils import _log_api_usage_once @@ -127,7 +127,7 @@ def __init__( # Stem firstconv_output_channels = block_setting[0].input_channels layers.append( - ConvNormActivation( + Conv2dNormActivation( 3, firstconv_output_channels, kernel_size=4, diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 47ecf59f1e2..1ee59e069ea 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -7,7 +7,7 @@ from torch import nn, Tensor from ..._internally_replaced_utils import load_state_dict_from_url -from ...ops.misc import ConvNormActivation +from ...ops.misc import Conv2dNormActivation from ...utils import _log_api_usage_once from .. import mobilenet from . import _utils as det_utils @@ -29,7 +29,7 @@ def _prediction_block( ) -> nn.Sequential: return nn.Sequential( # 3x3 depthwise with stride 1 and padding 1 - ConvNormActivation( + Conv2dNormActivation( in_channels, in_channels, kernel_size=kernel_size, @@ -47,11 +47,11 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., intermediate_channels = out_channels // 2 return nn.Sequential( # 1x1 projection to half output channels - ConvNormActivation( + Conv2dNormActivation( in_channels, intermediate_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation ), # 3x3 depthwise with stride 2 and padding 1 - ConvNormActivation( + Conv2dNormActivation( intermediate_channels, intermediate_channels, kernel_size=3, @@ -61,7 +61,7 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., activation_layer=activation, ), # 1x1 projetion to output channels - ConvNormActivation( + Conv2dNormActivation( intermediate_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation ), ) diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index f7eba46cb39..f245d00cffe 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -8,7 +8,7 @@ from torchvision.ops import StochasticDepth from .._internally_replaced_utils import load_state_dict_from_url -from ..ops.misc import ConvNormActivation, SqueezeExcitation +from ..ops.misc import Conv2dNormActivation, SqueezeExcitation from ..utils import _log_api_usage_once from ._utils import _make_divisible @@ -104,7 +104,7 @@ def __init__( expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio) if expanded_channels != cnf.input_channels: layers.append( - ConvNormActivation( + Conv2dNormActivation( cnf.input_channels, expanded_channels, kernel_size=1, @@ -115,7 +115,7 @@ def __init__( # depthwise layers.append( - ConvNormActivation( + Conv2dNormActivation( expanded_channels, expanded_channels, kernel_size=cnf.kernel, @@ -132,7 +132,7 @@ def __init__( # project layers.append( - ConvNormActivation( + Conv2dNormActivation( expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None ) ) @@ -193,7 +193,7 @@ def __init__( # building first layer firstconv_output_channels = inverted_residual_setting[0].input_channels layers.append( - ConvNormActivation( + Conv2dNormActivation( 3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU ) ) @@ -224,7 +224,7 @@ def __init__( lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_output_channels = 4 * lastconv_input_channels layers.append( - ConvNormActivation( + Conv2dNormActivation( lastconv_input_channels, lastconv_output_channels, kernel_size=1, diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index e24c5962d7e..930f68d13e9 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -6,7 +6,7 @@ from torch import nn from .._internally_replaced_utils import load_state_dict_from_url -from ..ops.misc import ConvNormActivation +from ..ops.misc import Conv2dNormActivation from ..utils import _log_api_usage_once from ._utils import _make_divisible @@ -20,11 +20,11 @@ # necessary for backwards compatibility -class _DeprecatedConvBNAct(ConvNormActivation): +class _DeprecatedConvBNAct(Conv2dNormActivation): def __init__(self, *args, **kwargs): warnings.warn( "The ConvBNReLU/ConvBNActivation classes are deprecated since 0.12 and will be removed in 0.14. " - "Use torchvision.ops.misc.ConvNormActivation instead.", + "Use torchvision.ops.misc.Conv2dNormActivation instead.", FutureWarning, ) if kwargs.get("norm_layer", None) is None: @@ -56,12 +56,12 @@ def __init__( if expand_ratio != 1: # pw layers.append( - ConvNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6) + Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6) ) layers.extend( [ # dw - ConvNormActivation( + Conv2dNormActivation( hidden_dim, hidden_dim, stride=stride, @@ -144,7 +144,7 @@ def __init__( input_channel = _make_divisible(input_channel * width_mult, round_nearest) self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) features: List[nn.Module] = [ - ConvNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6) + Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6) ] # building inverted residual blocks for t, c, n, s in inverted_residual_setting: @@ -155,7 +155,7 @@ def __init__( input_channel = output_channel # building last several layers features.append( - ConvNormActivation( + Conv2dNormActivation( input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6 ) ) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 711888b7c8b..530467d6d53 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -6,7 +6,7 @@ from torch import nn, Tensor from .._internally_replaced_utils import load_state_dict_from_url -from ..ops.misc import ConvNormActivation, SqueezeExcitation as SElayer +from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer from ..utils import _log_api_usage_once from ._utils import _make_divisible @@ -83,7 +83,7 @@ def __init__( # expand if cnf.expanded_channels != cnf.input_channels: layers.append( - ConvNormActivation( + Conv2dNormActivation( cnf.input_channels, cnf.expanded_channels, kernel_size=1, @@ -95,7 +95,7 @@ def __init__( # depthwise stride = 1 if cnf.dilation > 1 else cnf.stride layers.append( - ConvNormActivation( + Conv2dNormActivation( cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, @@ -112,7 +112,7 @@ def __init__( # project layers.append( - ConvNormActivation( + Conv2dNormActivation( cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None ) ) @@ -172,7 +172,7 @@ def __init__( # building first layer firstconv_output_channels = inverted_residual_setting[0].input_channels layers.append( - ConvNormActivation( + Conv2dNormActivation( 3, firstconv_output_channels, kernel_size=3, @@ -190,7 +190,7 @@ def __init__( lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_output_channels = 6 * lastconv_input_channels layers.append( - ConvNormActivation( + Conv2dNormActivation( lastconv_input_channels, lastconv_output_channels, kernel_size=1, diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index 83645cf38df..4dfd232d499 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -6,7 +6,7 @@ from torch import Tensor from torch.nn.modules.batchnorm import BatchNorm2d from torch.nn.modules.instancenorm import InstanceNorm2d -from torchvision.ops.misc import ConvNormActivation +from torchvision.ops import Conv2dNormActivation from ..._internally_replaced_utils import load_state_dict_from_url from ...utils import _log_api_usage_once @@ -38,17 +38,17 @@ def __init__(self, in_channels, out_channels, *, norm_layer, stride=1): # and frozen for the rest of the training process (i.e. set as eval()). The bias term is thus still useful # for the rest of the datasets. Technically, we could remove the bias for other norm layers like Instance norm # because these aren't frozen, but we don't bother (also, we woudn't be able to load the original weights). - self.convnormrelu1 = ConvNormActivation( + self.convnormrelu1 = Conv2dNormActivation( in_channels, out_channels, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True ) - self.convnormrelu2 = ConvNormActivation( + self.convnormrelu2 = Conv2dNormActivation( out_channels, out_channels, norm_layer=norm_layer, kernel_size=3, bias=True ) if stride == 1: self.downsample = nn.Identity() else: - self.downsample = ConvNormActivation( + self.downsample = Conv2dNormActivation( in_channels, out_channels, norm_layer=norm_layer, @@ -77,13 +77,13 @@ def __init__(self, in_channels, out_channels, *, norm_layer, stride=1): super().__init__() # See note in ResidualBlock for the reason behind bias=True - self.convnormrelu1 = ConvNormActivation( + self.convnormrelu1 = Conv2dNormActivation( in_channels, out_channels // 4, norm_layer=norm_layer, kernel_size=1, bias=True ) - self.convnormrelu2 = ConvNormActivation( + self.convnormrelu2 = Conv2dNormActivation( out_channels // 4, out_channels // 4, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True ) - self.convnormrelu3 = ConvNormActivation( + self.convnormrelu3 = Conv2dNormActivation( out_channels // 4, out_channels, norm_layer=norm_layer, kernel_size=1, bias=True ) self.relu = nn.ReLU(inplace=True) @@ -91,7 +91,7 @@ def __init__(self, in_channels, out_channels, *, norm_layer, stride=1): if stride == 1: self.downsample = nn.Identity() else: - self.downsample = ConvNormActivation( + self.downsample = Conv2dNormActivation( in_channels, out_channels, norm_layer=norm_layer, @@ -124,7 +124,9 @@ def __init__(self, *, block=ResidualBlock, layers=(64, 64, 96, 128, 256), norm_l assert len(layers) == 5 # See note in ResidualBlock for the reason behind bias=True - self.convnormrelu = ConvNormActivation(3, layers[0], norm_layer=norm_layer, kernel_size=7, stride=2, bias=True) + self.convnormrelu = Conv2dNormActivation( + 3, layers[0], norm_layer=norm_layer, kernel_size=7, stride=2, bias=True + ) self.layer1 = self._make_2_blocks(block, layers[0], layers[1], norm_layer=norm_layer, first_stride=1) self.layer2 = self._make_2_blocks(block, layers[1], layers[2], norm_layer=norm_layer, first_stride=2) @@ -170,17 +172,17 @@ def __init__(self, *, in_channels_corr, corr_layers=(256, 192), flow_layers=(128 assert len(flow_layers) == 2 assert len(corr_layers) in (1, 2) - self.convcorr1 = ConvNormActivation(in_channels_corr, corr_layers[0], norm_layer=None, kernel_size=1) + self.convcorr1 = Conv2dNormActivation(in_channels_corr, corr_layers[0], norm_layer=None, kernel_size=1) if len(corr_layers) == 2: - self.convcorr2 = ConvNormActivation(corr_layers[0], corr_layers[1], norm_layer=None, kernel_size=3) + self.convcorr2 = Conv2dNormActivation(corr_layers[0], corr_layers[1], norm_layer=None, kernel_size=3) else: self.convcorr2 = nn.Identity() - self.convflow1 = ConvNormActivation(2, flow_layers[0], norm_layer=None, kernel_size=7) - self.convflow2 = ConvNormActivation(flow_layers[0], flow_layers[1], norm_layer=None, kernel_size=3) + self.convflow1 = Conv2dNormActivation(2, flow_layers[0], norm_layer=None, kernel_size=7) + self.convflow2 = Conv2dNormActivation(flow_layers[0], flow_layers[1], norm_layer=None, kernel_size=3) # out_channels - 2 because we cat the flow (2 channels) at the end - self.conv = ConvNormActivation( + self.conv = Conv2dNormActivation( corr_layers[-1] + flow_layers[-1], out_channels - 2, norm_layer=None, kernel_size=3 ) @@ -301,7 +303,7 @@ class MaskPredictor(nn.Module): def __init__(self, *, in_channels, hidden_size, multiplier=0.25): super().__init__() - self.convrelu = ConvNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3) + self.convrelu = Conv2dNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3) # 8 * 8 * 9 because the predicted flow is downsampled by 8, from the downsampling of the initial FeatureEncoder # and we interpolate with all 9 surrounding neighbors. See paper and appendix B. self.conv = nn.Conv2d(hidden_size, 8 * 8 * 9, 1, padding=0) diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index b1e38f2cbbb..8cd9f16d13e 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -6,7 +6,7 @@ from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls from ..._internally_replaced_utils import load_state_dict_from_url -from ...ops.misc import ConvNormActivation +from ...ops.misc import Conv2dNormActivation from .utils import _fuse_modules, _replace_relu, quantize_model @@ -54,7 +54,7 @@ def forward(self, x: Tensor) -> Tensor: def fuse_model(self, is_qat: Optional[bool] = None) -> None: for m in self.modules(): - if type(m) is ConvNormActivation: + if type(m) is Conv2dNormActivation: _fuse_modules(m, ["0", "1", "2"], is_qat, inplace=True) if type(m) is QuantizableInvertedResidual: m.fuse_model(is_qat) diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 2f58cd96ace..4d7e2f7baad 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -5,7 +5,7 @@ from torch.ao.quantization import QuantStub, DeQuantStub from ..._internally_replaced_utils import load_state_dict_from_url -from ...ops.misc import ConvNormActivation, SqueezeExcitation +from ...ops.misc import Conv2dNormActivation, SqueezeExcitation from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, model_urls, _mobilenet_v3_conf from .utils import _fuse_modules, _replace_relu @@ -103,7 +103,7 @@ def forward(self, x: Tensor) -> Tensor: def fuse_model(self, is_qat: Optional[bool] = None) -> None: for m in self.modules(): - if type(m) is ConvNormActivation: + if type(m) is Conv2dNormActivation: modules_to_fuse = ["0", "1"] if len(m) == 3 and type(m[2]) is nn.ReLU: modules_to_fuse.append("2") diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 3f393c8e82d..74abd20b237 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -12,7 +12,7 @@ from torch import nn, Tensor from .._internally_replaced_utils import load_state_dict_from_url -from ..ops.misc import ConvNormActivation, SqueezeExcitation +from ..ops.misc import Conv2dNormActivation, SqueezeExcitation from ..utils import _log_api_usage_once from ._utils import _make_divisible @@ -55,7 +55,7 @@ } -class SimpleStemIN(ConvNormActivation): +class SimpleStemIN(Conv2dNormActivation): """Simple stem for ImageNet: 3x3, BN, ReLU.""" def __init__( @@ -88,10 +88,10 @@ def __init__( w_b = int(round(width_out * bottleneck_multiplier)) g = w_b // group_width - layers["a"] = ConvNormActivation( + layers["a"] = Conv2dNormActivation( width_in, w_b, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=activation_layer ) - layers["b"] = ConvNormActivation( + layers["b"] = Conv2dNormActivation( w_b, w_b, kernel_size=3, stride=stride, groups=g, norm_layer=norm_layer, activation_layer=activation_layer ) @@ -105,7 +105,7 @@ def __init__( activation=activation_layer, ) - layers["c"] = ConvNormActivation( + layers["c"] = Conv2dNormActivation( w_b, width_out, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=None ) super().__init__(layers) @@ -131,7 +131,7 @@ def __init__( self.proj = None should_proj = (width_in != width_out) or (stride != 1) if should_proj: - self.proj = ConvNormActivation( + self.proj = Conv2dNormActivation( width_in, width_out, kernel_size=1, stride=stride, norm_layer=norm_layer, activation_layer=None ) self.f = BottleneckTransform( diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index b36658e34d8..29f756ccbe5 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -7,7 +7,7 @@ import torch.nn as nn from .._internally_replaced_utils import load_state_dict_from_url -from ..ops.misc import ConvNormActivation +from ..ops.misc import Conv2dNormActivation from ..utils import _log_api_usage_once __all__ = [ @@ -163,7 +163,7 @@ def __init__( for i, conv_stem_layer_config in enumerate(conv_stem_configs): seq_proj.add_module( f"conv_bn_relu_{i}", - ConvNormActivation( + Conv2dNormActivation( in_channels=prev_channels, out_channels=conv_stem_layer_config.out_channels, kernel_size=conv_stem_layer_config.kernel_size, From 7a6100c072e6dd129d0bf1d22bb3c64394df0d5f Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Fri, 25 Feb 2022 14:34:12 +0530 Subject: [PATCH 13/13] Small doc fix --- torchvision/ops/misc.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index b542de90742..c7c52a86ff1 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -112,7 +112,9 @@ def __init__( self.out_channels = out_channels if self.__class__ == ConvNormActivation: - warnings.warn("Don't use ConvNormActivation directly. Use Conv2dNormActivation instead.") + warnings.warn( + "Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead." + ) class Conv2dNormActivation(ConvNormActivation):