From fb1a7496a175c9f414ea30431758b3889ff28eaf Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Wed, 25 May 2022 11:04:00 +0100 Subject: [PATCH 01/14] Use List[int] instead of int for window_size and shift_size --- torchvision/models/swin_transformer.py | 88 ++++++++++++++------------ 1 file changed, 46 insertions(+), 42 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 25e8900db56..722b0826082 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -59,9 +59,9 @@ def shifted_window_attention( qkv_weight: Tensor, proj_weight: Tensor, relative_position_bias: Tensor, - window_size: int, + window_size: List[int], num_heads: int, - shift_size: int = 0, + shift_size: List[int], attention_dropout: float = 0.0, dropout: float = 0.0, qkv_bias: Optional[Tensor] = None, @@ -75,9 +75,9 @@ def shifted_window_attention( qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection. relative_position_bias (Tensor): The learned relative position bias added to attention. - window_size (int): Window size. + window_size (List[int]): Window size. num_heads (int): Number of attention heads. - shift_size (int): Shift size for shifted window attention. Default: 0. + shift_size (List[int]): Shift size for shifted window attention. attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. dropout (float): Dropout ratio of output. Default: 0.0. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. @@ -87,23 +87,25 @@ def shifted_window_attention( """ B, H, W, C = input.shape # pad feature maps to multiples of window size - pad_r = (window_size - W % window_size) % window_size - pad_b = (window_size - H % window_size) % window_size + pad_r = (window_size[1] - W % window_size[1]) % window_size[1] + pad_b = (window_size[0] - H % window_size[0]) % window_size[0] x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) _, pad_H, pad_W, _ = x.shape # If window size is larger than feature size, there is no need to shift window. - if window_size == min(pad_H, pad_W): - shift_size = 0 + if window_size[0] >= pad_H: + shift_size[0] = 0 + if window_size[1] >= pad_W: + shift_size[1] = 0 # cyclic shift - if shift_size > 0: - x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) # partition windows - num_windows = (pad_H // window_size) * (pad_W // window_size) - x = x.view(B, pad_H // window_size, window_size, pad_W // window_size, window_size, C) - x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size * window_size, C) # B*nW, Ws*Ws, C + num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1]) + x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C # multi-head attention qkv = F.linear(x, qkv_weight, qkv_bias) @@ -114,17 +116,17 @@ def shifted_window_attention( # add relative position bias attn = attn + relative_position_bias - if shift_size > 0: + if sum(shift_size) > 0: # generate attention mask attn_mask = x.new_zeros((pad_H, pad_W)) - slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, None)) + slices = [((0, -window_size[i]), (-window_size[i], -shift_size[i]), (-shift_size[i], None)) for i in range(2)] count = 0 - for h in slices: - for w in slices: + for h in slices[0]: + for w in slices[1]: attn_mask[h[0] : h[1], w[0] : w[1]] = count count += 1 - attn_mask = attn_mask.view(pad_H // window_size, window_size, pad_W // window_size, window_size) - attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size * window_size) + attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1]) + attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1]) attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1)) @@ -139,12 +141,12 @@ def shifted_window_attention( x = F.dropout(x, p=dropout) # reverse windows - x = x.view(B, pad_H // window_size, pad_W // window_size, window_size, window_size, C) + x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C) x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C) # reverse cyclic shift - if shift_size > 0: - x = torch.roll(x, shifts=(shift_size, shift_size), dims=(1, 2)) + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) # unpad features x = x[:, :H, :W, :].contiguous() @@ -162,8 +164,8 @@ class ShiftedWindowAttention(nn.Module): def __init__( self, dim: int, - window_size: int, - shift_size: int, + window_size: List[int], + shift_size: List[int], num_heads: int, qkv_bias: bool = True, proj_bias: bool = True, @@ -171,6 +173,8 @@ def __init__( dropout: float = 0.0, ): super().__init__() + # Assert for 2d shift window attention + assert len(window_size) == 2 and len(shift_size) == 2, "window_size and shift_size must be of length 2" self.window_size = window_size self.shift_size = shift_size self.num_heads = num_heads @@ -182,19 +186,19 @@ def __init__( # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) ) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size) - coords_w = torch.arange(self.window_size) + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size - 1 - relative_coords[:, :, 0] *= 2 * self.window_size - 1 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww self.register_buffer("relative_position_index", relative_position_index) @@ -203,7 +207,7 @@ def __init__( def forward(self, x: Tensor): relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index] relative_position_bias = relative_position_bias.view( - self.window_size * self.window_size, self.window_size * self.window_size, -1 + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 ) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) @@ -228,8 +232,8 @@ class SwinTransformerBlock(nn.Module): Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. - window_size (int): Window size. Default: 7. - shift_size (int): Shift size for shifted window attention. Default: 0. + window_size (List[int]): Window size. + shift_size (List[int]): Shift size for shifted window attention. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. @@ -241,8 +245,8 @@ def __init__( self, dim: int, num_heads: int, - window_size: int = 7, - shift_size: int = 0, + window_size: List[int], + shift_size: List[int], mlp_ratio: float = 4.0, dropout: float = 0.0, attention_dropout: float = 0.0, @@ -285,7 +289,7 @@ class SwinTransformer(nn.Module): embed_dim (int): Patch embedding dimension. depths (List(int)): Depth of each Swin Transformer layer. num_heads (List(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7. + window_size (List[int]): Window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. @@ -301,7 +305,7 @@ def __init__( embed_dim: int, depths: List[int], num_heads: List[int], - window_size: int = 7, + window_size: List[int], mlp_ratio: float = 4.0, dropout: float = 0.0, attention_dropout: float = 0.0, @@ -344,7 +348,7 @@ def __init__( dim, num_heads[i_stage], window_size=window_size, - shift_size=0 if i_layer % 2 == 0 else window_size // 2, + shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size], mlp_ratio=mlp_ratio, dropout=dropout, attention_dropout=attention_dropout, @@ -385,7 +389,7 @@ def _swin_transformer( embed_dim: int, depths: List[int], num_heads: List[int], - window_size: int, + window_size: List[int], stochastic_depth_prob: float, weights: Optional[WeightsEnum], progress: bool, @@ -512,7 +516,7 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, * embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, + window_size=[7, 7], stochastic_depth_prob=0.2, weights=weights, progress=progress, @@ -548,7 +552,7 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, * embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], - window_size=7, + window_size=[7, 7], stochastic_depth_prob=0.3, weights=weights, progress=progress, @@ -584,7 +588,7 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, * embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], - window_size=7, + window_size=[7, 7], stochastic_depth_prob=0.5, weights=weights, progress=progress, From c3902aea0539c56aeaabf3533bec91d71e90527b Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Wed, 25 May 2022 12:58:27 +0100 Subject: [PATCH 02/14] Make PatchMerging and SwinTransformerBlock able to handle 2d and 3d cases --- torchvision/models/swin_transformer.py | 27 ++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 722b0826082..bec0591e5d6 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -39,18 +39,23 @@ def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm self.norm = norm_layer(4 * dim) def forward(self, x: Tensor): - B, H, W, C = x.shape + """ + Args: + x (Tensor): input tensor with expected layout of [... H W C] + """ + H, W, C = x.shape[-3:] - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + if (H % 2 == 1) or (W % 2 == 1): + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C + x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C + x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C + x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C x = self.norm(x) - x = self.reduction(x) - x = x.view(B, H // 2, W // 2, 2 * C) + x = self.reduction(x) # ... H/2 W/2 2*C return x @@ -239,6 +244,7 @@ class SwinTransformerBlock(nn.Module): attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention """ def __init__( @@ -252,11 +258,12 @@ def __init__( attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention, ): super().__init__() self.norm1 = norm_layer(dim) - self.attn = ShiftedWindowAttention( + self.attn = attn_layer( dim, window_size, shift_size, From 13d4d857a72352800ae2970ccb83f7602e0a7629 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Wed, 25 May 2022 15:58:59 +0100 Subject: [PATCH 03/14] Separate patch embedding from SwinTransformer and enable to get model without head by specifying num_heads=None --- torchvision/models/swin_transformer.py | 124 ++++++++++++++++++------- 1 file changed, 92 insertions(+), 32 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index bec0591e5d6..4b13a947505 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from torch import nn, Tensor -from ..ops.misc import MLP, Permute +from ..ops.misc import MLP from ..ops.stochastic_depth import StochasticDepth from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once @@ -41,7 +41,9 @@ def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm def forward(self, x: Tensor): """ Args: - x (Tensor): input tensor with expected layout of [... H W C] + x (Tensor): input tensor with expected layout of [..., H, W, C] + Returns: + Tensor with layout of [..., H/2, W/2, 2*C] """ H, W, C = x.shape[-3:] @@ -97,12 +99,6 @@ def shifted_window_attention( x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) _, pad_H, pad_W, _ = x.shape - # If window size is larger than feature size, there is no need to shift window. - if window_size[0] >= pad_H: - shift_size[0] = 0 - if window_size[1] >= pad_W: - shift_size[1] = 0 - # cyclic shift if sum(shift_size) > 0: x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) @@ -204,16 +200,30 @@ def __init__( relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) def forward(self, x: Tensor): - relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index] - relative_position_bias = relative_position_bias.view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 - ) + """ + Args: + x (Tensor): Tensor with layout of [B, H, W, C] + Returns: + Tensor with same layout as input, i.e. [B, H, W, C] + """ + _, H, W, _ = x.shape + size_hw = [H, W] + window_size, shift_size = self.window_size.copy(), self.shift_size.copy() + # Handle case where window_size is larger than input tensor + for i in range(2): + if size_hw[i] <= window_size[i]: + window_size[i] = size_hw[i] + shift_size[i] = 0 + + N = window_size[0] * window_size[1] + relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N]] # type: ignore[index] + relative_position_bias = relative_position_bias.view(N, N, -1) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) return shifted_window_attention( @@ -221,9 +231,9 @@ def forward(self, x: Tensor): self.qkv.weight, self.proj.weight, relative_position_bias, - self.window_size, + window_size, self.num_heads, - shift_size=self.shift_size, + shift_size=shift_size, attention_dropout=self.attention_dropout, dropout=self.dropout, qkv_bias=self.qkv.bias, @@ -287,12 +297,59 @@ def forward(self, x: Tensor): return x +class PatchEmbed(nn.Module): + """Split image into non-overlapping patches + + Args: + patch_size (List[int]): Patch token size. + embed_dim (int): Number of linear projection output channels. + in_chanels (int): Number of input channels. Default: 3. + norm_layer (Optional[Callable[..., nn.Module]]): Normalization layer. Default: None. + """ + + def __init__( + self, + patch_size: List[int], + embed_dim: int, + in_channels: int = 3, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1]) + ) + self.norm: Optional[nn.Module] + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): batch of images with layout [B, C, H, W] + Returns: + Tensor with layout [B, H/patch_size[0], W/patch_size[1], embed_dim] + """ + _, _, H, W = x.shape + pad_r = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] + pad_b = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] + x = F.pad(x, (0, pad_r, 0, pad_b)) + + x = self.proj(x) + x = x.permute(0, 2, 3, 1) + if self.norm is not None: + x = self.norm(x) + return x + + class SwinTransformer(nn.Module): """ Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" `_ paper. Args: - patch_size (int): Patch size. + patch_size (List[int]): Patch size. embed_dim (int): Patch embedding dimension. depths (List(int)): Depth of each Swin Transformer layer. num_heads (List(int)): Number of attention heads in different layers. @@ -304,11 +361,12 @@ class SwinTransformer(nn.Module): num_classes (int): Number of classes for classification head. Default: 1000. block (nn.Module, optional): SwinTransformer Block. Default: None. norm_layer (nn.Module, optional): Normalization layer. Default: None. + patch_embed (nn.Module, optional): Patch Embedding layer. Default: None. """ def __init__( self, - patch_size: int, + patch_size: List[int], embed_dim: int, depths: List[int], num_heads: List[int], @@ -320,6 +378,7 @@ def __init__( num_classes: int = 1000, norm_layer: Optional[Callable[..., nn.Module]] = None, block: Optional[Callable[..., nn.Module]] = None, + patch_embed: Optional[Callable[..., nn.Module]] = None, ): super().__init__() _log_api_usage_once(self) @@ -331,16 +390,11 @@ def __init__( if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-5) - layers: List[nn.Module] = [] - # split image into non-overlapping patches - layers.append( - nn.Sequential( - nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size), - Permute([0, 2, 3, 1]), - norm_layer(embed_dim), - ) - ) + if patch_embed is None: + patch_embed = PatchEmbed + self.patch_embed = patch_embed(patch_size, embed_dim, norm_layer=norm_layer) + layers: List[nn.Module] = [] total_stage_blocks = sum(depths) stage_block_id = 0 # build SwinTransformer blocks @@ -373,7 +427,11 @@ def __init__( num_features = embed_dim * 2 ** (len(depths) - 1) self.norm = norm_layer(num_features) self.avgpool = nn.AdaptiveAvgPool2d(1) - self.head = nn.Linear(num_features, num_classes) + self.head: Optional[nn.Module] + if num_classes is not None: + self.head = nn.Linear(num_features, num_classes) + else: + self.head = None for m in self.modules(): if isinstance(m, nn.Linear): @@ -382,17 +440,19 @@ def __init__( nn.init.zeros_(m.bias) def forward(self, x): + x = self.patch_embed(x) x = self.features(x) x = self.norm(x) x = x.permute(0, 3, 1, 2) x = self.avgpool(x) x = torch.flatten(x, 1) - x = self.head(x) + if self.num_classes is not None: + x = self.head(x) return x def _swin_transformer( - patch_size: int, + patch_size: List[int], embed_dim: int, depths: List[int], num_heads: List[int], @@ -519,7 +579,7 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, * weights = Swin_T_Weights.verify(weights) return _swin_transformer( - patch_size=4, + patch_size=[4, 4], embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], @@ -555,7 +615,7 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, * weights = Swin_S_Weights.verify(weights) return _swin_transformer( - patch_size=4, + patch_size=[4, 4], embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], @@ -591,7 +651,7 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, * weights = Swin_B_Weights.verify(weights) return _swin_transformer( - patch_size=4, + patch_size=[4, 4], embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], From a3d11929b748a64f9c6ad0591e2d323da031de78 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Wed, 25 May 2022 16:45:01 +0100 Subject: [PATCH 04/14] Dont use if before padding so it is fx friendly --- torchvision/models/swin_transformer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index bec0591e5d6..6dfc7b4a2b3 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -43,10 +43,8 @@ def forward(self, x: Tensor): Args: x (Tensor): input tensor with expected layout of [... H W C] """ - H, W, C = x.shape[-3:] - - if (H % 2 == 1) or (W % 2 == 1): - x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + H, W, _ = x.shape[-3:] + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C From ba5f8f90cda99d39aca43f70ed8e4cae41f051ac Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Wed, 25 May 2022 19:31:23 +0100 Subject: [PATCH 05/14] Put the handling on window_size edge cases on separate function and wrap with torch.fx.wrap so it is excluded from tracing --- torchvision/models/swin_transformer.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index c62108067d2..ee119a52129 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -155,6 +155,18 @@ def shifted_window_attention( torch.fx.wrap("shifted_window_attention") +def _fix_window_and_shift_size(input_hw: List[int], window_size: List[int], shift_size: List[int]): + # Handle case where window_size is larger than input tensor + for i in range(2): + if input_hw[i] <= window_size[i]: + window_size[i] = input_hw[i] + shift_size[i] = 0 + return window_size, shift_size + + +torch.fx.wrap("_fix_window_and_shift_size") + + class ShiftedWindowAttention(nn.Module): """ See :func:`shifted_window_attention`. @@ -211,13 +223,8 @@ def forward(self, x: Tensor): Tensor with same layout as input, i.e. [B, H, W, C] """ _, H, W, _ = x.shape - size_hw = [H, W] - window_size, shift_size = self.window_size.copy(), self.shift_size.copy() - # Handle case where window_size is larger than input tensor - for i in range(2): - if size_hw[i] <= window_size[i]: - window_size[i] = size_hw[i] - shift_size[i] = 0 + input_hw = [H, W] + window_size, shift_size = _fix_window_and_shift_size(input_hw, self.window_size, self.shift_size) N = window_size[0] * window_size[1] relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N]] # type: ignore[index] From ed14bd7f21de793b6db76852f703eb338e6ea385 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Wed, 25 May 2022 20:34:53 +0100 Subject: [PATCH 06/14] Update the weight url to the converted weight with new structure --- torchvision/models/swin_transformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index ee119a52129..56736b8f8fe 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -493,7 +493,7 @@ def _swin_transformer( class Swin_T_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/swin_t-704ceda3.pth", + url="https://download.pytorch.org/models/swin_t-2a53f41a.pth", transforms=partial( ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC ), @@ -516,7 +516,7 @@ class Swin_T_Weights(WeightsEnum): class Swin_S_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/swin_s-5e29d889.pth", + url="https://download.pytorch.org/models/swin_s-5acd2824.pth", transforms=partial( ImageClassification, crop_size=224, resize_size=246, interpolation=InterpolationMode.BICUBIC ), @@ -539,7 +539,7 @@ class Swin_S_Weights(WeightsEnum): class Swin_B_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/swin_b-68c6b09e.pth", + url="https://download.pytorch.org/models/swin_b-ff10ecec.pth", transforms=partial( ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC ), From 548109aaefa85785eeb9f67a220f9761ec4b041d Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Wed, 25 May 2022 22:54:55 +0100 Subject: [PATCH 07/14] Update the accuracy of swin_transformer --- torchvision/models/swin_transformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 56736b8f8fe..e100884006a 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -504,7 +504,7 @@ class Swin_T_Weights(WeightsEnum): "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", "_metrics": { "ImageNet-1K": { - "acc@1": 81.474, + "acc@1": 81.470, "acc@5": 95.776, } }, @@ -528,7 +528,7 @@ class Swin_S_Weights(WeightsEnum): "_metrics": { "ImageNet-1K": { "acc@1": 83.196, - "acc@5": 96.360, + "acc@5": 96.362, } }, "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", @@ -550,7 +550,7 @@ class Swin_B_Weights(WeightsEnum): "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", "_metrics": { "ImageNet-1K": { - "acc@1": 83.582, + "acc@1": 83.584, "acc@5": 96.640, } }, From 2d767f84082f8cdf57524d6aa53d021d63439b3d Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Wed, 25 May 2022 23:12:13 +0100 Subject: [PATCH 08/14] Change assert to Exception and nit --- torchvision/models/swin_transformer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index e100884006a..1aff25289fa 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -118,10 +118,11 @@ def shifted_window_attention( if sum(shift_size) > 0: # generate attention mask attn_mask = x.new_zeros((pad_H, pad_W)) - slices = [((0, -window_size[i]), (-window_size[i], -shift_size[i]), (-shift_size[i], None)) for i in range(2)] + h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None)) + w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None)) count = 0 - for h in slices[0]: - for w in slices[1]: + for h in h_slices: + for w in w_slices: attn_mask[h[0] : h[1], w[0] : w[1]] = count count += 1 attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1]) @@ -184,8 +185,8 @@ def __init__( dropout: float = 0.0, ): super().__init__() - # Assert for 2d shift window attention - assert len(window_size) == 2 and len(shift_size) == 2, "window_size and shift_size must be of length 2" + if len(window_size) != 2 or len(shift_size) != 2: + raise ValueError("window_size and shift_size must be of length 2") self.window_size = window_size self.shift_size = shift_size self.num_heads = num_heads From 65c8439fd65c403723b2561229daf109b7ccf8a0 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Wed, 25 May 2022 23:41:50 +0100 Subject: [PATCH 09/14] Make num_classes optional --- torchvision/models/swin_transformer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 1aff25289fa..99eac3729a0 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -364,7 +364,8 @@ class SwinTransformer(nn.Module): dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob (float): Stochastic depth rate. Default: 0.0. - num_classes (int): Number of classes for classification head. Default: 1000. + num_classes (int, optional): Number of classes for classification head, + if None then the model have no head. Default: 1000. block (nn.Module, optional): SwinTransformer Block. Default: None. norm_layer (nn.Module, optional): Normalization layer. Default: None. patch_embed (nn.Module, optional): Patch Embedding layer. Default: None. @@ -381,7 +382,7 @@ def __init__( dropout: float = 0.0, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, - num_classes: int = 1000, + num_classes: Optional[int] = 1000, norm_layer: Optional[Callable[..., nn.Module]] = None, block: Optional[Callable[..., nn.Module]] = None, patch_embed: Optional[Callable[..., nn.Module]] = None, From f0f872fb3154689f6b3a3f0e115d0898cad94da3 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Thu, 26 May 2022 08:27:51 +0100 Subject: [PATCH 10/14] Add typing output for _fix_window_and_shift_size function --- torchvision/models/swin_transformer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 99eac3729a0..a1ce215e28d 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Optional, Callable, List, Any +from typing import Optional, Callable, List, Tuple, Any import torch import torch.nn.functional as F @@ -156,7 +156,9 @@ def shifted_window_attention( torch.fx.wrap("shifted_window_attention") -def _fix_window_and_shift_size(input_hw: List[int], window_size: List[int], shift_size: List[int]): +def _fix_window_and_shift_size( + input_hw: List[int], window_size: List[int], shift_size: List[int] +) -> Tuple[List[int], List[int]]: # Handle case where window_size is larger than input tensor for i in range(2): if input_hw[i] <= window_size[i]: From e2317f41fc28297ffc95c8aeef1af794c3fe048d Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Thu, 26 May 2022 08:43:10 +0100 Subject: [PATCH 11/14] init head to None to make it jit scriptable --- torchvision/models/swin_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index a1ce215e28d..4401bdd75bc 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -436,7 +436,7 @@ def __init__( num_features = embed_dim * 2 ** (len(depths) - 1) self.norm = norm_layer(num_features) self.avgpool = nn.AdaptiveAvgPool2d(1) - self.head: Optional[nn.Module] + self.head: Optional[nn.Module] = None if num_classes is not None: self.head = nn.Linear(num_features, num_classes) else: From 77ce2b9bb4e00b36f8bddaf63d7347710f3ecedf Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Thu, 26 May 2022 11:03:26 +0100 Subject: [PATCH 12/14] Revert the change to make num_classes optional --- torchvision/models/swin_transformer.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 4401bdd75bc..f97a769488b 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -366,8 +366,7 @@ class SwinTransformer(nn.Module): dropout (float): Dropout rate. Default: 0.0. attention_dropout (float): Attention dropout rate. Default: 0.0. stochastic_depth_prob (float): Stochastic depth rate. Default: 0.0. - num_classes (int, optional): Number of classes for classification head, - if None then the model have no head. Default: 1000. + num_classes (int): Number of classes for classification head. Default: 1000. block (nn.Module, optional): SwinTransformer Block. Default: None. norm_layer (nn.Module, optional): Normalization layer. Default: None. patch_embed (nn.Module, optional): Patch Embedding layer. Default: None. @@ -384,7 +383,7 @@ def __init__( dropout: float = 0.0, attention_dropout: float = 0.0, stochastic_depth_prob: float = 0.0, - num_classes: Optional[int] = 1000, + num_classes: int = 1000, norm_layer: Optional[Callable[..., nn.Module]] = None, block: Optional[Callable[..., nn.Module]] = None, patch_embed: Optional[Callable[..., nn.Module]] = None, @@ -437,10 +436,7 @@ def __init__( self.norm = norm_layer(num_features) self.avgpool = nn.AdaptiveAvgPool2d(1) self.head: Optional[nn.Module] = None - if num_classes is not None: - self.head = nn.Linear(num_features, num_classes) - else: - self.head = None + self.head = nn.Linear(num_features, num_classes) for m in self.modules(): if isinstance(m, nn.Linear): @@ -455,8 +451,7 @@ def forward(self, x): x = x.permute(0, 3, 1, 2) x = self.avgpool(x) x = torch.flatten(x, 1) - if self.num_classes is not None: - x = self.head(x) + x = self.head(x) return x From f04801fce7c6510e1946981538af65242498b92c Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Thu, 26 May 2022 12:00:57 +0100 Subject: [PATCH 13/14] Revert unneccesarry changes that might be risky --- torchvision/models/swin_transformer.py | 114 +++++++------------------ 1 file changed, 30 insertions(+), 84 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index f97a769488b..357d3e3a3da 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -1,11 +1,11 @@ from functools import partial -from typing import Optional, Callable, List, Tuple, Any +from typing import Optional, Callable, List, Any import torch import torch.nn.functional as F from torch import nn, Tensor -from ..ops.misc import MLP +from ..ops.misc import MLP, Permute from ..ops.stochastic_depth import StochasticDepth from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once @@ -97,6 +97,12 @@ def shifted_window_attention( x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) _, pad_H, pad_W, _ = x.shape + # If window size is larger than feature size, there is no need to shift window + if window_size[0] >= pad_H: + shift_size[0] = 0 + if window_size[1] >= pad_W: + shift_size[1] = 0 + # cyclic shift if sum(shift_size) > 0: x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) @@ -156,20 +162,6 @@ def shifted_window_attention( torch.fx.wrap("shifted_window_attention") -def _fix_window_and_shift_size( - input_hw: List[int], window_size: List[int], shift_size: List[int] -) -> Tuple[List[int], List[int]]: - # Handle case where window_size is larger than input tensor - for i in range(2): - if input_hw[i] <= window_size[i]: - window_size[i] = input_hw[i] - shift_size[i] = 0 - return window_size, shift_size - - -torch.fx.wrap("_fix_window_and_shift_size") - - class ShiftedWindowAttention(nn.Module): """ See :func:`shifted_window_attention`. @@ -213,7 +205,7 @@ def __init__( relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww self.register_buffer("relative_position_index", relative_position_index) nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) @@ -225,12 +217,9 @@ def forward(self, x: Tensor): Returns: Tensor with same layout as input, i.e. [B, H, W, C] """ - _, H, W, _ = x.shape - input_hw = [H, W] - window_size, shift_size = _fix_window_and_shift_size(input_hw, self.window_size, self.shift_size) - N = window_size[0] * window_size[1] - relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N]] # type: ignore[index] + N = self.window_size[0] * self.window_size[1] + relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index] relative_position_bias = relative_position_bias.view(N, N, -1) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) @@ -239,9 +228,9 @@ def forward(self, x: Tensor): self.qkv.weight, self.proj.weight, relative_position_bias, - window_size, + self.window_size, self.num_heads, - shift_size=shift_size, + shift_size=self.shift_size, attention_dropout=self.attention_dropout, dropout=self.dropout, qkv_bias=self.qkv.bias, @@ -305,53 +294,6 @@ def forward(self, x: Tensor): return x -class PatchEmbed(nn.Module): - """Split image into non-overlapping patches - - Args: - patch_size (List[int]): Patch token size. - embed_dim (int): Number of linear projection output channels. - in_chanels (int): Number of input channels. Default: 3. - norm_layer (Optional[Callable[..., nn.Module]]): Normalization layer. Default: None. - """ - - def __init__( - self, - patch_size: List[int], - embed_dim: int, - in_channels: int = 3, - norm_layer: Optional[Callable[..., nn.Module]] = None, - ): - super().__init__() - self.patch_size = patch_size - self.proj = nn.Conv2d( - in_channels, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1]) - ) - self.norm: Optional[nn.Module] - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x: Tensor) -> Tensor: - """ - Args: - x (Tensor): batch of images with layout [B, C, H, W] - Returns: - Tensor with layout [B, H/patch_size[0], W/patch_size[1], embed_dim] - """ - _, _, H, W = x.shape - pad_r = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] - pad_b = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] - x = F.pad(x, (0, pad_r, 0, pad_b)) - - x = self.proj(x) - x = x.permute(0, 2, 3, 1) - if self.norm is not None: - x = self.norm(x) - return x - - class SwinTransformer(nn.Module): """ Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using @@ -369,7 +311,6 @@ class SwinTransformer(nn.Module): num_classes (int): Number of classes for classification head. Default: 1000. block (nn.Module, optional): SwinTransformer Block. Default: None. norm_layer (nn.Module, optional): Normalization layer. Default: None. - patch_embed (nn.Module, optional): Patch Embedding layer. Default: None. """ def __init__( @@ -386,7 +327,6 @@ def __init__( num_classes: int = 1000, norm_layer: Optional[Callable[..., nn.Module]] = None, block: Optional[Callable[..., nn.Module]] = None, - patch_embed: Optional[Callable[..., nn.Module]] = None, ): super().__init__() _log_api_usage_once(self) @@ -398,11 +338,18 @@ def __init__( if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-5) - if patch_embed is None: - patch_embed = PatchEmbed - self.patch_embed = patch_embed(patch_size, embed_dim, norm_layer=norm_layer) - layers: List[nn.Module] = [] + # split image into non-overlapping patches + layers.append( + nn.Sequential( + nn.Conv2d( + 3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1]) + ), + Permute([0, 2, 3, 1]), + norm_layer(embed_dim), + ) + ) + total_stage_blocks = sum(depths) stage_block_id = 0 # build SwinTransformer blocks @@ -445,7 +392,6 @@ def __init__( nn.init.zeros_(m.bias) def forward(self, x): - x = self.patch_embed(x) x = self.features(x) x = self.norm(x) x = x.permute(0, 3, 1, 2) @@ -492,7 +438,7 @@ def _swin_transformer( class Swin_T_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/swin_t-2a53f41a.pth", + url="https://download.pytorch.org/models/swin_t-704ceda3.pth", transforms=partial( ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC ), @@ -503,7 +449,7 @@ class Swin_T_Weights(WeightsEnum): "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", "_metrics": { "ImageNet-1K": { - "acc@1": 81.470, + "acc@1": 81.474, "acc@5": 95.776, } }, @@ -515,7 +461,7 @@ class Swin_T_Weights(WeightsEnum): class Swin_S_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/swin_s-5acd2824.pth", + url="https://download.pytorch.org/models/swin_s-5e29d889.pth", transforms=partial( ImageClassification, crop_size=224, resize_size=246, interpolation=InterpolationMode.BICUBIC ), @@ -527,7 +473,7 @@ class Swin_S_Weights(WeightsEnum): "_metrics": { "ImageNet-1K": { "acc@1": 83.196, - "acc@5": 96.362, + "acc@5": 96.360, } }, "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", @@ -538,7 +484,7 @@ class Swin_S_Weights(WeightsEnum): class Swin_B_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/swin_b-ff10ecec.pth", + url="https://download.pytorch.org/models/swin_b-68c6b09e.pth", transforms=partial( ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC ), @@ -549,7 +495,7 @@ class Swin_B_Weights(WeightsEnum): "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", "_metrics": { "ImageNet-1K": { - "acc@1": 83.584, + "acc@1": 83.582, "acc@5": 96.640, } }, From 8cf5e39e636a3a4d4fc05220c722895d60c9ab83 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Thu, 26 May 2022 12:30:44 +0100 Subject: [PATCH 14/14] Remove self.head declaration --- torchvision/models/swin_transformer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 357d3e3a3da..c56093ed4bf 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -382,7 +382,6 @@ def __init__( num_features = embed_dim * 2 ** (len(depths) - 1) self.norm = norm_layer(num_features) self.avgpool = nn.AdaptiveAvgPool2d(1) - self.head: Optional[nn.Module] = None self.head = nn.Linear(num_features, num_classes) for m in self.modules():