diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 2f2cfd44445..e7082ddd9cb 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Optional, Callable, List, Any +from typing import Any, Callable, List, Optional, Tuple import torch import torch.nn.functional as F @@ -9,7 +9,7 @@ from ..ops.stochastic_depth import StochasticDepth from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once -from ._api import WeightsEnum, Weights +from ._api import Weights, WeightsEnum from ._meta import _IMAGENET_CATEGORIES from ._utils import _ovewrite_named_param @@ -160,7 +160,22 @@ def shifted_window_attention( return x +def _fix_window_and_shift_size( + input_size: List[int], window_size: List[int], shift_size: List[int] +) -> Tuple[List[int], List[int]]: + # Handle case where window_size is larger than input tensor + # Reference on the original implementation: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L192-L195 + updated_window_size = window_size.copy() + updated_shift_size = shift_size.copy() + for i in range(len(input_size)): + if input_size[i] <= window_size[i]: + updated_window_size[i] = input_size[i] + updated_shift_size[i] = 0 + return updated_window_size, updated_shift_size + + torch.fx.wrap("shifted_window_attention") +torch.fx.wrap("_fix_window_and_shift_size") class ShiftedWindowAttention(nn.Module): @@ -218,8 +233,12 @@ def forward(self, x: Tensor): Returns: Tensor with same layout as input, i.e. [B, H, W, C] """ + _, H, W, _ = x.shape + input_hw = [H, W] + # Handle case where the window_size is larger than the input + window_size, shift_size = _fix_window_and_shift_size(input_hw, self.window_size, self.shift_size) - N = self.window_size[0] * self.window_size[1] + N = window_size[0] * window_size[1] relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index] relative_position_bias = relative_position_bias.view(N, N, -1) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) @@ -229,9 +248,9 @@ def forward(self, x: Tensor): self.qkv.weight, self.proj.weight, relative_position_bias, - self.window_size, + window_size, self.num_heads, - shift_size=self.shift_size, + shift_size=shift_size, attention_dropout=self.attention_dropout, dropout=self.dropout, qkv_bias=self.qkv.bias,