Skip to content

Generalize ConvNormActivation function to accept tuple for some parameters #6251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions torchvision/ops/misc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import warnings
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Union, Tuple

import torch
from torch import Tensor

from ..utils import _log_api_usage_once
from ..utils import _log_api_usage_once, _make_ntuple


interpolate = torch.nn.functional.interpolate
Expand Down Expand Up @@ -70,20 +70,29 @@ def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
kernel_size: Union[int, Tuple[int, ...]] = 3,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was scared that by adding Union are we JIT compatible? I had avoided Union for the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the input @oke-aditya , let me check if this is JIT compatible or not. Will update here

Copy link
Contributor Author

@YosuaMichael YosuaMichael Jul 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have verified that it is JIT compatible. Here is my script to check:

import torchvision.ops.misc as misc
import torch

conv = misc.Conv2dNormActivation(10, 5, kernel_size=(1, 3), stride=(1, 2))

x = torch.rand(1, 10, 32, 32)
out = conv(x)

conv_jit = torch.jit.script(conv)
out_jit = conv_jit(x)

print(torch.allclose(out, out_jit))  # True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no JIT-scriptability concerns here. Constructors can have whatever calls. It's on the forward calls that we have the restrictions. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have also try to do

import torchvision.models as models
import torch

m = models.efficientnet_b0()
torch.jit.script(m)

and it works as well (notes: efficientnet use Conv2dNormActivation)

stride: Union[int, Tuple[int, ...]] = 1,
padding: Optional[Union[int, Tuple[int, ...], str]] = None,
groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1,
dilation: Union[int, Tuple[int, ...]] = 1,
inplace: Optional[bool] = True,
bias: Optional[bool] = None,
conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
) -> None:

if padding is None:
padding = (kernel_size - 1) // 2 * dilation
if isinstance(kernel_size, int) and isinstance(dilation, int):
padding = (kernel_size - 1) // 2 * dilation
else:
if isinstance(kernel_size, Tuple):
_conv_dim = len(kernel_size)
elif isinstance(dilation, Tuple):
_conv_dim = len(dilation)
kernel_size = _make_ntuple(kernel_size, _conv_dim)
dilation = _make_ntuple(dilation, _conv_dim)
padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim))
if bias is None:
bias = norm_layer is None

Expand Down Expand Up @@ -139,13 +148,13 @@ def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
kernel_size: Union[int, Tuple[int, int]] = 3,
stride: Union[int, Tuple[int, int]] = 1,
padding: Optional[Union[int, Tuple[int, int], str]] = None,
groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1,
dilation: Union[int, Tuple[int, int]] = 1,
inplace: Optional[bool] = True,
bias: Optional[bool] = None,
) -> None:
Expand Down Expand Up @@ -188,13 +197,13 @@ def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
kernel_size: Union[int, Tuple[int, int, int]] = 3,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Optional[Union[int, Tuple[int, int, int], str]] = None,
groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1,
dilation: Union[int, Tuple[int, int, int]] = 1,
inplace: Optional[bool] = True,
bias: Optional[bool] = None,
) -> None:
Expand Down
17 changes: 17 additions & 0 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import collections
import math
import pathlib
import warnings
from itertools import repeat
from types import FunctionType
from typing import Any, BinaryIO, List, Optional, Tuple, Union

Expand Down Expand Up @@ -569,3 +571,18 @@ def _log_api_usage_once(obj: Any) -> None:
if isinstance(obj, FunctionType):
name = obj.__name__
torch._C._log_api_usage_once(f"{module}.{name}")


def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]:
"""
Make n-tuple from input x. If x is an iterable, then we just convert it to tuple.
Otherwise we will make a tuple of length n, all with value of x.
reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8

Args:
x (Any): input value
n (int): length of the resulting tuple
"""
if isinstance(x, collections.abc.Iterable):
return tuple(x)
return tuple(repeat(x, n))