Skip to content

Commit d57f929

Browse files
authored
Move Permute layer to ops. (#6055)
1 parent 77cad12 commit d57f929

File tree

5 files changed

+21
-15
lines changed

5 files changed

+21
-15
lines changed

docs/source/ops.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,9 @@ TorchVision provides commonly used building blocks as layers:
8787
DeformConv2d
8888
DropBlock2d
8989
DropBlock3d
90-
MLP
9190
FrozenBatchNorm2d
91+
MLP
92+
Permute
9293
SqueezeExcitation
9394
StochasticDepth
9495

torchvision/models/convnext.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch import nn, Tensor
66
from torch.nn import functional as F
77

8-
from ..ops.misc import Conv2dNormActivation
8+
from ..ops.misc import Conv2dNormActivation, Permute
99
from ..ops.stochastic_depth import StochasticDepth
1010
from ..transforms._presets import ImageClassification
1111
from ..utils import _log_api_usage_once
@@ -35,15 +35,6 @@ def forward(self, x: Tensor) -> Tensor:
3535
return x
3636

3737

38-
class Permute(nn.Module):
39-
def __init__(self, dims: List[int]):
40-
super().__init__()
41-
self.dims = dims
42-
43-
def forward(self, x):
44-
return torch.permute(x, self.dims)
45-
46-
4738
class CNBlock(nn.Module):
4839
def __init__(
4940
self,

torchvision/models/swin_transformer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
import torch.nn.functional as F
66
from torch import nn, Tensor
77

8-
from ..ops.misc import MLP
8+
from ..ops.misc import MLP, Permute
99
from ..ops.stochastic_depth import StochasticDepth
1010
from ..transforms._presets import ImageClassification, InterpolationMode
1111
from ..utils import _log_api_usage_once
1212
from ._api import WeightsEnum, Weights
1313
from ._meta import _IMAGENET_CATEGORIES
1414
from ._utils import _ovewrite_named_param
15-
from .convnext import Permute # TODO: move Permute on ops
1615

1716

1817
__all__ = [

torchvision/ops/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .feature_pyramid_network import FeaturePyramidNetwork
2020
from .focal_loss import sigmoid_focal_loss
2121
from .giou_loss import generalized_box_iou_loss
22-
from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation, MLP
22+
from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation, MLP, Permute
2323
from .poolers import MultiScaleRoIAlign
2424
from .ps_roi_align import ps_roi_align, PSRoIAlign
2525
from .ps_roi_pool import ps_roi_pool, PSRoIPool
@@ -62,6 +62,7 @@
6262
"Conv3dNormActivation",
6363
"SqueezeExcitation",
6464
"MLP",
65+
"Permute",
6566
"generalized_box_iou_loss",
6667
"distance_box_iou_loss",
6768
"complete_box_iou_loss",

torchvision/ops/misc.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
interpolate = torch.nn.functional.interpolate
1111

1212

13-
# This is not in nn
1413
class FrozenBatchNorm2d(torch.nn.Module):
1514
"""
1615
BatchNorm2d where the batch statistics and the affine parameters are fixed
@@ -297,3 +296,18 @@ def __init__(
297296

298297
super().__init__(*layers)
299298
_log_api_usage_once(self)
299+
300+
301+
class Permute(torch.nn.Module):
302+
"""This module returns a view of the tensor input with its dimensions permuted.
303+
304+
Args:
305+
dims (List[int]): The desired ordering of dimensions
306+
"""
307+
308+
def __init__(self, dims: List[int]):
309+
super().__init__()
310+
self.dims = dims
311+
312+
def forward(self, x: Tensor) -> Tensor:
313+
return torch.permute(x, self.dims)

0 commit comments

Comments
 (0)