Skip to content

Commit 4e3d900

Browse files
committed
Adding multiweight support for shufflenetv2 prototype models
1 parent 85e4429 commit 4e3d900

File tree

2 files changed

+122
-0
lines changed

2 files changed

+122
-0
lines changed

torchvision/prototype/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .mobilenetv2 import *
88
from .mnasnet import *
99
from .regnet import *
10+
from .shufflenetv2 import *
1011
from . import detection
1112
from . import quantization
1213
from . import segmentation
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import warnings
2+
from functools import partial
3+
from typing import Any, Optional
4+
5+
from torchvision.transforms.functional import InterpolationMode
6+
7+
from ...models.shufflenetv2 import ShuffleNetV2
8+
from ..transforms.presets import ImageNetEval
9+
from ._api import Weights, WeightEntry
10+
from ._meta import _IMAGENET_CATEGORIES
11+
12+
13+
__all__ = [
14+
"ShuffleNetV2",
15+
"ShuffleNetV2_x0_5Weights",
16+
"ShuffleNetV2_x1_0Weights",
17+
"ShuffleNetV2_x1_5Weights",
18+
"ShuffleNetV2_x2_0Weights",
19+
"shufflenet_v2_x0_5",
20+
"shufflenet_v2_x1_0",
21+
"shufflenet_v2_x1_5",
22+
"shufflenet_v2_x2_0",
23+
]
24+
25+
26+
def _shufflenetv2(
27+
weights: Optional[Weights],
28+
progress: bool,
29+
*args: Any,
30+
**kwargs: Any,
31+
) -> ShuffleNetV2:
32+
if weights is not None:
33+
kwargs["num_classes"] = len(weights.meta["categories"])
34+
35+
model = ShuffleNetV2(*args, **kwargs)
36+
37+
if weights is not None:
38+
model.load_state_dict(weights.state_dict(progress=progress))
39+
40+
return model
41+
42+
43+
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
44+
45+
46+
class ShuffleNetV2_x0_5Weights(Weights):
47+
ImageNet1K_RefV1 = WeightEntry(
48+
url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
49+
transforms=partial(ImageNetEval, crop_size=224),
50+
meta={
51+
**_common_meta,
52+
"recipe": "",
53+
"acc@1": 69.362,
54+
"acc@5": 88.316,
55+
},
56+
)
57+
58+
59+
class ShuffleNetV2_x1_0Weights(Weights):
60+
ImageNet1K_RefV1 = WeightEntry(
61+
url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
62+
transforms=partial(ImageNetEval, crop_size=224),
63+
meta={
64+
**_common_meta,
65+
"recipe": "",
66+
"acc@1": 60.552,
67+
"acc@5": 81.746,
68+
},
69+
)
70+
71+
72+
class ShuffleNetV2_x1_5Weights(Weights):
73+
pass
74+
75+
76+
class ShuffleNetV2_x2_0Weights(Weights):
77+
pass
78+
79+
80+
def shufflenet_v2_x0_5(
81+
weights: Optional[ShuffleNetV2_x0_5Weights] = None, progress: bool = True, **kwargs: Any
82+
) -> ShuffleNetV2:
83+
if "pretrained" in kwargs:
84+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
85+
weights = ShuffleNetV2_x0_5Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
86+
weights = ShuffleNetV2_x0_5Weights.verify(weights)
87+
88+
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
89+
90+
91+
def shufflenet_v2_x1_0(
92+
weights: Optional[ShuffleNetV2_x1_0Weights] = None, progress: bool = True, **kwargs: Any
93+
) -> ShuffleNetV2:
94+
if "pretrained" in kwargs:
95+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
96+
weights = ShuffleNetV2_x1_0Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
97+
weights = ShuffleNetV2_x1_0Weights.verify(weights)
98+
99+
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
100+
101+
102+
def shufflenet_v2_x1_5(
103+
weights: Optional[ShuffleNetV2_x1_5Weights] = None, progress: bool = True, **kwargs: Any
104+
) -> ShuffleNetV2:
105+
if "pretrained" in kwargs:
106+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
107+
weights = ShuffleNetV2_x1_5Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
108+
weights = ShuffleNetV2_x1_5Weights.verify(weights)
109+
110+
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
111+
112+
113+
def shufflenet_v2_x2_0(
114+
weights: Optional[ShuffleNetV2_x2_0Weights] = None, progress: bool = True, **kwargs: Any
115+
) -> ShuffleNetV2:
116+
if "pretrained" in kwargs:
117+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
118+
weights = ShuffleNetV2_x2_0Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
119+
weights = ShuffleNetV2_x2_0Weights.verify(weights)
120+
121+
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)

0 commit comments

Comments
 (0)