Skip to content

Commit aa79a1f

Browse files
Commit end rush N
1 parent 57b3586 commit aa79a1f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+6442
-1181
lines changed

inclearn/__init__.pyc

243 Bytes
Binary file not shown.

inclearn/convnet/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1-
from . import (cifar_resnet, densenet, my_resnet, my_resnet2, my_resnet_brn, resnet, ucir_resnet)
1+
from . import (
2+
cifar_resnet, densenet, my_resnet, my_resnet2, my_resnet_brn, my_resnet_mcbn, my_resnet_mtl,
3+
resnet, resnet_mtl, ucir_resnet, vgg
4+
)

inclearn/convnet/my_resnet.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,10 @@ def __init__(self, inplanes, increase_dim=False, last_relu=False, downsampling="
6464
if increase_dim:
6565
if downsampling == "stride":
6666
self.downsampler = DownsampleStride()
67-
self.downsample = lambda x: self.pad(self.downsampler(x))
67+
self._need_pad = True
6868
else:
69-
self.downsample = DownsampleConv(inplanes, planes)
69+
self.downsampler = DownsampleConv(inplanes, planes)
70+
self._need_pad = False
7071

7172
self.last_relu = last_relu
7273

@@ -83,7 +84,9 @@ def forward(self, x):
8384
y = self.bn_b(y)
8485

8586
if self.increase_dim:
86-
x = self.downsample(x)
87+
x = self.downsampler(x)
88+
if self._need_pad:
89+
x = self.pad(x)
8790

8891
y = x + y
8992

@@ -212,7 +215,7 @@ def __init__(
212215
)
213216

214217
if pooling_config["type"] == "avg":
215-
self.pool = nn.AvgPool2d(8)
218+
self.pool = nn.AdaptiveAvgPool2d((1, 1))
216219
elif pooling_config["type"] == "weldon":
217220
self.pool = pooling.WeldonPool2d(**pooling_config)
218221
else:
@@ -222,22 +225,33 @@ def __init__(
222225
if final_layer in (True, "conv"):
223226
self.final_layer = nn.Conv2d(self.out_dim, self.out_dim, kernel_size=1, bias=False)
224227
elif isinstance(final_layer, dict):
225-
if final_layer["type"] == "bn_relu_fc":
228+
if final_layer["type"] == "one_layer":
226229
self.final_layer = nn.Sequential(
227-
nn.BatchNorm1d(self.out_dim), nn.ReLU(),
230+
nn.BatchNorm1d(self.out_dim), nn.ReLU(inplace=True),
228231
nn.Linear(self.out_dim, int(self.out_dim * final_layer["reduction_factor"]))
229232
)
233+
self.out_dim = int(self.out_dim * final_layer["reduction_factor"])
234+
elif final_layer["type"] == "two_layers":
235+
self.final_layer = nn.Sequential(
236+
nn.BatchNorm1d(self.out_dim), nn.ReLU(inplace=True),
237+
nn.Linear(self.out_dim, self.out_dim), nn.BatchNorm1d(self.out_dim),
238+
nn.ReLU(inplace=True),
239+
nn.Linear(self.out_dim, int(self.out_dim * final_layer["reduction_factor"]))
240+
)
241+
self.out_dim = int(self.out_dim * final_layer["reduction_factor"])
230242
else:
231243
raise ValueError("Unknown final layer type {}.".format(final_layer["type"]))
232244
else:
233-
self.final_layer = lambda x: x
245+
self.final_layer = None
234246

235247
for m in self.modules():
236248
if isinstance(m, nn.Conv2d):
237249
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
238250
elif isinstance(m, nn.BatchNorm2d):
239251
nn.init.constant_(m.weight, 1)
240252
nn.init.constant_(m.bias, 0)
253+
elif isinstance(m, nn.Linear):
254+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
241255

242256
if zero_residual:
243257
for m in self.modules():
@@ -289,7 +303,9 @@ def forward(self, x):
289303
def end_features(self, x):
290304
x = self.pool(x)
291305
x = x.view(x.size(0), -1)
292-
x = self.final_layer(x)
306+
307+
if self.final_layer is not None:
308+
x = self.final_layer(x)
293309

294310
return x
295311

0 commit comments

Comments
 (0)