diff --git a/test/expect/ModelTester.test_vit_b_16_expect.pkl b/test/expect/ModelTester.test_vit_b_16_expect.pkl new file mode 100644 index 00000000000..1f846beb6a0 Binary files /dev/null and b/test/expect/ModelTester.test_vit_b_16_expect.pkl differ diff --git a/test/expect/ModelTester.test_vit_b_32_expect.pkl b/test/expect/ModelTester.test_vit_b_32_expect.pkl new file mode 100644 index 00000000000..1f846beb6a0 Binary files /dev/null and b/test/expect/ModelTester.test_vit_b_32_expect.pkl differ diff --git a/test/expect/ModelTester.test_vit_l_16_expect.pkl b/test/expect/ModelTester.test_vit_l_16_expect.pkl new file mode 100644 index 00000000000..1f846beb6a0 Binary files /dev/null and b/test/expect/ModelTester.test_vit_l_16_expect.pkl differ diff --git a/test/expect/ModelTester.test_vit_l_32_expect.pkl b/test/expect/ModelTester.test_vit_l_32_expect.pkl new file mode 100644 index 00000000000..1f846beb6a0 Binary files /dev/null and b/test/expect/ModelTester.test_vit_l_32_expect.pkl differ diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index 9b46bdd5288..4a375b0036a 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -1,5 +1,6 @@ import random from itertools import chain +from typing import Mapping, Sequence import pytest import torch @@ -89,7 +90,16 @@ def _create_feature_extractor(self, *args, **kwargs): def _get_return_nodes(self, model): set_rng_seed(0) - exclude_nodes_filter = ["getitem", "floordiv", "size", "chunk"] + exclude_nodes_filter = [ + "getitem", + "floordiv", + "size", + "chunk", + "_assert", + "eq", + "dim", + "getattr", + ] train_nodes, eval_nodes = get_graph_node_names( model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True ) @@ -144,7 +154,16 @@ def test_forward_backward(self, model_name): model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes ) out = model(self.inp) - sum(o.mean() for o in out.values()).backward() + out_agg = 0 + for node_out in out.values(): + if isinstance(node_out, Sequence): + out_agg += sum(o.mean() for o in node_out if o is not None) + elif isinstance(node_out, Mapping): + out_agg += sum(o.mean() for o in node_out.values() if o is not None) + else: + # Assume that the only other alternative at this point is a Tensor + out_agg += node_out.mean() + out_agg.backward() def test_feature_extraction_methods_equivalence(self): model = models.resnet18(**self.model_defaults).eval() @@ -176,7 +195,16 @@ def test_jit_forward_backward(self, model_name): ) model = torch.jit.script(model) fgn_out = model(self.inp) - sum(o.mean() for o in fgn_out.values()).backward() + out_agg = 0 + for node_out in fgn_out.values(): + if isinstance(node_out, Sequence): + out_agg += sum(o.mean() for o in node_out if o is not None) + elif isinstance(node_out, Mapping): + out_agg += sum(o.mean() for o in node_out.values() if o is not None) + else: + # Assume that the only other alternative at this point is a Tensor + out_agg += node_out.mean() + out_agg.backward() def test_train_eval(self): class TestModel(torch.nn.Module): diff --git a/test/test_models.py b/test/test_models.py index ef0a5d0260f..10f16f20081 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -507,6 +507,7 @@ def test_classification_model(model_fn, dev): } model_name = model_fn.__name__ kwargs = {**defaults, **_model_params.get(model_name, {})} + num_classes = kwargs.get("num_classes") input_shape = kwargs.pop("input_shape") model = model_fn(**kwargs) @@ -515,7 +516,7 @@ def test_classification_model(model_fn, dev): x = torch.rand(input_shape).to(device=dev) out = model(x) _assert_expected(out.cpu(), model_name, prec=0.1) - assert out.shape[-1] == 50 + assert out.shape[-1] == num_classes _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None)) _check_fx_compatible(model, x) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 2cb378520a4..f53299bcf51 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -122,8 +122,11 @@ def test_old_vs_new_factory(model_fn, module_name, dev): x = [x] # compare with new model builder parameterized in the old fashion way - model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev) - model_new = _build_model(model_fn, **kwargs).to(device=dev) + try: + model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev) + model_new = _build_model(model_fn, **kwargs).to(device=dev) + except ModuleNotFoundError: + pytest.skip(f"Model '{model_name}' not available in both modules.") torch.testing.assert_close(model_new(x), model_old(x), rtol=0.0, atol=0.0, check_dtype=False) diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index 5077b7fd178..f675dc37f25 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -10,6 +10,7 @@ from .shufflenetv2 import * from .squeezenet import * from .vgg import * +from .vision_transformer import * from . import detection from . import quantization from . import segmentation diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py new file mode 100644 index 00000000000..987f3af1bb4 --- /dev/null +++ b/torchvision/prototype/models/vision_transformer.py @@ -0,0 +1,399 @@ +# References: +# https://github.com/google-research/vision_transformer +# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/vision_transformer.py + +import math +from collections import OrderedDict +from functools import partial +from typing import Any, Callable, Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from ._api import Weights +from ._utils import _deprecated_param, _deprecated_positional + + +__all__ = [ + "VisionTransformer", + "VisionTransformer_B_16Weights", + "VisionTransformer_B_32Weights", + "VisionTransformer_L_16Weights", + "VisionTransformer_L_32Weights", + "vit_b_16", + "vit_b_32", + "vit_l_16", + "vit_l_32", +] + + +class MLPBlock(nn.Sequential): + """Transformer MLP block.""" + + def __init__(self, in_dim: int, mlp_dim: int, dropout: float): + super().__init__() + self.linear_1 = nn.Linear(in_dim, mlp_dim) + self.act = nn.GELU() + self.dropout_1 = nn.Dropout(dropout) + self.linear_2 = nn.Linear(mlp_dim, in_dim) + self.dropout_2 = nn.Dropout(dropout) + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.linear_1.weight) + nn.init.xavier_uniform_(self.linear_2.weight) + nn.init.normal_(self.linear_1.bias, std=1e-6) + nn.init.normal_(self.linear_2.bias, std=1e-6) + + +class EncoderBlock(nn.Module): + """Transformer encoder block.""" + + def __init__( + self, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ): + super().__init__() + self.num_heads = num_heads + + # Attention block + self.ln_1 = norm_layer(hidden_dim) + self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) + self.dropout = nn.Dropout(dropout) + + # MLP block + self.ln_2 = norm_layer(hidden_dim) + self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) + + def forward(self, input: Tensor): + torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}") + x = self.ln_1(input) + x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) + x = self.dropout(x) + x = x + input + + y = self.ln_2(x) + y = self.mlp(y) + return x + y + + +class Encoder(nn.Module): + """Transformer Model Encoder for sequence to sequence translation.""" + + def __init__( + self, + seq_length: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ): + super().__init__() + # Note that batch_size is on the first dim because + # we have batch_first=True in nn.MultiAttention() by default + self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT + self.dropout = nn.Dropout(dropout) + layers: OrderedDict[str, nn.Module] = OrderedDict() + for i in range(num_layers): + layers[f"encoder_layer_{i}"] = EncoderBlock( + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + norm_layer, + ) + self.layers = nn.Sequential(layers) + self.ln = norm_layer(hidden_dim) + + def forward(self, input: Tensor): + torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") + input = input + self.pos_embedding + return self.ln(self.layers(self.dropout(input))) + + +class VisionTransformer(nn.Module): + """Vision Transformer as per https://arxiv.org/abs/2010.11929.""" + + def __init__( + self, + image_size: int, + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float = 0.0, + attention_dropout: float = 0.0, + num_classes: int = 1000, + representation_size: Optional[int] = None, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ): + super().__init__() + torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!") + self.image_size = image_size + self.patch_size = patch_size + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.attention_dropout = attention_dropout + self.dropout = dropout + self.num_classes = num_classes + self.representation_size = representation_size + self.norm_layer = norm_layer + + input_channels = 3 + + # The conv_proj is a more efficient version of reshaping, permuting + # and projecting the input + self.conv_proj = nn.Conv2d(input_channels, hidden_dim, kernel_size=patch_size, stride=patch_size) + + seq_length = (image_size // patch_size) ** 2 + + # Add a class token + self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) + seq_length += 1 + + self.encoder = Encoder( + seq_length, + num_layers, + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + norm_layer, + ) + self.seq_length = seq_length + + heads_layers: OrderedDict[str, nn.Module] = OrderedDict() + if representation_size is None: + heads_layers["head"] = nn.Linear(hidden_dim, num_classes) + else: + heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) + heads_layers["act"] = nn.Tanh() + heads_layers["head"] = nn.Linear(representation_size, num_classes) + + self.heads = nn.Sequential(heads_layers) + self._init_weights() + + def _init_weights(self): + fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1] + nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) + nn.init.zeros_(self.conv_proj.bias) + + if hasattr(self.heads, "pre_logits"): + fan_in = self.heads.pre_logits.in_features + nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) + nn.init.zeros_(self.heads.pre_logits.bias) + + nn.init.zeros_(self.heads.head.weight) + nn.init.zeros_(self.heads.head.bias) + + def forward(self, x: torch.Tensor): + n, c, h, w = x.shape + p = self.patch_size + torch._assert(h == self.image_size, "Wrong image height!") + torch._assert(w == self.image_size, "Wrong image width!") + n_h = h // p + n_w = w // p + + # (n, c, h, w) -> (n, hidden_dim, n_h, n_w) + x = self.conv_proj(x) + # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w)) + x = x.reshape(n, self.hidden_dim, n_h * n_w) + + # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim) + # The self attention layer expects inputs in the format (N, S, E) + # where S is the source sequence length, N is the batch size, E is the + # embedding dimension + x = x.permute(0, 2, 1) + + # Expand the class token to the full batch. + batch_class_token = self.class_token.expand(n, -1, -1) + x = torch.cat([batch_class_token, x], dim=1) + + x = self.encoder(x) + + # Classifier "token" as used by standard language architectures + x = x[:, 0] + + x = self.heads(x) + + return x + + +class VisionTransformer_B_16Weights(Weights): + # If a default model is added here the corresponding changes need to be done in vit_b_16 + pass + + +class VisionTransformer_B_32Weights(Weights): + # If a default model is added here the corresponding changes need to be done in vit_b_32 + pass + + +class VisionTransformer_L_16Weights(Weights): + # If a default model is added here the corresponding changes need to be done in vit_l_16 + pass + + +class VisionTransformer_L_32Weights(Weights): + # If a default model is added here the corresponding changes need to be done in vit_l_32 + pass + + +def _vision_transformer( + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + weights: Optional[Weights], + progress: bool, + **kwargs: Any, +) -> VisionTransformer: + image_size = kwargs.pop("image_size", 224) + + model = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + **kwargs, + ) + + if weights: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model + + +def vit_b_16( + weights: Optional[VisionTransformer_B_16Weights] = None, progress: bool = True, **kwargs: Any +) -> VisionTransformer: + """ + Constructs a vit_b_16 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + Args: + weights (VisionTransformer_B_16Weights, optional): If not None, returns a model pre-trained on ImageNet. + Default: None. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. + """ + if type(weights) == bool and weights: + _deprecated_positional(kwargs, "pretrained", "weights", True) + if "pretrained" in kwargs: + weights = _deprecated_param(kwargs, "pretrained", "weights", None) + weights = VisionTransformer_B_16Weights.verify(weights) + + return _vision_transformer( + patch_size=16, + num_layers=12, + num_heads=12, + hidden_dim=768, + mlp_dim=3072, + weights=weights, + progress=progress, + **kwargs, + ) + + +def vit_b_32( + weights: Optional[VisionTransformer_B_32Weights] = None, progress: bool = True, **kwargs: Any +) -> VisionTransformer: + """ + Constructs a vit_b_32 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + Args: + weights (VisionTransformer_B_32Weights, optional): If not None, returns a model pre-trained on ImageNet. + Default: None. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. + """ + if type(weights) == bool and weights: + _deprecated_positional(kwargs, "pretrained", "weights", True) + if "pretrained" in kwargs: + weights = _deprecated_param(kwargs, "pretrained", "weights", None) + weights = VisionTransformer_B_32Weights.verify(weights) + + return _vision_transformer( + patch_size=32, + num_layers=12, + num_heads=12, + hidden_dim=768, + mlp_dim=3072, + weights=weights, + progress=progress, + **kwargs, + ) + + +def vit_l_16( + weights: Optional[VisionTransformer_L_16Weights] = None, progress: bool = True, **kwargs: Any +) -> VisionTransformer: + """ + Constructs a vit_l_16 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + Args: + weights (VisionTransformer_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet. + Default: None. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. + """ + if type(weights) == bool and weights: + _deprecated_positional(kwargs, "pretrained", "weights", True) + if "pretrained" in kwargs: + weights = _deprecated_param(kwargs, "pretrained", "weights", None) + weights = VisionTransformer_L_16Weights.verify(weights) + + return _vision_transformer( + patch_size=16, + num_layers=24, + num_heads=16, + hidden_dim=1024, + mlp_dim=4096, + weights=weights, + progress=progress, + **kwargs, + ) + + +def vit_l_32( + weights: Optional[VisionTransformer_B_32Weights] = None, progress: bool = True, **kwargs: Any +) -> VisionTransformer: + """ + Constructs a vit_l_32 architecture from + `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. + + Args: + weights (VisionTransformer_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet. + Default: None. + progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. + """ + if type(weights) == bool and weights: + _deprecated_positional(kwargs, "pretrained", "weights", True) + if "pretrained" in kwargs: + weights = _deprecated_param(kwargs, "pretrained", "weights", None) + weights = VisionTransformer_L_32Weights.verify(weights) + + return _vision_transformer( + patch_size=32, + num_layers=24, + num_heads=16, + hidden_dim=1024, + mlp_dim=4096, + weights=weights, + progress=progress, + **kwargs, + )