diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 72093686d84..0748ee0460d 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -412,6 +412,15 @@ def _regnet( "interpolation": InterpolationMode.BILINEAR, } +_COMMON_SWAG_META = { + **_COMMON_META, + "publication_year": 2022, + "size": (384, 384), + "recipe": "https://github.com/facebookresearch/SWAG", + "license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE", + "interpolation": InterpolationMode.BICUBIC, +} + class RegNet_Y_400MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( @@ -566,6 +575,18 @@ class RegNet_Y_16GF_Weights(WeightsEnum): "acc@5": 96.328, }, ) + IMAGENET1K_SWAG_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pth", + transforms=partial( + ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_SWAG_META, + "num_params": 83590140, + "acc@1": 86.012, + "acc@5": 98.054, + }, + ) DEFAULT = IMAGENET1K_V2 @@ -592,6 +613,18 @@ class RegNet_Y_32GF_Weights(WeightsEnum): "acc@5": 96.498, }, ) + IMAGENET1K_SWAG_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pth", + transforms=partial( + ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_SWAG_META, + "num_params": 145046770, + "acc@1": 86.838, + "acc@5": 98.362, + }, + ) DEFAULT = IMAGENET1K_V2