-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Generalize ConvNormActivation function to accept tuple for some parameters #6251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generalize ConvNormActivation function to accept tuple for some parameters #6251
Conversation
…, padding, and dilation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @YosuaMichael. I left a few comments below, let me know what you think.
kernel_size: int = 3, | ||
stride: int = 1, | ||
padding: Optional[int] = None, | ||
kernel_size: Union[int, Tuple[int, ...]] = 3, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was scared that by adding Union are we JIT compatible? I had avoided Union for the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the input @oke-aditya , let me check if this is JIT compatible or not. Will update here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have verified that it is JIT compatible. Here is my script to check:
import torchvision.ops.misc as misc
import torch
conv = misc.Conv2dNormActivation(10, 5, kernel_size=(1, 3), stride=(1, 2))
x = torch.rand(1, 10, 32, 32)
out = conv(x)
conv_jit = torch.jit.script(conv)
out_jit = conv_jit(x)
print(torch.allclose(out, out_jit)) # True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are no JIT-scriptability concerns here. Constructors can have whatever calls. It's on the forward calls that we have the restrictions. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have also try to do
import torchvision.models as models
import torch
m = models.efficientnet_b0()
torch.jit.script(m)
and it works as well (notes: efficientnet use Conv2dNormActivation)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just a minor nit. I think we can merge on green CI.
If JIT is happy then we are all good 😄 |
…ome parameters (#6251) Summary: * Make ConvNormActivation function accept tuple for kernel_size, stride, padding, and dilation * Fix the method to get the conv_dim * Simplify if-elif logic Reviewed By: jdsgomes Differential Revision: D37993422 fbshipit-source-id: db76e88960f1da8b2ed7715903f4cc2ff88f0464
Currently the function
ConvNormActivation
and its variantConv2dNormActivation
/Conv3dNormActivation
only accept integer for kernel_size, stride, padding, and dilation. However, some model require tuple kernel_size that have different value (example in S3D model).With this PR, we enable these 4 parameters to accept tuple as well as integer and we make sure it is BC.