@@ -20,6 +20,18 @@ def divisible_by(val, d):
20
20
21
21
# helper classes
22
22
23
+ class ChanLayerNorm (nn .Module ):
24
+ def __init__ (self , dim , eps = 1e-5 ):
25
+ super ().__init__ ()
26
+ self .eps = eps
27
+ self .g = nn .Parameter (torch .ones (1 , dim , 1 , 1 ))
28
+ self .b = nn .Parameter (torch .zeros (1 , dim , 1 , 1 ))
29
+
30
+ def forward (self , x ):
31
+ var = torch .var (x , dim = 1 , unbiased = False , keepdim = True )
32
+ mean = torch .mean (x , dim = 1 , keepdim = True )
33
+ return (x - mean ) / (var + self .eps ).sqrt () * self .g + self .b
34
+
23
35
class Downsample (nn .Module ):
24
36
def __init__ (self , dim_in , dim_out ):
25
37
super ().__init__ ()
@@ -212,10 +224,10 @@ def __init__(
212
224
if tokenize_local_3_conv :
213
225
self .local_encoder = nn .Sequential (
214
226
nn .Conv2d (3 , init_dim , 3 , 2 , 1 ),
215
- nn . LayerNorm (init_dim ),
227
+ ChanLayerNorm (init_dim ),
216
228
nn .GELU (),
217
229
nn .Conv2d (init_dim , init_dim , 3 , 2 , 1 ),
218
- nn . LayerNorm (init_dim ),
230
+ ChanLayerNorm (init_dim ),
219
231
nn .GELU (),
220
232
nn .Conv2d (init_dim , init_dim , 3 , 1 , 1 )
221
233
)
0 commit comments