6
6
7
7
from .._internally_replaced_utils import load_state_dict_from_url
8
8
from ..utils import _log_api_usage_once
9
+ from ..ops import DropBlock2d
9
10
10
11
11
12
__all__ = [
@@ -122,6 +123,7 @@ def __init__(
122
123
base_width : int = 64 ,
123
124
dilation : int = 1 ,
124
125
norm_layer : Optional [Callable [..., nn .Module ]] = None ,
126
+ p : float = 0.0 ,
125
127
) -> None :
126
128
super ().__init__ ()
127
129
if norm_layer is None :
@@ -130,31 +132,40 @@ def __init__(
130
132
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
131
133
self .conv1 = conv1x1 (inplanes , width )
132
134
self .bn1 = norm_layer (width )
135
+ # we won't be doing scheduled p
136
+ self .drop1 = DropBlock2d (p , 7 )
133
137
self .conv2 = conv3x3 (width , width , stride , groups , dilation )
134
138
self .bn2 = norm_layer (width )
139
+ self .drop2 = DropBlock2d (p , 7 )
135
140
self .conv3 = conv1x1 (width , planes * self .expansion )
136
141
self .bn3 = norm_layer (planes * self .expansion )
142
+ self .drop3 = DropBlock2d (p , 7 )
137
143
self .relu = nn .ReLU (inplace = True )
138
144
self .downsample = downsample
145
+ self .drop4 = DropBlock2d (p , 7 )
139
146
self .stride = stride
140
147
141
148
def forward (self , x : Tensor ) -> Tensor :
142
149
identity = x
143
-
150
+ # as in https://github.com/tensorflow/tpu/blob/b24729de804fdb751b06467d3dce0637fa652060/models/official/resnet/resnet_model.py#L545-L579
144
151
out = self .conv1 (x )
145
152
out = self .bn1 (out )
146
153
out = self .relu (out )
154
+ out = self .drop1 (out )
147
155
148
156
out = self .conv2 (out )
149
157
out = self .bn2 (out )
150
158
out = self .relu (out )
159
+ out = self .drop2 (out )
151
160
152
161
out = self .conv3 (out )
153
162
out = self .bn3 (out )
163
+ out = self .drop3 (out )
154
164
155
165
if self .downsample is not None :
156
166
identity = self .downsample (x )
157
167
168
+ identity = self .drop4 (identity )
158
169
out += identity
159
170
out = self .relu (out )
160
171
@@ -198,8 +209,9 @@ def __init__(
198
209
self .maxpool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
199
210
self .layer1 = self ._make_layer (block , 64 , layers [0 ])
200
211
self .layer2 = self ._make_layer (block , 128 , layers [1 ], stride = 2 , dilate = replace_stride_with_dilation [0 ])
201
- self .layer3 = self ._make_layer (block , 256 , layers [2 ], stride = 2 , dilate = replace_stride_with_dilation [1 ])
202
- self .layer4 = self ._make_layer (block , 512 , layers [3 ], stride = 2 , dilate = replace_stride_with_dilation [2 ])
212
+ # https://github.com/tensorflow/tpu/blob/b24729de804fdb751b06467d3dce0637fa652060/models/official/resnet/resnet_main.py#L393-L394
213
+ self .layer3 = self ._make_layer (block , 256 , layers [2 ], stride = 2 , dilate = replace_stride_with_dilation [1 ], p = 0.1 / 4 )
214
+ self .layer4 = self ._make_layer (block , 512 , layers [3 ], stride = 2 , dilate = replace_stride_with_dilation [2 ], p = 0.1 )
203
215
self .avgpool = nn .AdaptiveAvgPool2d ((1 , 1 ))
204
216
self .fc = nn .Linear (512 * block .expansion , num_classes )
205
217
@@ -227,6 +239,7 @@ def _make_layer(
227
239
blocks : int ,
228
240
stride : int = 1 ,
229
241
dilate : bool = False ,
242
+ p : float = 0.0 ,
230
243
) -> nn .Sequential :
231
244
norm_layer = self ._norm_layer
232
245
downsample = None
@@ -243,7 +256,7 @@ def _make_layer(
243
256
layers = []
244
257
layers .append (
245
258
block (
246
- self .inplanes , planes , stride , downsample , self .groups , self .base_width , previous_dilation , norm_layer
259
+ self .inplanes , planes , stride , downsample , self .groups , self .base_width , previous_dilation , norm_layer , p
247
260
)
248
261
)
249
262
self .inplanes = planes * block .expansion
@@ -256,6 +269,7 @@ def _make_layer(
256
269
base_width = self .base_width ,
257
270
dilation = self .dilation ,
258
271
norm_layer = norm_layer ,
272
+ p = p
259
273
)
260
274
)
261
275
0 commit comments