@@ -64,9 +64,10 @@ def __init__(self, inplanes, increase_dim=False, last_relu=False, downsampling="
64
64
if increase_dim :
65
65
if downsampling == "stride" :
66
66
self .downsampler = DownsampleStride ()
67
- self .downsample = lambda x : self . pad ( self . downsampler ( x ))
67
+ self ._need_pad = True
68
68
else :
69
- self .downsample = DownsampleConv (inplanes , planes )
69
+ self .downsampler = DownsampleConv (inplanes , planes )
70
+ self ._need_pad = False
70
71
71
72
self .last_relu = last_relu
72
73
@@ -83,7 +84,9 @@ def forward(self, x):
83
84
y = self .bn_b (y )
84
85
85
86
if self .increase_dim :
86
- x = self .downsample (x )
87
+ x = self .downsampler (x )
88
+ if self ._need_pad :
89
+ x = self .pad (x )
87
90
88
91
y = x + y
89
92
@@ -212,7 +215,7 @@ def __init__(
212
215
)
213
216
214
217
if pooling_config ["type" ] == "avg" :
215
- self .pool = nn .AvgPool2d ( 8 )
218
+ self .pool = nn .AdaptiveAvgPool2d (( 1 , 1 ) )
216
219
elif pooling_config ["type" ] == "weldon" :
217
220
self .pool = pooling .WeldonPool2d (** pooling_config )
218
221
else :
@@ -222,22 +225,33 @@ def __init__(
222
225
if final_layer in (True , "conv" ):
223
226
self .final_layer = nn .Conv2d (self .out_dim , self .out_dim , kernel_size = 1 , bias = False )
224
227
elif isinstance (final_layer , dict ):
225
- if final_layer ["type" ] == "bn_relu_fc " :
228
+ if final_layer ["type" ] == "one_layer " :
226
229
self .final_layer = nn .Sequential (
227
- nn .BatchNorm1d (self .out_dim ), nn .ReLU (),
230
+ nn .BatchNorm1d (self .out_dim ), nn .ReLU (inplace = True ),
228
231
nn .Linear (self .out_dim , int (self .out_dim * final_layer ["reduction_factor" ]))
229
232
)
233
+ self .out_dim = int (self .out_dim * final_layer ["reduction_factor" ])
234
+ elif final_layer ["type" ] == "two_layers" :
235
+ self .final_layer = nn .Sequential (
236
+ nn .BatchNorm1d (self .out_dim ), nn .ReLU (inplace = True ),
237
+ nn .Linear (self .out_dim , self .out_dim ), nn .BatchNorm1d (self .out_dim ),
238
+ nn .ReLU (inplace = True ),
239
+ nn .Linear (self .out_dim , int (self .out_dim * final_layer ["reduction_factor" ]))
240
+ )
241
+ self .out_dim = int (self .out_dim * final_layer ["reduction_factor" ])
230
242
else :
231
243
raise ValueError ("Unknown final layer type {}." .format (final_layer ["type" ]))
232
244
else :
233
- self .final_layer = lambda x : x
245
+ self .final_layer = None
234
246
235
247
for m in self .modules ():
236
248
if isinstance (m , nn .Conv2d ):
237
249
nn .init .kaiming_normal_ (m .weight , mode = 'fan_out' , nonlinearity = 'relu' )
238
250
elif isinstance (m , nn .BatchNorm2d ):
239
251
nn .init .constant_ (m .weight , 1 )
240
252
nn .init .constant_ (m .bias , 0 )
253
+ elif isinstance (m , nn .Linear ):
254
+ nn .init .kaiming_normal_ (m .weight , mode = "fan_out" , nonlinearity = "relu" )
241
255
242
256
if zero_residual :
243
257
for m in self .modules ():
@@ -289,7 +303,9 @@ def forward(self, x):
289
303
def end_features (self , x ):
290
304
x = self .pool (x )
291
305
x = x .view (x .size (0 ), - 1 )
292
- x = self .final_layer (x )
306
+
307
+ if self .final_layer is not None :
308
+ x = self .final_layer (x )
293
309
294
310
return x
295
311
0 commit comments