Skip to content

Commit e75a333

Browse files
authored
Move out the pad operation from PatchMerging in swin transformer to make it fx compatible (#6252)
1 parent f14682a commit e75a333

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

torchvision/models/swin_transformer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@
2525
]
2626

2727

28+
def _patch_merging_pad(x):
29+
H, W, _ = x.shape[-3:]
30+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
31+
return x
32+
33+
34+
torch.fx.wrap("_patch_merging_pad")
35+
36+
2837
class PatchMerging(nn.Module):
2938
"""Patch Merging Layer.
3039
Args:
@@ -46,8 +55,7 @@ def forward(self, x: Tensor):
4655
Returns:
4756
Tensor with layout of [..., H/2, W/2, 2*C]
4857
"""
49-
H, W, _ = x.shape[-3:]
50-
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
58+
x = _patch_merging_pad(x)
5159

5260
x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
5361
x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C

0 commit comments

Comments
 (0)