Skip to content

Swin transformer handle window size smaller than input size #6222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions torchvision/models/swin_transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Optional, Callable, List, Any
from typing import Any, Callable, List, Optional, Tuple

import torch
import torch.nn.functional as F
Expand All @@ -9,7 +9,7 @@
from ..ops.stochastic_depth import StochasticDepth
from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once
from ._api import WeightsEnum, Weights
from ._api import Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param

Expand Down Expand Up @@ -160,7 +160,22 @@ def shifted_window_attention(
return x


def _fix_window_and_shift_size(
input_size: List[int], window_size: List[int], shift_size: List[int]
) -> Tuple[List[int], List[int]]:
# Handle case where window_size is larger than input tensor
# Reference on the original implementation: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py#L192-L195
updated_window_size = window_size.copy()
updated_shift_size = shift_size.copy()
for i in range(len(input_size)):
if input_size[i] <= window_size[i]:
updated_window_size[i] = input_size[i]
updated_shift_size[i] = 0
return updated_window_size, updated_shift_size


torch.fx.wrap("shifted_window_attention")
torch.fx.wrap("_fix_window_and_shift_size")


class ShiftedWindowAttention(nn.Module):
Expand Down Expand Up @@ -218,8 +233,12 @@ def forward(self, x: Tensor):
Returns:
Tensor with same layout as input, i.e. [B, H, W, C]
"""
_, H, W, _ = x.shape
input_hw = [H, W]
# Handle case where the window_size is larger than the input
window_size, shift_size = _fix_window_and_shift_size(input_hw, self.window_size, self.shift_size)

N = self.window_size[0] * self.window_size[1]
N = window_size[0] * window_size[1]
relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index]
relative_position_bias = relative_position_bias.view(N, N, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
Expand All @@ -229,9 +248,9 @@ def forward(self, x: Tensor):
self.qkv.weight,
self.proj.weight,
relative_position_bias,
self.window_size,
window_size,
self.num_heads,
shift_size=self.shift_size,
shift_size=shift_size,
attention_dropout=self.attention_dropout,
dropout=self.dropout,
qkv_bias=self.qkv.bias,
Expand Down