diff --git a/docs/source/models.rst b/docs/source/models.rst index 62c104cf927..f1331d5baa9 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -88,6 +88,7 @@ You can construct a model with random weights by calling its constructor: vit_b_32 = models.vit_b_32() vit_l_16 = models.vit_l_16() vit_l_32 = models.vit_l_32() + vit_h_14 = models.vit_h_14() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. These can be constructed by passing ``pretrained=True``: @@ -460,6 +461,7 @@ VisionTransformer vit_b_32 vit_l_16 vit_l_32 + vit_h_14 Quantized Models ---------------- diff --git a/hubconf.py b/hubconf.py index 2b2eeb1c166..1b3b191efa4 100644 --- a/hubconf.py +++ b/hubconf.py @@ -63,4 +63,5 @@ vit_b_32, vit_l_16, vit_l_32, + vit_h_14, ) diff --git a/test/expect/ModelTester.test_vit_h_14_expect.pkl b/test/expect/ModelTester.test_vit_h_14_expect.pkl new file mode 100644 index 00000000000..1f846beb6a0 Binary files /dev/null and b/test/expect/ModelTester.test_vit_h_14_expect.pkl differ diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 11ecd9d97ad..a64f342e1a0 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -15,6 +15,7 @@ "vit_b_32", "vit_l_16", "vit_l_32", + "vit_h_14", ] model_urls = { @@ -260,6 +261,8 @@ def _vision_transformer( ) if pretrained: + if arch not in model_urls: + raise ValueError(f"No checkpoint is available for model type '{arch}'!") state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) @@ -354,6 +357,26 @@ def vit_l_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ) +def vit_h_14(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: + """ + Constructs a vit_h_14 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + NOTE: Pretrained weights are not available for this model. + """ + return _vision_transformer( + arch="vit_h_14", + patch_size=14, + num_layers=32, + num_heads=16, + hidden_dim=1280, + mlp_dim=5120, + pretrained=pretrained, + progress=progress, + **kwargs, + ) + + def interpolate_embeddings( image_size: int, patch_size: int, diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py index 1cd186a2d82..af742c7ee01 100644 --- a/torchvision/prototype/models/vision_transformer.py +++ b/torchvision/prototype/models/vision_transformer.py @@ -19,10 +19,12 @@ "ViT_B_32_Weights", "ViT_L_16_Weights", "ViT_L_32_Weights", + "ViT_H_14_Weights", "vit_b_16", "vit_b_32", "vit_l_16", "vit_l_32", + "vit_h_14", ] @@ -99,6 +101,11 @@ class ViT_L_32_Weights(WeightsEnum): default = ImageNet1K_V1 +class ViT_H_14_Weights(WeightsEnum): + # Weights are not available yet. + pass + + def _vision_transformer( patch_size: int, num_layers: int, @@ -192,3 +199,19 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru progress=progress, **kwargs, ) + + +@handle_legacy_interface(weights=("pretrained", None)) +def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: + weights = ViT_H_14_Weights.verify(weights) + + return _vision_transformer( + patch_size=14, + num_layers=32, + num_heads=16, + hidden_dim=1280, + mlp_dim=5120, + weights=weights, + progress=progress, + **kwargs, + )