@@ -28,6 +28,20 @@ def __init__(self, dim, fn):
28
28
def forward (self , x ):
29
29
return self .fn (self .norm (x )) + x
30
30
31
+ class FeedForward (nn .Module ):
32
+ def __init__ (self , dim , mult = 4 , dropout = 0. ):
33
+ super ().__init__ ()
34
+ inner_dim = int (dim * mult )
35
+ self .net = nn .Sequential (
36
+ nn .Linear (dim , inner_dim ),
37
+ nn .GELU (),
38
+ nn .Dropout (dropout ),
39
+ nn .Linear (inner_dim , dim ),
40
+ nn .Dropout (dropout )
41
+ )
42
+ def forward (self , x ):
43
+ return self .net (x )
44
+
31
45
# MBConv
32
46
33
47
class SqueezeExcitation (nn .Module ):
@@ -244,10 +258,12 @@ def __init__(
244
258
),
245
259
Rearrange ('b d (x w1) (y w2) -> b x y w1 w2 d' , w1 = w , w2 = w ), # block-like attention
246
260
PreNormResidual (layer_dim , Attention (dim = layer_dim , dim_head = dim_head , dropout = dropout , window_size = w )),
261
+ PreNormResidual (layer_dim , FeedForward (dim = layer_dim , dropout = dropout )),
247
262
Rearrange ('b x y w1 w2 d -> b d (x w1) (y w2)' ),
248
263
249
264
Rearrange ('b d (w1 x) (w2 y) -> b x y w1 w2 d' , w1 = w , w2 = w ), # grid-like attention
250
265
PreNormResidual (layer_dim , Attention (dim = layer_dim , dim_head = dim_head , dropout = dropout , window_size = w )),
266
+ PreNormResidual (layer_dim , FeedForward (dim = layer_dim , dropout = dropout )),
251
267
Rearrange ('b x y w1 w2 d -> b d (w1 x) (w2 y)' ),
252
268
)
253
269
0 commit comments