Skip to content

Commit 034c401

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Additional SOTA ingredients on Classification Recipe (#4493)
Summary: * Update EMA every X iters. * Adding AdamW optimizer. * Adjusting EMA decay scheme. * Support custom weight decay for Normalization layers. * Fix identation bug. * Change EMA adjustment. * Quality of life changes to faciliate testing * ufmt format * Fixing imports. * Adding FixRes improvement. * Support EMA in store_model_weights. * Adding interpolation values. * Change train_crop_size. * Add interpolation option. * Removing hardcoded interpolation and sizes from the scripts. * Fixing linter. * Incorporating feedback from code review. Reviewed By: NicolasHug Differential Revision: D31916313 fbshipit-source-id: 6136c02dd6d511d0f327b5a72c9056a134abc697
1 parent e6587c2 commit 034c401

File tree

7 files changed

+169
-54
lines changed

7 files changed

+169
-54
lines changed

references/classification/README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,17 @@ Here `$MODEL` is one of `alexnet`, `vgg11`, `vgg13`, `vgg16` or `vgg19`. Note
3131
that `vgg11_bn`, `vgg13_bn`, `vgg16_bn`, and `vgg19_bn` include batch
3232
normalization and thus are trained with the default parameters.
3333

34+
### Inception V3
35+
36+
The weights of the Inception V3 model are ported from the original paper rather than trained from scratch.
37+
38+
Since it expects tensors with a size of N x 3 x 299 x 299, to validate the model use the following command:
39+
40+
```
41+
torchrun --nproc_per_node=8 train.py --model inception_v3
42+
--val-resize-size 342 --val-crop-size 299 --train-crop-size 299 --test-only --pretrained
43+
```
44+
3445
### ResNext-50 32x4d
3546
```
3647
torchrun --nproc_per_node=8 train.py\
@@ -79,6 +90,25 @@ The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](ht
7990

8091
The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTorch repo](https://github.com/lukemelas/EfficientNet-PyTorch/blob/1039e009545d9329ea026c9f7541341439712b96/efficientnet_pytorch/utils.py#L562-L564).
8192

93+
All models were trained using Bicubic interpolation and each have custom crop and resize sizes. To validate the models use the following commands:
94+
```
95+
torchrun --nproc_per_node=8 train.py --model efficientnet_b0 --interpolation bicubic\
96+
--val-resize-size 256 --val-crop-size 224 --train-crop-size 224 --test-only --pretrained
97+
torchrun --nproc_per_node=8 train.py --model efficientnet_b1 --interpolation bicubic\
98+
--val-resize-size 256 --val-crop-size 240 --train-crop-size 240 --test-only --pretrained
99+
torchrun --nproc_per_node=8 train.py --model efficientnet_b2 --interpolation bicubic\
100+
--val-resize-size 288 --val-crop-size 288 --train-crop-size 288 --test-only --pretrained
101+
torchrun --nproc_per_node=8 train.py --model efficientnet_b3 --interpolation bicubic\
102+
--val-resize-size 320 --val-crop-size 300 --train-crop-size 300 --test-only --pretrained
103+
torchrun --nproc_per_node=8 train.py --model efficientnet_b4 --interpolation bicubic\
104+
--val-resize-size 384 --val-crop-size 380 --train-crop-size 380 --test-only --pretrained
105+
torchrun --nproc_per_node=8 train.py --model efficientnet_b5 --interpolation bicubic\
106+
--val-resize-size 456 --val-crop-size 456 --train-crop-size 456 --test-only --pretrained
107+
torchrun --nproc_per_node=8 train.py --model efficientnet_b6 --interpolation bicubic\
108+
--val-resize-size 528 --val-crop-size 528 --train-crop-size 528 --test-only --pretrained
109+
torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --interpolation bicubic\
110+
--val-resize-size 600 --val-crop-size 600 --train-crop-size 600 --test-only --pretrained
111+
```
82112

83113
### RegNet
84114

@@ -181,3 +211,8 @@ For post training quant, device is set to CPU. For training, the device is set t
181211
```
182212
python train_quantization.py --device='cpu' --test-only --backend='<backend>' --model='<model_name>'
183213
```
214+
215+
For inception_v3 you need to pass the following extra parameters:
216+
```
217+
--val-resize-size 342 --val-crop-size 299 --train-crop-size 299
218+
```

references/classification/presets.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,22 @@ def __init__(
99
crop_size,
1010
mean=(0.485, 0.456, 0.406),
1111
std=(0.229, 0.224, 0.225),
12+
interpolation=InterpolationMode.BILINEAR,
1213
hflip_prob=0.5,
1314
auto_augment_policy=None,
1415
random_erase_prob=0.0,
1516
):
16-
trans = [transforms.RandomResizedCrop(crop_size)]
17+
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
1718
if hflip_prob > 0:
1819
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
1920
if auto_augment_policy is not None:
2021
if auto_augment_policy == "ra":
21-
trans.append(autoaugment.RandAugment())
22+
trans.append(autoaugment.RandAugment(interpolation=interpolation))
2223
elif auto_augment_policy == "ta_wide":
23-
trans.append(autoaugment.TrivialAugmentWide())
24+
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
2425
else:
2526
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
26-
trans.append(autoaugment.AutoAugment(policy=aa_policy))
27+
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
2728
trans.extend(
2829
[
2930
transforms.PILToTensor(),

references/classification/train.py

Lines changed: 74 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,20 @@
1414
from torchvision.transforms.functional import InterpolationMode
1515

1616

17-
def train_one_epoch(
18-
model, criterion, optimizer, data_loader, device, epoch, print_freq, amp=False, model_ema=None, scaler=None
19-
):
17+
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
2018
model.train()
2119
metric_logger = utils.MetricLogger(delimiter=" ")
2220
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
2321
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
2422

2523
header = "Epoch: [{}]".format(epoch)
26-
for image, target in metric_logger.log_every(data_loader, print_freq, header):
24+
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
2725
start_time = time.time()
2826
image, target = image.to(device), target.to(device)
2927
output = model(image)
3028

3129
optimizer.zero_grad()
32-
if amp:
30+
if args.amp:
3331
with torch.cuda.amp.autocast():
3432
loss = criterion(output, target)
3533
scaler.scale(loss).backward()
@@ -40,16 +38,19 @@ def train_one_epoch(
4038
loss.backward()
4139
optimizer.step()
4240

41+
if model_ema and i % args.model_ema_steps == 0:
42+
model_ema.update_parameters(model)
43+
if epoch < args.lr_warmup_epochs:
44+
# Reset ema buffer to keep copying weights during warmup period
45+
model_ema.n_averaged.fill_(0)
46+
4347
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
4448
batch_size = image.shape[0]
4549
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
4650
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
4751
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
4852
metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
4953

50-
if model_ema:
51-
model_ema.update_parameters(model)
52-
5354

5455
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
5556
model.eval()
@@ -106,24 +107,8 @@ def _get_cache_path(filepath):
106107
def load_data(traindir, valdir, args):
107108
# Data loading code
108109
print("Loading data")
109-
resize_size, crop_size = 256, 224
110-
interpolation = InterpolationMode.BILINEAR
111-
if args.model == "inception_v3":
112-
resize_size, crop_size = 342, 299
113-
elif args.model.startswith("efficientnet_"):
114-
sizes = {
115-
"b0": (256, 224),
116-
"b1": (256, 240),
117-
"b2": (288, 288),
118-
"b3": (320, 300),
119-
"b4": (384, 380),
120-
"b5": (456, 456),
121-
"b6": (528, 528),
122-
"b7": (600, 600),
123-
}
124-
e_type = args.model.replace("efficientnet_", "")
125-
resize_size, crop_size = sizes[e_type]
126-
interpolation = InterpolationMode.BICUBIC
110+
val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
111+
interpolation = InterpolationMode(args.interpolation)
127112

128113
print("Loading training data")
129114
st = time.time()
@@ -138,7 +123,10 @@ def load_data(traindir, valdir, args):
138123
dataset = torchvision.datasets.ImageFolder(
139124
traindir,
140125
presets.ClassificationPresetTrain(
141-
crop_size=crop_size, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob
126+
crop_size=train_crop_size,
127+
interpolation=interpolation,
128+
auto_augment_policy=auto_augment_policy,
129+
random_erase_prob=random_erase_prob,
142130
),
143131
)
144132
if args.cache_dataset:
@@ -156,7 +144,9 @@ def load_data(traindir, valdir, args):
156144
else:
157145
dataset_test = torchvision.datasets.ImageFolder(
158146
valdir,
159-
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size, interpolation=interpolation),
147+
presets.ClassificationPresetEval(
148+
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
149+
),
160150
)
161151
if args.cache_dataset:
162152
print("Saving dataset_test to {}".format(cache_path))
@@ -224,26 +214,30 @@ def main(args):
224214

225215
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
226216

217+
if args.norm_weight_decay is None:
218+
parameters = model.parameters()
219+
else:
220+
param_groups = torchvision.ops._utils.split_normalization_params(model)
221+
wd_groups = [args.norm_weight_decay, args.weight_decay]
222+
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]
223+
227224
opt_name = args.opt.lower()
228225
if opt_name.startswith("sgd"):
229226
optimizer = torch.optim.SGD(
230-
model.parameters(),
227+
parameters,
231228
lr=args.lr,
232229
momentum=args.momentum,
233230
weight_decay=args.weight_decay,
234231
nesterov="nesterov" in opt_name,
235232
)
236233
elif opt_name == "rmsprop":
237234
optimizer = torch.optim.RMSprop(
238-
model.parameters(),
239-
lr=args.lr,
240-
momentum=args.momentum,
241-
weight_decay=args.weight_decay,
242-
eps=0.0316,
243-
alpha=0.9,
235+
parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
244236
)
237+
elif opt_name == "adamw":
238+
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
245239
else:
246-
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))
240+
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
247241

248242
scaler = torch.cuda.amp.GradScaler() if args.amp else None
249243

@@ -288,13 +282,23 @@ def main(args):
288282

289283
model_ema = None
290284
if args.model_ema:
291-
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay)
285+
# Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
286+
# https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
287+
#
288+
# total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
289+
# We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
290+
# adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
291+
adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
292+
alpha = 1.0 - args.model_ema_decay
293+
alpha = min(1.0, alpha * adjust)
294+
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
292295

293296
if args.resume:
294297
checkpoint = torch.load(args.resume, map_location="cpu")
295298
model_without_ddp.load_state_dict(checkpoint["model"])
296-
optimizer.load_state_dict(checkpoint["optimizer"])
297-
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
299+
if not args.test_only:
300+
optimizer.load_state_dict(checkpoint["optimizer"])
301+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
298302
args.start_epoch = checkpoint["epoch"] + 1
299303
if model_ema:
300304
model_ema.load_state_dict(checkpoint["model_ema"])
@@ -303,18 +307,18 @@ def main(args):
303307
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
304308
torch.backends.cudnn.benchmark = False
305309
torch.backends.cudnn.deterministic = True
306-
307-
evaluate(model, criterion, data_loader_test, device=device)
310+
if model_ema:
311+
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
312+
else:
313+
evaluate(model, criterion, data_loader_test, device=device)
308314
return
309315

310316
print("Start training")
311317
start_time = time.time()
312318
for epoch in range(args.start_epoch, args.epochs):
313319
if args.distributed:
314320
train_sampler.set_epoch(epoch)
315-
train_one_epoch(
316-
model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.amp, model_ema, scaler
317-
)
321+
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
318322
lr_scheduler.step()
319323
evaluate(model, criterion, data_loader_test, device=device)
320324
if model_ema:
@@ -362,6 +366,12 @@ def get_args_parser(add_help=True):
362366
help="weight decay (default: 1e-4)",
363367
dest="weight_decay",
364368
)
369+
parser.add_argument(
370+
"--norm-weight-decay",
371+
default=None,
372+
type=float,
373+
help="weight decay for Normalization layers (default: None, same value as --wd)",
374+
)
365375
parser.add_argument(
366376
"--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
367377
)
@@ -415,15 +425,33 @@ def get_args_parser(add_help=True):
415425
parser.add_argument(
416426
"--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
417427
)
428+
parser.add_argument(
429+
"--model-ema-steps",
430+
type=int,
431+
default=32,
432+
help="the number of iterations that controls how often to update the EMA model (default: 32)",
433+
)
418434
parser.add_argument(
419435
"--model-ema-decay",
420436
type=float,
421-
default=0.9,
422-
help="decay factor for Exponential Moving Average of model parameters(default: 0.9)",
437+
default=0.99998,
438+
help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
423439
)
424440
parser.add_argument(
425441
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
426442
)
443+
parser.add_argument(
444+
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
445+
)
446+
parser.add_argument(
447+
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
448+
)
449+
parser.add_argument(
450+
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
451+
)
452+
parser.add_argument(
453+
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
454+
)
427455

428456
return parser
429457

references/classification/train_quantization.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,19 @@ def get_args_parser(add_help=True):
236236
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
237237
parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training")
238238

239+
parser.add_argument(
240+
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
241+
)
242+
parser.add_argument(
243+
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
244+
)
245+
parser.add_argument(
246+
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
247+
)
248+
parser.add_argument(
249+
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
250+
)
251+
239252
return parser
240253

241254

references/classification/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,9 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T
380380

381381
# Load the weights to the model to validate that everything works
382382
# and remove unnecessary weights (such as auxiliaries, etc)
383+
if checkpoint_key == "model_ema":
384+
del checkpoint[checkpoint_key]["n_averaged"]
385+
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.")
383386
model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
384387

385388
tmp_path = os.path.join(output_dir, str(model.__hash__()))

test/test_ops.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
import torch
1010
from common_utils import needs_cuda, cpu_and_gpu, assert_equal
1111
from PIL import Image
12-
from torch import Tensor
12+
from torch import nn, Tensor
1313
from torch.autograd import gradcheck
1414
from torch.nn.modules.utils import _pair
15-
from torchvision import ops
15+
from torchvision import models, ops
1616

1717

1818
class RoIOpTester(ABC):
@@ -1176,5 +1176,15 @@ def test_stochastic_depth(self, mode, p):
11761176
assert p_value > 0.0001
11771177

11781178

1179+
class TestUtils:
1180+
@pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm])
1181+
def test_split_normalization_params(self, norm_layer):
1182+
model = models.mobilenet_v3_large(norm_layer=norm_layer)
1183+
params = ops._utils.split_normalization_params(model, None if norm_layer is None else [norm_layer])
1184+
1185+
assert len(params[0]) == 92
1186+
assert len(params[1]) == 82
1187+
1188+
11791189
if __name__ == "__main__":
11801190
pytest.main([__file__])

0 commit comments

Comments
 (0)