Skip to content

Commit 4e76a42

Browse files
committed
Patch scripts for training dropblock resnet
1 parent 8d89128 commit 4e76a42

File tree

3 files changed

+26
-4
lines changed

3 files changed

+26
-4
lines changed

references/classification/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ torchrun --nproc_per_node=8 train.py --model $MODEL
5353

5454
Here `$MODEL` is one of `resnet18`, `resnet34`, `resnet50`, `resnet101` or `resnet152`.
5555

56+
### ResNet with dropblock
57+
```
58+
torchrun --nproc_per_node=8 train.py --model resnet50 -b 128 --lr 0.4 --epochs 270
59+
```
60+
61+
62+
5663
### ResNext
5764
```
5865
torchrun --nproc_per_node=8 train.py\

references/classification/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def main(args):
288288
f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
289289
"are supported."
290290
)
291+
main_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[125, 200, 250], gamma=0.1)
291292

292293
if args.lr_warmup_epochs > 0:
293294
if args.lr_warmup_method == "linear":

torchvision/models/resnet.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .._internally_replaced_utils import load_state_dict_from_url
88
from ..utils import _log_api_usage_once
9+
from ..ops import DropBlock2d
910

1011

1112
__all__ = [
@@ -122,6 +123,7 @@ def __init__(
122123
base_width: int = 64,
123124
dilation: int = 1,
124125
norm_layer: Optional[Callable[..., nn.Module]] = None,
126+
p: float = 0.0,
125127
) -> None:
126128
super().__init__()
127129
if norm_layer is None:
@@ -130,31 +132,40 @@ def __init__(
130132
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
131133
self.conv1 = conv1x1(inplanes, width)
132134
self.bn1 = norm_layer(width)
135+
# we won't be doing scheduled p
136+
self.drop1 = DropBlock2d(p, 7)
133137
self.conv2 = conv3x3(width, width, stride, groups, dilation)
134138
self.bn2 = norm_layer(width)
139+
self.drop2 = DropBlock2d(p, 7)
135140
self.conv3 = conv1x1(width, planes * self.expansion)
136141
self.bn3 = norm_layer(planes * self.expansion)
142+
self.drop3 = DropBlock2d(p, 7)
137143
self.relu = nn.ReLU(inplace=True)
138144
self.downsample = downsample
145+
self.drop4 = DropBlock2d(p, 7)
139146
self.stride = stride
140147

141148
def forward(self, x: Tensor) -> Tensor:
142149
identity = x
143-
150+
# as in https://github.com/tensorflow/tpu/blob/b24729de804fdb751b06467d3dce0637fa652060/models/official/resnet/resnet_model.py#L545-L579
144151
out = self.conv1(x)
145152
out = self.bn1(out)
146153
out = self.relu(out)
154+
out = self.drop1(out)
147155

148156
out = self.conv2(out)
149157
out = self.bn2(out)
150158
out = self.relu(out)
159+
out = self.drop2(out)
151160

152161
out = self.conv3(out)
153162
out = self.bn3(out)
163+
out = self.drop3(out)
154164

155165
if self.downsample is not None:
156166
identity = self.downsample(x)
157167

168+
identity = self.drop4(identity)
158169
out += identity
159170
out = self.relu(out)
160171

@@ -198,8 +209,9 @@ def __init__(
198209
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
199210
self.layer1 = self._make_layer(block, 64, layers[0])
200211
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
201-
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
202-
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
212+
# https://github.com/tensorflow/tpu/blob/b24729de804fdb751b06467d3dce0637fa652060/models/official/resnet/resnet_main.py#L393-L394
213+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1], p=0.1 / 4)
214+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2], p=0.1)
203215
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
204216
self.fc = nn.Linear(512 * block.expansion, num_classes)
205217

@@ -227,6 +239,7 @@ def _make_layer(
227239
blocks: int,
228240
stride: int = 1,
229241
dilate: bool = False,
242+
p: float = 0.0,
230243
) -> nn.Sequential:
231244
norm_layer = self._norm_layer
232245
downsample = None
@@ -243,7 +256,7 @@ def _make_layer(
243256
layers = []
244257
layers.append(
245258
block(
246-
self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
259+
self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer, p
247260
)
248261
)
249262
self.inplanes = planes * block.expansion
@@ -256,6 +269,7 @@ def _make_layer(
256269
base_width=self.base_width,
257270
dilation=self.dilation,
258271
norm_layer=norm_layer,
272+
p=p
259273
)
260274
)
261275

0 commit comments

Comments
 (0)