diff --git a/docs/source/models.rst b/docs/source/models.rst index f84d9c7fd1a..51881e505e4 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -42,6 +42,7 @@ architectures for image classification: - `RegNet`_ - `VisionTransformer`_ - `ConvNeXt`_ +- `SwinTransformer`_ You can construct a model with random weights by calling its constructor: @@ -97,6 +98,7 @@ You can construct a model with random weights by calling its constructor: convnext_small = models.convnext_small() convnext_base = models.convnext_base() convnext_large = models.convnext_large() + swin_t = models.swin_t() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. @@ -219,6 +221,7 @@ convnext_tiny 82.520 96.146 convnext_small 83.616 96.650 convnext_base 84.062 96.870 convnext_large 84.414 96.976 +swin_t 81.358 95.526 ================================ ============= ============= @@ -238,6 +241,7 @@ convnext_large 84.414 96.976 .. _RegNet: https://arxiv.org/abs/2003.13678 .. _VisionTransformer: https://arxiv.org/abs/2010.11929 .. _ConvNeXt: https://arxiv.org/abs/2201.03545 +.. _SwinTransformer: https://arxiv.org/abs/2103.14030 .. currentmodule:: torchvision.models @@ -450,6 +454,15 @@ ConvNeXt convnext_base convnext_large +SwinTransformer +-------- + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + swin_t + Quantized Models ---------------- diff --git a/docs/source/models/swin_transformer.rst b/docs/source/models/swin_transformer.rst new file mode 100644 index 00000000000..b8726d71d2a --- /dev/null +++ b/docs/source/models/swin_transformer.rst @@ -0,0 +1,25 @@ +SwinTransformer +=============== + +.. currentmodule:: torchvision.models + +The SwinTransformer model is based on the `Swin Transformer: Hierarchical Vision +Transformer using Shifted Windows `__ +paper. + + +Model builders +-------------- + +The following model builders can be used to instanciate an SwinTransformer model. +`swin_t` can be instantiated with pre-trained weights and all others without. +All the model builders internally rely on the ``torchvision.models.swin_transformer.SwinTransformer`` +base class. Please refer to the `source code +`_ for +more details about this class. + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + swin_t diff --git a/docs/source/models_new.rst b/docs/source/models_new.rst index 7df9522f306..12249571e5c 100644 --- a/docs/source/models_new.rst +++ b/docs/source/models_new.rst @@ -46,6 +46,7 @@ weights: models/resnet models/resnext models/squeezenet + models/swin_transformer models/vgg models/vision_transformer diff --git a/references/classification/README.md b/references/classification/README.md index c274c997791..758c57ce27d 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -224,6 +224,18 @@ Note that the above command corresponds to training on a single node with 8 GPUs For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs), and `--batch_size 64`. + +### SwinTransformer +``` +torchrun --nproc_per_node=8 train.py\ +--model swin_t --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0\ +--bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear\ +--lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8\ +--clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ra +``` +Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value. + + ## Mixed precision training Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [torch.cuda.amp](https://pytorch.org/docs/stable/amp.html?highlight=amp#module-torch.cuda.amp). diff --git a/references/classification/train.py b/references/classification/train.py index 6a3c289bc04..96703bfdf85 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -233,7 +233,7 @@ def main(args): if args.bias_weight_decay is not None: custom_keys_weight_decay.append(("bias", args.bias_weight_decay)) if args.transformer_embedding_decay is not None: - for key in ["class_token", "position_embedding", "relative_position_bias"]: + for key in ["class_token", "position_embedding", "relative_position_bias_table"]: custom_keys_weight_decay.append((key, args.transformer_embedding_decay)) parameters = utils.set_weight_decay( model, @@ -267,7 +267,7 @@ def main(args): main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) elif args.lr_scheduler == "cosineannealinglr": main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, T_max=args.epochs - args.lr_warmup_epochs + optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min ) elif args.lr_scheduler == "exponentiallr": main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma) @@ -424,6 +424,7 @@ def get_args_parser(add_help=True): parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr") parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") + parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)") parser.add_argument("--print-freq", default=10, type=int, help="print frequency") parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs") parser.add_argument("--resume", default="", type=str, help="path of checkpoint") diff --git a/test/expect/ModelTester.test_swin_t_expect.pkl b/test/expect/ModelTester.test_swin_t_expect.pkl new file mode 100644 index 00000000000..7326683b7a5 Binary files /dev/null and b/test/expect/ModelTester.test_swin_t_expect.pkl differ diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 83e49908348..00b5ebefe55 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -12,6 +12,7 @@ from .squeezenet import * from .vgg import * from .vision_transformer import * +from .swin_transformer import * from . import detection from . import optical_flow from . import quantization diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py new file mode 100644 index 00000000000..455397c8403 --- /dev/null +++ b/torchvision/models/swin_transformer.py @@ -0,0 +1,462 @@ +from functools import partial +from typing import Optional, Callable, List, Any + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from ..ops.stochastic_depth import StochasticDepth +from ..transforms._presets import ImageClassification, InterpolationMode +from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import _ovewrite_named_param +from .convnext import Permute +from .vision_transformer import MLPBlock + + +__all__ = [ + "SwinTransformer", + "Swin_T_Weights", + "swin_t", +] + + +class PatchMerging(nn.Module): + """Patch Merging Layer. + Args: + dim (int): Number of input channels. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + """ + + def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x: Tensor): + B, H, W, C = x.shape + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + x = x.view(B, H // 2, W // 2, 2 * C) + return x + + +def shifted_window_attention( + input: Tensor, + qkv_weight: Tensor, + proj_weight: Tensor, + relative_position_bias: Tensor, + window_size: int, + num_heads: int, + shift_size: int = 0, + attention_dropout: float = 0.0, + dropout: float = 0.0, + qkv_bias: Optional[Tensor] = None, + proj_bias: Optional[Tensor] = None, +): + """ + Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + input (Tensor[N, H, W, C]): The input tensor or 4-dimensions. + qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. + proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection. + relative_position_bias (Tensor): The learned relative position bias added to attention. + window_size (int): Window size. + num_heads (int): Number of attention heads. + shift_size (int): Shift size for shifted window attention. Default: 0. + attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. + dropout (float): Dropout ratio of output. Default: 0.0. + qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. + proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. + Returns: + Tensor[N, H, W, C]: The output tensor after shifted window attention. + """ + B, H, W, C = input.shape + # pad feature maps to multiples of window size + pad_r = (window_size - W % window_size) % window_size + pad_b = (window_size - H % window_size) % window_size + x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) + _, pad_H, pad_W, _ = x.shape + + # If window size is larger than feature size, there is no need to shift window. + if window_size == min(pad_H, pad_W): + shift_size = 0 + + # cyclic shift + if shift_size > 0: + x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) + + # partition windows + num_windows = (pad_H // window_size) * (pad_W // window_size) + x = x.view(B, pad_H // window_size, window_size, pad_W // window_size, window_size, C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size * window_size, C) # B*nW, Ws*Ws, C + + # multi-head attention + qkv = F.linear(x, qkv_weight, qkv_bias) + qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * (C // num_heads) ** -0.5 + attn = q.matmul(k.transpose(-2, -1)) + # add relative position bias + attn = attn + relative_position_bias + + if shift_size > 0: + # generate attention mask + attn_mask = x.new_zeros((pad_H, pad_W)) + slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, None)) + count = 0 + for h in slices: + for w in slices: + attn_mask[h[0] : h[1], w[0] : w[1]] = count + count += 1 + attn_mask = attn_mask.view(pad_H // window_size, window_size, pad_W // window_size, window_size) + attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size * window_size) + attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1)) + attn = attn + attn_mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, num_heads, x.size(1), x.size(1)) + + attn = F.softmax(attn, dim=-1) + attn = F.dropout(attn, p=attention_dropout) + + x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C) + x = F.linear(x, proj_weight, proj_bias) + x = F.dropout(x, p=dropout) + + # reverse windows + x = x.view(B, pad_H // window_size, pad_W // window_size, window_size, window_size, C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C) + + # reverse cyclic shift + if shift_size > 0: + x = torch.roll(x, shifts=(shift_size, shift_size), dims=(1, 2)) + + # unpad features + x = x[:, :H, :W, :].contiguous() + return x + + +torch.fx.wrap("shifted_window_attention") + + +class ShiftedWindowAttention(nn.Module): + """ + See :func:`shifted_window_attention`. + """ + + def __init__( + self, + dim: int, + window_size: int, + shift_size: int, + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + attention_dropout: float = 0.0, + dropout: float = 0.0, + ): + super().__init__() + self.window_size = window_size + self.shift_size = shift_size + self.num_heads = num_heads + self.attention_dropout = attention_dropout + self.dropout = dropout + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size) + coords_w = torch.arange(self.window_size) + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size - 1 + relative_coords[:, :, 0] *= 2 * self.window_size - 1 + relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x: Tensor): + relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index] + relative_position_bias = relative_position_bias.view( + self.window_size * self.window_size, self.window_size * self.window_size, -1 + ) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) + + return shifted_window_attention( + x, + self.qkv.weight, + self.proj.weight, + relative_position_bias, + self.window_size, + self.num_heads, + shift_size=self.shift_size, + attention_dropout=self.attention_dropout, + dropout=self.dropout, + qkv_bias=self.qkv.bias, + proj_bias=self.proj.bias, + ) + + +class SwinTransformerBlock(nn.Module): + """ + Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. Default: 7. + shift_size (int): Shift size for shifted window attention. Default: 0. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: int = 7, + shift_size: int = 0, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + ): + super().__init__() + + self.norm1 = norm_layer(dim) + self.attn = ShiftedWindowAttention( + dim, + window_size, + shift_size, + num_heads, + attention_dropout=attention_dropout, + dropout=dropout, + ) + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(dim, int(dim * mlp_ratio), dropout) + + def forward(self, x: Tensor): + x = x + self.stochastic_depth(self.attn(self.norm1(x))) + x = x + self.stochastic_depth(self.mlp(self.norm2(x))) + return x + + +class SwinTransformer(nn.Module): + """ + Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using + Shifted Windows" `_ paper. + Args: + patch_size (int): Patch size. + embed_dim (int): Patch embedding dimension. + depths (List(int)): Depth of each Swin Transformer layer. + num_heads (List(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob (float): Stochastic depth rate. Default: 0.0. + num_classes (int): Number of classes for classification head. Default: 1000. + block (nn.Module, optional): SwinTransformer Block. Default: None. + norm_layer (nn.Module, optional): Normalization layer. Default: None. + """ + + def __init__( + self, + patch_size: int, + embed_dim: int, + depths: List[int], + num_heads: List[int], + window_size: int = 7, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + num_classes: int = 1000, + norm_layer: Optional[Callable[..., nn.Module]] = None, + block: Optional[Callable[..., nn.Module]] = None, + ): + super().__init__() + _log_api_usage_once(self) + self.num_classes = num_classes + + if block is None: + block = SwinTransformerBlock + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-5) + + layers: List[nn.Module] = [] + # split image into non-overlapping patches + layers.append( + nn.Sequential( + nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size), + Permute([0, 2, 3, 1]), + norm_layer(embed_dim), + ) + ) + + total_stage_blocks = sum(depths) + stage_block_id = 0 + # build SwinTransformer blocks + for i_stage in range(len(depths)): + stage: List[nn.Module] = [] + dim = embed_dim * 2 ** i_stage + for i_layer in range(depths[i_stage]): + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) + stage.append( + block( + dim, + num_heads[i_stage], + window_size=window_size, + shift_size=0 if i_layer % 2 == 0 else window_size // 2, + mlp_ratio=mlp_ratio, + dropout=dropout, + attention_dropout=attention_dropout, + stochastic_depth_prob=sd_prob, + norm_layer=norm_layer, + ) + ) + stage_block_id += 1 + layers.append(nn.Sequential(*stage)) + # add patch merging layer + if i_stage < (len(depths) - 1): + layers.append(PatchMerging(dim, norm_layer)) + self.features = nn.Sequential(*layers) + + num_features = embed_dim * 2 ** (len(depths) - 1) + self.norm = norm_layer(num_features) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.head = nn.Linear(num_features, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x): + x = self.features(x) + x = self.norm(x) + x = x.permute(0, 3, 1, 2) + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.head(x) + return x + + +def _swin_transformer( + patch_size: int, + embed_dim: int, + depths: List[int], + num_heads: List[int], + window_size: int, + stochastic_depth_prob: float, + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> SwinTransformer: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = SwinTransformer( + patch_size=patch_size, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + stochastic_depth_prob=stochastic_depth_prob, + **kwargs, + ) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model + + +_COMMON_META = { + "categories": _IMAGENET_CATEGORIES, +} + + +class Swin_T_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_t-81486767.pth", + transforms=partial( + ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META, + "num_params": 28288354, + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swin_t", + "metrics": { + "acc@1": 81.358, + "acc@5": 95.526, + }, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_tiny architecture from + `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. + + Args: + weights (:class:`~torchvision.models.Swin_T_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_T_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_T_Weights + :members: + """ + weights = Swin_T_Weights.verify(weights) + + return _swin_transformer( + patch_size=4, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + stochastic_depth_prob=0.2, + weights=weights, + progress=progress, + **kwargs, + )