Skip to content

Commit 1dc3bc5

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Adding ConvNeXt architecture in prototype (#5197)
Summary: * Adding CNBlock and skeleton architecture * Completed implementation. * Adding model in prototypes. * Add test and minor refactor for JIT. * Fix mypy. * Fixing naming conventions. * Fixing tests. * Fix stochastic depth percentages. * Adding stochastic depth to tiny variant. * Minor refactoring and adding comments. * Adding weights. * Update default weights. * Fix transforms issue * Move convnext to prototype. * linter fix * fix docs * Addressing code review comments. Reviewed By: jdsgomes, prabhat00155 Differential Revision: D33739375 fbshipit-source-id: 9df87bff1030cb629faf7d056957d1153a58af42
1 parent dc3d569 commit 1dc3bc5

File tree

6 files changed

+249
-3
lines changed

6 files changed

+249
-3
lines changed

docs/source/models.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ architectures for image classification:
4141
- `EfficientNet`_
4242
- `RegNet`_
4343
- `VisionTransformer`_
44+
- `ConvNeXt`_
4445

4546
You can construct a model with random weights by calling its constructor:
4647

@@ -88,7 +89,7 @@ You can construct a model with random weights by calling its constructor:
8889
vit_b_32 = models.vit_b_32()
8990
vit_l_16 = models.vit_l_16()
9091
vit_l_32 = models.vit_l_32()
91-
vit_h_14 = models.vit_h_14()
92+
vit_h_14 = models.vit_h_14()
9293
9394
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
9495
These can be constructed by passing ``pretrained=True``:
@@ -248,6 +249,7 @@ vit_b_16 81.072 95.318
248249
vit_b_32 75.912 92.466
249250
vit_l_16 79.662 94.638
250251
vit_l_32 76.972 93.070
252+
convnext_tiny (prototype) 82.520 96.146
251253
================================ ============= =============
252254

253255

@@ -266,6 +268,7 @@ vit_l_32 76.972 93.070
266268
.. _EfficientNet: https://arxiv.org/abs/1905.11946
267269
.. _RegNet: https://arxiv.org/abs/2003.13678
268270
.. _VisionTransformer: https://arxiv.org/abs/2010.11929
271+
.. _ConvNeXt: https://arxiv.org/abs/2201.03545
269272

270273
.. currentmodule:: torchvision.models
271274

references/classification/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,20 @@ Note that the above command corresponds to training on a single node with 8 GPUs
197197
For generatring the pre-trained weights, we trained with 8 nodes, each with 8 GPUs (for a total of 64 GPUs),
198198
and `--batch_size 64`.
199199

200+
201+
### ConvNeXt
202+
```
203+
torchrun --nproc_per_node=8 train.py\
204+
--model convnext_tiny --batch-size 128 --opt adamw --lr 1e-3 --lr-scheduler cosineannealinglr \
205+
--lr-warmup-epochs 5 --lr-warmup-method linear --auto-augment ta_wide --epochs 600 --random-erase 0.1 \
206+
--label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --weight-decay 0.05 --norm-weight-decay 0.0 \
207+
--train-crop-size 176 --model-ema --val-resize-size 236 --ra-sampler --ra-reps 4
208+
```
209+
210+
Note that the above command corresponds to training on a single node with 8 GPUs.
211+
For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs),
212+
and `--batch_size 64`.
213+
200214
## Mixed precision training
201215
Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [torch.cuda.amp](https://pytorch.org/docs/stable/amp.html?highlight=amp#module-torch.cuda.amp).
202216

Binary file not shown.

torchvision/ops/misc.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __init__(
131131
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
132132
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
133133
dilation: int = 1,
134-
inplace: bool = True,
134+
inplace: Optional[bool] = True,
135135
bias: Optional[bool] = None,
136136
) -> None:
137137
if padding is None:
@@ -153,7 +153,8 @@ def __init__(
153153
if norm_layer is not None:
154154
layers.append(norm_layer(out_channels))
155155
if activation_layer is not None:
156-
layers.append(activation_layer(inplace=inplace))
156+
params = {} if inplace is None else {"inplace": inplace}
157+
layers.append(activation_layer(**params))
157158
super().__init__(*layers)
158159
_log_api_usage_once(self)
159160
self.out_channels = out_channels

torchvision/prototype/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .alexnet import *
2+
from .convnext import *
23
from .densenet import *
34
from .efficientnet import *
45
from .googlenet import *
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
from functools import partial
2+
from typing import Any, Callable, List, Optional, Sequence
3+
4+
import torch
5+
from torch import nn, Tensor
6+
from torch.nn import functional as F
7+
from torchvision.prototype.transforms import ImageNetEval
8+
from torchvision.transforms.functional import InterpolationMode
9+
10+
from ...ops.misc import ConvNormActivation
11+
from ...ops.stochastic_depth import StochasticDepth
12+
from ...utils import _log_api_usage_once
13+
from ._api import WeightsEnum, Weights
14+
from ._meta import _IMAGENET_CATEGORIES
15+
from ._utils import handle_legacy_interface, _ovewrite_named_param
16+
17+
18+
__all__ = ["ConvNeXt", "ConvNeXt_Tiny_Weights", "convnext_tiny"]
19+
20+
21+
class LayerNorm2d(nn.LayerNorm):
22+
def __init__(self, *args: Any, **kwargs: Any) -> None:
23+
self.channels_last = kwargs.pop("channels_last", False)
24+
super().__init__(*args, **kwargs)
25+
26+
def forward(self, x: Tensor) -> Tensor:
27+
# TODO: Benchmark this against the approach described at https://github.com/pytorch/vision/pull/5197#discussion_r786251298
28+
if not self.channels_last:
29+
x = x.permute(0, 2, 3, 1)
30+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
31+
if not self.channels_last:
32+
x = x.permute(0, 3, 1, 2)
33+
return x
34+
35+
36+
class CNBlock(nn.Module):
37+
def __init__(
38+
self, dim, layer_scale: float, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module]
39+
) -> None:
40+
super().__init__()
41+
self.block = nn.Sequential(
42+
ConvNormActivation(
43+
dim,
44+
dim,
45+
kernel_size=7,
46+
groups=dim,
47+
norm_layer=norm_layer,
48+
activation_layer=None,
49+
bias=True,
50+
),
51+
ConvNormActivation(dim, 4 * dim, kernel_size=1, norm_layer=None, activation_layer=nn.GELU, inplace=None),
52+
ConvNormActivation(
53+
4 * dim,
54+
dim,
55+
kernel_size=1,
56+
norm_layer=None,
57+
activation_layer=None,
58+
),
59+
)
60+
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
61+
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
62+
63+
def forward(self, input: Tensor) -> Tensor:
64+
result = self.layer_scale * self.block(input)
65+
result = self.stochastic_depth(result)
66+
result += input
67+
return result
68+
69+
70+
class CNBlockConfig:
71+
# Stores information listed at Section 3 of the ConvNeXt paper
72+
def __init__(
73+
self,
74+
input_channels: int,
75+
out_channels: Optional[int],
76+
num_layers: int,
77+
) -> None:
78+
self.input_channels = input_channels
79+
self.out_channels = out_channels
80+
self.num_layers = num_layers
81+
82+
def __repr__(self) -> str:
83+
s = self.__class__.__name__ + "("
84+
s += "input_channels={input_channels}"
85+
s += ", out_channels={out_channels}"
86+
s += ", num_layers={num_layers}"
87+
s += ")"
88+
return s.format(**self.__dict__)
89+
90+
91+
class ConvNeXt(nn.Module):
92+
def __init__(
93+
self,
94+
block_setting: List[CNBlockConfig],
95+
stochastic_depth_prob: float = 0.0,
96+
layer_scale: float = 1e-6,
97+
num_classes: int = 1000,
98+
block: Optional[Callable[..., nn.Module]] = None,
99+
norm_layer: Optional[Callable[..., nn.Module]] = None,
100+
**kwargs: Any,
101+
) -> None:
102+
super().__init__()
103+
_log_api_usage_once(self)
104+
105+
if not block_setting:
106+
raise ValueError("The block_setting should not be empty")
107+
elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
108+
raise TypeError("The block_setting should be List[CNBlockConfig]")
109+
110+
if block is None:
111+
block = CNBlock
112+
113+
if norm_layer is None:
114+
norm_layer = partial(LayerNorm2d, eps=1e-6)
115+
116+
layers: List[nn.Module] = []
117+
118+
# Stem
119+
firstconv_output_channels = block_setting[0].input_channels
120+
layers.append(
121+
ConvNormActivation(
122+
3,
123+
firstconv_output_channels,
124+
kernel_size=4,
125+
stride=4,
126+
padding=0,
127+
norm_layer=norm_layer,
128+
activation_layer=None,
129+
bias=True,
130+
)
131+
)
132+
133+
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
134+
stage_block_id = 0
135+
for cnf in block_setting:
136+
# Bottlenecks
137+
stage: List[nn.Module] = []
138+
for _ in range(cnf.num_layers):
139+
# adjust stochastic depth probability based on the depth of the stage block
140+
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
141+
stage.append(block(cnf.input_channels, layer_scale, sd_prob, norm_layer))
142+
stage_block_id += 1
143+
layers.append(nn.Sequential(*stage))
144+
if cnf.out_channels is not None:
145+
# Downsampling
146+
layers.append(
147+
nn.Sequential(
148+
norm_layer(cnf.input_channels),
149+
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
150+
)
151+
)
152+
153+
self.features = nn.Sequential(*layers)
154+
self.avgpool = nn.AdaptiveAvgPool2d(1)
155+
156+
lastblock = block_setting[-1]
157+
lastconv_output_channels = (
158+
lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels
159+
)
160+
self.classifier = nn.Sequential(
161+
norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes)
162+
)
163+
164+
for m in self.modules():
165+
if isinstance(m, (nn.Conv2d, nn.Linear)):
166+
nn.init.trunc_normal_(m.weight, std=0.02)
167+
if m.bias is not None:
168+
nn.init.zeros_(m.bias)
169+
170+
def _forward_impl(self, x: Tensor) -> Tensor:
171+
x = self.features(x)
172+
x = self.avgpool(x)
173+
x = self.classifier(x)
174+
return x
175+
176+
def forward(self, x: Tensor) -> Tensor:
177+
return self._forward_impl(x)
178+
179+
180+
class ConvNeXt_Tiny_Weights(WeightsEnum):
181+
ImageNet1K_V1 = Weights(
182+
url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth",
183+
transforms=partial(ImageNetEval, crop_size=224, resize_size=236),
184+
meta={
185+
"task": "image_classification",
186+
"architecture": "ConvNeXt",
187+
"publication_year": 2022,
188+
"num_params": 28589128,
189+
"size": (224, 224),
190+
"min_size": (32, 32),
191+
"categories": _IMAGENET_CATEGORIES,
192+
"interpolation": InterpolationMode.BILINEAR,
193+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
194+
"acc@1": 82.520,
195+
"acc@5": 96.146,
196+
},
197+
)
198+
default = ImageNet1K_V1
199+
200+
201+
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.ImageNet1K_V1))
202+
def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
203+
r"""ConvNeXt model architecture from the
204+
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.
205+
206+
Args:
207+
weights (ConvNeXt_Tiny_Weights, optional): The pre-trained weights of the model
208+
progress (bool): If True, displays a progress bar of the download to stderr
209+
"""
210+
weights = ConvNeXt_Tiny_Weights.verify(weights)
211+
212+
if weights is not None:
213+
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
214+
215+
block_setting = [
216+
CNBlockConfig(96, 192, 3),
217+
CNBlockConfig(192, 384, 3),
218+
CNBlockConfig(384, 768, 9),
219+
CNBlockConfig(768, None, 3),
220+
]
221+
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
222+
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
223+
224+
if weights is not None:
225+
model.load_state_dict(weights.get_state_dict(progress=progress))
226+
227+
return model

0 commit comments

Comments
 (0)