Skip to content

Commit 13f8e12

Browse files
committed
fix maxvit - need feedforwards after attention
1 parent 2d4089c commit 13f8e12

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.33.0',
6+
version = '0.33.1',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
author = 'Phil Wang',

vit_pytorch/max_vit.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@ def __init__(self, dim, fn):
2828
def forward(self, x):
2929
return self.fn(self.norm(x)) + x
3030

31+
class FeedForward(nn.Module):
32+
def __init__(self, dim, mult = 4, dropout = 0.):
33+
super().__init__()
34+
inner_dim = int(dim * mult)
35+
self.net = nn.Sequential(
36+
nn.Linear(dim, inner_dim),
37+
nn.GELU(),
38+
nn.Dropout(dropout),
39+
nn.Linear(inner_dim, dim),
40+
nn.Dropout(dropout)
41+
)
42+
def forward(self, x):
43+
return self.net(x)
44+
3145
# MBConv
3246

3347
class SqueezeExcitation(nn.Module):
@@ -244,10 +258,12 @@ def __init__(
244258
),
245259
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention
246260
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
261+
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
247262
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),
248263

249264
Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention
250265
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
266+
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
251267
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
252268
)
253269

0 commit comments

Comments
 (0)