2
2
from basicsr .utils .registry import ARCH_REGISTRY
3
3
4
4
5
- def conv3x3 (in_planes , out_planes , stride = 1 ):
6
- """3x3 convolution with padding"""
7
- return nn .Conv2d (in_planes , out_planes , kernel_size = 3 , stride = stride , padding = 1 , bias = False )
5
+ def conv3x3 (inplanes , outplanes , stride = 1 ):
6
+ """A simple wrapper for 3x3 convolution with padding.
7
+
8
+ Args:
9
+ inplanes (int): Channel number of inputs.
10
+ outplanes (int): Channel number of outputs.
11
+ stride (int): Stride in convolution. Default: 1.
12
+ """
13
+ return nn .Conv2d (inplanes , outplanes , kernel_size = 3 , stride = stride , padding = 1 , bias = False )
8
14
9
15
10
16
class BasicBlock (nn .Module ):
11
- expansion = 1
17
+ """Basic residual block used in the ResNetArcFace architecture.
18
+
19
+ Args:
20
+ inplanes (int): Channel number of inputs.
21
+ planes (int): Channel number of outputs.
22
+ stride (int): Stride in convolution. Default: 1.
23
+ downsample (nn.Module): The downsample module. Default: None.
24
+ """
25
+ expansion = 1 # output channel expansion ratio
12
26
13
27
def __init__ (self , inplanes , planes , stride = 1 , downsample = None ):
14
28
super (BasicBlock , self ).__init__ ()
@@ -40,7 +54,16 @@ def forward(self, x):
40
54
41
55
42
56
class IRBlock (nn .Module ):
43
- expansion = 1
57
+ """Improved residual block (IR Block) used in the ResNetArcFace architecture.
58
+
59
+ Args:
60
+ inplanes (int): Channel number of inputs.
61
+ planes (int): Channel number of outputs.
62
+ stride (int): Stride in convolution. Default: 1.
63
+ downsample (nn.Module): The downsample module. Default: None.
64
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
65
+ """
66
+ expansion = 1 # output channel expansion ratio
44
67
45
68
def __init__ (self , inplanes , planes , stride = 1 , downsample = None , use_se = True ):
46
69
super (IRBlock , self ).__init__ ()
@@ -78,7 +101,15 @@ def forward(self, x):
78
101
79
102
80
103
class Bottleneck (nn .Module ):
81
- expansion = 4
104
+ """Bottleneck block used in the ResNetArcFace architecture.
105
+
106
+ Args:
107
+ inplanes (int): Channel number of inputs.
108
+ planes (int): Channel number of outputs.
109
+ stride (int): Stride in convolution. Default: 1.
110
+ downsample (nn.Module): The downsample module. Default: None.
111
+ """
112
+ expansion = 4 # output channel expansion ratio
82
113
83
114
def __init__ (self , inplanes , planes , stride = 1 , downsample = None ):
84
115
super (Bottleneck , self ).__init__ ()
@@ -116,10 +147,16 @@ def forward(self, x):
116
147
117
148
118
149
class SEBlock (nn .Module ):
150
+ """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
151
+
152
+ Args:
153
+ channel (int): Channel number of inputs.
154
+ reduction (int): Channel reduction ration. Default: 16.
155
+ """
119
156
120
157
def __init__ (self , channel , reduction = 16 ):
121
158
super (SEBlock , self ).__init__ ()
122
- self .avg_pool = nn .AdaptiveAvgPool2d (1 )
159
+ self .avg_pool = nn .AdaptiveAvgPool2d (1 ) # pool to 1x1 without spatial information
123
160
self .fc = nn .Sequential (
124
161
nn .Linear (channel , channel // reduction ), nn .PReLU (), nn .Linear (channel // reduction , channel ),
125
162
nn .Sigmoid ())
@@ -133,13 +170,23 @@ def forward(self, x):
133
170
134
171
@ARCH_REGISTRY .register ()
135
172
class ResNetArcFace (nn .Module ):
173
+ """ArcFace with ResNet architectures.
174
+
175
+ Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
176
+
177
+ Args:
178
+ block (str): Block used in the ArcFace architecture.
179
+ layers (tuple(int)): Block numbers in each layer.
180
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
181
+ """
136
182
137
183
def __init__ (self , block , layers , use_se = True ):
138
184
if block == 'IRBlock' :
139
185
block = IRBlock
140
186
self .inplanes = 64
141
187
self .use_se = use_se
142
188
super (ResNetArcFace , self ).__init__ ()
189
+
143
190
self .conv1 = nn .Conv2d (1 , 64 , kernel_size = 3 , padding = 1 , bias = False )
144
191
self .bn1 = nn .BatchNorm2d (64 )
145
192
self .prelu = nn .PReLU ()
@@ -153,6 +200,7 @@ def __init__(self, block, layers, use_se=True):
153
200
self .fc5 = nn .Linear (512 * 8 * 8 , 512 )
154
201
self .bn5 = nn .BatchNorm1d (512 )
155
202
203
+ # initialization
156
204
for m in self .modules ():
157
205
if isinstance (m , nn .Conv2d ):
158
206
nn .init .xavier_normal_ (m .weight )
@@ -163,7 +211,7 @@ def __init__(self, block, layers, use_se=True):
163
211
nn .init .xavier_normal_ (m .weight )
164
212
nn .init .constant_ (m .bias , 0 )
165
213
166
- def _make_layer (self , block , planes , blocks , stride = 1 ):
214
+ def _make_layer (self , block , planes , num_blocks , stride = 1 ):
167
215
downsample = None
168
216
if stride != 1 or self .inplanes != planes * block .expansion :
169
217
downsample = nn .Sequential (
@@ -173,7 +221,7 @@ def _make_layer(self, block, planes, blocks, stride=1):
173
221
layers = []
174
222
layers .append (block (self .inplanes , planes , stride , downsample , use_se = self .use_se ))
175
223
self .inplanes = planes
176
- for _ in range (1 , blocks ):
224
+ for _ in range (1 , num_blocks ):
177
225
layers .append (block (self .inplanes , planes , use_se = self .use_se ))
178
226
179
227
return nn .Sequential (* layers )
0 commit comments