diff --git a/README.rst b/README.rst index a65ffd8340a..bf1a9b53106 100644 --- a/README.rst +++ b/README.rst @@ -185,3 +185,10 @@ Disclaimer on Datasets This is a utility library that downloads and prepares public datasets. We do not host or distribute these datasets, vouch for their quality or fairness, or claim that you have license to use the dataset. It is your responsibility to determine whether you have permission to use the dataset under the dataset's license. If you're a dataset owner and wish to update any part of it (description, citation, etc.), or do not want your dataset to be included in this library, please get in touch through a GitHub issue. Thanks for your contribution to the ML community! + +Pre-trained Model License +========================= + +The pre-trained models provided in this library may have their own licenses or terms and conditions derived from the dataset used for training. It is your responsibility to determine whether you have permission to use the models for your use case. + +More specifically, SWAG models are released under the CC-BY-NC 4.0 license. See `SWAG LICENSE `_ for additional details. diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 577be1d2cd6..1422cc28103 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -115,7 +115,8 @@ def test_schema_meta_validation(model_fn): incorrect_params.append(w) else: if w.meta.get("num_params") != weights_enum.DEFAULT.meta.get("num_params"): - incorrect_params.append(w) + if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()): + incorrect_params.append(w) if not w.name.isupper(): bad_names.append(w) diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index fb34cf3c8e1..59da51c1bd9 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -1,7 +1,7 @@ import math from collections import OrderedDict from functools import partial -from typing import Any, Callable, List, NamedTuple, Optional +from typing import Any, Callable, List, NamedTuple, Optional, Sequence import torch import torch.nn as nn @@ -284,10 +284,21 @@ def _vision_transformer( progress: bool, **kwargs: Any, ) -> VisionTransformer: - image_size = kwargs.pop("image_size", 224) - if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if isinstance(weights.meta["size"], int): + _ovewrite_named_param(kwargs, "image_size", weights.meta["size"]) + elif isinstance(weights.meta["size"], Sequence): + if len(weights.meta["size"]) != 2 or weights.meta["size"][0] != weights.meta["size"][1]: + raise ValueError( + f'size: {weights.meta["size"]} is not valid! Currently we only support a 2-dimensional square and width = height' + ) + _ovewrite_named_param(kwargs, "image_size", weights.meta["size"][0]) + else: + raise ValueError( + f'weights.meta["size"]: {weights.meta["size"]} is not valid, the type should be either an int or a Sequence[int]' + ) + image_size = kwargs.pop("image_size", 224) model = VisionTransformer( image_size=image_size, @@ -313,6 +324,14 @@ def _vision_transformer( "interpolation": InterpolationMode.BILINEAR, } +_COMMON_SWAG_META = { + **_COMMON_META, + "publication_year": 2022, + "recipe": "https://github.com/facebookresearch/SWAG", + "license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE", + "interpolation": InterpolationMode.BICUBIC, +} + class ViT_B_16_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( @@ -328,6 +347,23 @@ class ViT_B_16_Weights(WeightsEnum): "acc@5": 95.318, }, ) + IMAGENET1K_SWAG_V1 = Weights( + url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth", + transforms=partial( + ImageClassification, + crop_size=384, + resize_size=384, + interpolation=InterpolationMode.BICUBIC, + ), + meta={ + **_COMMON_SWAG_META, + "num_params": 86859496, + "size": (384, 384), + "min_size": (384, 384), + "acc@1": 85.304, + "acc@5": 97.650, + }, + ) DEFAULT = IMAGENET1K_V1 @@ -362,6 +398,23 @@ class ViT_L_16_Weights(WeightsEnum): "acc@5": 94.638, }, ) + IMAGENET1K_SWAG_V1 = Weights( + url="https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth", + transforms=partial( + ImageClassification, + crop_size=512, + resize_size=512, + interpolation=InterpolationMode.BICUBIC, + ), + meta={ + **_COMMON_SWAG_META, + "num_params": 305174504, + "size": (512, 512), + "min_size": (512, 512), + "acc@1": 88.064, + "acc@5": 98.512, + }, + ) DEFAULT = IMAGENET1K_V1