14
14
from torchvision .transforms .functional import InterpolationMode
15
15
16
16
17
- def train_one_epoch (
18
- model , criterion , optimizer , data_loader , device , epoch , print_freq , amp = False , model_ema = None , scaler = None
19
- ):
17
+ def train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , args , model_ema = None , scaler = None ):
20
18
model .train ()
21
19
metric_logger = utils .MetricLogger (delimiter = " " )
22
20
metric_logger .add_meter ("lr" , utils .SmoothedValue (window_size = 1 , fmt = "{value}" ))
23
21
metric_logger .add_meter ("img/s" , utils .SmoothedValue (window_size = 10 , fmt = "{value}" ))
24
22
25
23
header = "Epoch: [{}]" .format (epoch )
26
- for image , target in metric_logger .log_every (data_loader , print_freq , header ):
24
+ for i , ( image , target ) in enumerate ( metric_logger .log_every (data_loader , args . print_freq , header ) ):
27
25
start_time = time .time ()
28
26
image , target = image .to (device ), target .to (device )
29
27
output = model (image )
30
28
31
29
optimizer .zero_grad ()
32
- if amp :
30
+ if args . amp :
33
31
with torch .cuda .amp .autocast ():
34
32
loss = criterion (output , target )
35
33
scaler .scale (loss ).backward ()
@@ -40,16 +38,19 @@ def train_one_epoch(
40
38
loss .backward ()
41
39
optimizer .step ()
42
40
41
+ if model_ema and i % args .model_ema_steps == 0 :
42
+ model_ema .update_parameters (model )
43
+ if epoch < args .lr_warmup_epochs :
44
+ # Reset ema buffer to keep copying weights during warmup period
45
+ model_ema .n_averaged .fill_ (0 )
46
+
43
47
acc1 , acc5 = utils .accuracy (output , target , topk = (1 , 5 ))
44
48
batch_size = image .shape [0 ]
45
49
metric_logger .update (loss = loss .item (), lr = optimizer .param_groups [0 ]["lr" ])
46
50
metric_logger .meters ["acc1" ].update (acc1 .item (), n = batch_size )
47
51
metric_logger .meters ["acc5" ].update (acc5 .item (), n = batch_size )
48
52
metric_logger .meters ["img/s" ].update (batch_size / (time .time () - start_time ))
49
53
50
- if model_ema :
51
- model_ema .update_parameters (model )
52
-
53
54
54
55
def evaluate (model , criterion , data_loader , device , print_freq = 100 , log_suffix = "" ):
55
56
model .eval ()
@@ -106,24 +107,8 @@ def _get_cache_path(filepath):
106
107
def load_data (traindir , valdir , args ):
107
108
# Data loading code
108
109
print ("Loading data" )
109
- resize_size , crop_size = 256 , 224
110
- interpolation = InterpolationMode .BILINEAR
111
- if args .model == "inception_v3" :
112
- resize_size , crop_size = 342 , 299
113
- elif args .model .startswith ("efficientnet_" ):
114
- sizes = {
115
- "b0" : (256 , 224 ),
116
- "b1" : (256 , 240 ),
117
- "b2" : (288 , 288 ),
118
- "b3" : (320 , 300 ),
119
- "b4" : (384 , 380 ),
120
- "b5" : (456 , 456 ),
121
- "b6" : (528 , 528 ),
122
- "b7" : (600 , 600 ),
123
- }
124
- e_type = args .model .replace ("efficientnet_" , "" )
125
- resize_size , crop_size = sizes [e_type ]
126
- interpolation = InterpolationMode .BICUBIC
110
+ val_resize_size , val_crop_size , train_crop_size = args .val_resize_size , args .val_crop_size , args .train_crop_size
111
+ interpolation = InterpolationMode (args .interpolation )
127
112
128
113
print ("Loading training data" )
129
114
st = time .time ()
@@ -138,7 +123,10 @@ def load_data(traindir, valdir, args):
138
123
dataset = torchvision .datasets .ImageFolder (
139
124
traindir ,
140
125
presets .ClassificationPresetTrain (
141
- crop_size = crop_size , auto_augment_policy = auto_augment_policy , random_erase_prob = random_erase_prob
126
+ crop_size = train_crop_size ,
127
+ interpolation = interpolation ,
128
+ auto_augment_policy = auto_augment_policy ,
129
+ random_erase_prob = random_erase_prob ,
142
130
),
143
131
)
144
132
if args .cache_dataset :
@@ -156,7 +144,9 @@ def load_data(traindir, valdir, args):
156
144
else :
157
145
dataset_test = torchvision .datasets .ImageFolder (
158
146
valdir ,
159
- presets .ClassificationPresetEval (crop_size = crop_size , resize_size = resize_size , interpolation = interpolation ),
147
+ presets .ClassificationPresetEval (
148
+ crop_size = val_crop_size , resize_size = val_resize_size , interpolation = interpolation
149
+ ),
160
150
)
161
151
if args .cache_dataset :
162
152
print ("Saving dataset_test to {}" .format (cache_path ))
@@ -224,26 +214,30 @@ def main(args):
224
214
225
215
criterion = nn .CrossEntropyLoss (label_smoothing = args .label_smoothing )
226
216
217
+ if args .norm_weight_decay is None :
218
+ parameters = model .parameters ()
219
+ else :
220
+ param_groups = torchvision .ops ._utils .split_normalization_params (model )
221
+ wd_groups = [args .norm_weight_decay , args .weight_decay ]
222
+ parameters = [{"params" : p , "weight_decay" : w } for p , w in zip (param_groups , wd_groups ) if p ]
223
+
227
224
opt_name = args .opt .lower ()
228
225
if opt_name .startswith ("sgd" ):
229
226
optimizer = torch .optim .SGD (
230
- model . parameters () ,
227
+ parameters ,
231
228
lr = args .lr ,
232
229
momentum = args .momentum ,
233
230
weight_decay = args .weight_decay ,
234
231
nesterov = "nesterov" in opt_name ,
235
232
)
236
233
elif opt_name == "rmsprop" :
237
234
optimizer = torch .optim .RMSprop (
238
- model .parameters (),
239
- lr = args .lr ,
240
- momentum = args .momentum ,
241
- weight_decay = args .weight_decay ,
242
- eps = 0.0316 ,
243
- alpha = 0.9 ,
235
+ parameters , lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay , eps = 0.0316 , alpha = 0.9
244
236
)
237
+ elif opt_name == "adamw" :
238
+ optimizer = torch .optim .AdamW (parameters , lr = args .lr , weight_decay = args .weight_decay )
245
239
else :
246
- raise RuntimeError ("Invalid optimizer {}. Only SGD and RMSprop are supported." . format ( args . opt ) )
240
+ raise RuntimeError (f "Invalid optimizer { args . opt } . Only SGD, RMSprop and AdamW are supported." )
247
241
248
242
scaler = torch .cuda .amp .GradScaler () if args .amp else None
249
243
@@ -288,13 +282,23 @@ def main(args):
288
282
289
283
model_ema = None
290
284
if args .model_ema :
291
- model_ema = utils .ExponentialMovingAverage (model_without_ddp , device = device , decay = args .model_ema_decay )
285
+ # Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
286
+ # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
287
+ #
288
+ # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
289
+ # We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
290
+ # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
291
+ adjust = args .world_size * args .batch_size * args .model_ema_steps / args .epochs
292
+ alpha = 1.0 - args .model_ema_decay
293
+ alpha = min (1.0 , alpha * adjust )
294
+ model_ema = utils .ExponentialMovingAverage (model_without_ddp , device = device , decay = 1.0 - alpha )
292
295
293
296
if args .resume :
294
297
checkpoint = torch .load (args .resume , map_location = "cpu" )
295
298
model_without_ddp .load_state_dict (checkpoint ["model" ])
296
- optimizer .load_state_dict (checkpoint ["optimizer" ])
297
- lr_scheduler .load_state_dict (checkpoint ["lr_scheduler" ])
299
+ if not args .test_only :
300
+ optimizer .load_state_dict (checkpoint ["optimizer" ])
301
+ lr_scheduler .load_state_dict (checkpoint ["lr_scheduler" ])
298
302
args .start_epoch = checkpoint ["epoch" ] + 1
299
303
if model_ema :
300
304
model_ema .load_state_dict (checkpoint ["model_ema" ])
@@ -303,18 +307,18 @@ def main(args):
303
307
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
304
308
torch .backends .cudnn .benchmark = False
305
309
torch .backends .cudnn .deterministic = True
306
-
307
- evaluate (model , criterion , data_loader_test , device = device )
310
+ if model_ema :
311
+ evaluate (model_ema , criterion , data_loader_test , device = device , log_suffix = "EMA" )
312
+ else :
313
+ evaluate (model , criterion , data_loader_test , device = device )
308
314
return
309
315
310
316
print ("Start training" )
311
317
start_time = time .time ()
312
318
for epoch in range (args .start_epoch , args .epochs ):
313
319
if args .distributed :
314
320
train_sampler .set_epoch (epoch )
315
- train_one_epoch (
316
- model , criterion , optimizer , data_loader , device , epoch , args .print_freq , args .amp , model_ema , scaler
317
- )
321
+ train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , args , model_ema , scaler )
318
322
lr_scheduler .step ()
319
323
evaluate (model , criterion , data_loader_test , device = device )
320
324
if model_ema :
@@ -362,6 +366,12 @@ def get_args_parser(add_help=True):
362
366
help = "weight decay (default: 1e-4)" ,
363
367
dest = "weight_decay" ,
364
368
)
369
+ parser .add_argument (
370
+ "--norm-weight-decay" ,
371
+ default = None ,
372
+ type = float ,
373
+ help = "weight decay for Normalization layers (default: None, same value as --wd)" ,
374
+ )
365
375
parser .add_argument (
366
376
"--label-smoothing" , default = 0.0 , type = float , help = "label smoothing (default: 0.0)" , dest = "label_smoothing"
367
377
)
@@ -415,15 +425,33 @@ def get_args_parser(add_help=True):
415
425
parser .add_argument (
416
426
"--model-ema" , action = "store_true" , help = "enable tracking Exponential Moving Average of model parameters"
417
427
)
428
+ parser .add_argument (
429
+ "--model-ema-steps" ,
430
+ type = int ,
431
+ default = 32 ,
432
+ help = "the number of iterations that controls how often to update the EMA model (default: 32)" ,
433
+ )
418
434
parser .add_argument (
419
435
"--model-ema-decay" ,
420
436
type = float ,
421
- default = 0.9 ,
422
- help = "decay factor for Exponential Moving Average of model parameters(default: 0.9 )" ,
437
+ default = 0.99998 ,
438
+ help = "decay factor for Exponential Moving Average of model parameters (default: 0.99998 )" ,
423
439
)
424
440
parser .add_argument (
425
441
"--use-deterministic-algorithms" , action = "store_true" , help = "Forces the use of deterministic algorithms only."
426
442
)
443
+ parser .add_argument (
444
+ "--interpolation" , default = "bilinear" , type = str , help = "the interpolation method (default: bilinear)"
445
+ )
446
+ parser .add_argument (
447
+ "--val-resize-size" , default = 256 , type = int , help = "the resize size used for validation (default: 256)"
448
+ )
449
+ parser .add_argument (
450
+ "--val-crop-size" , default = 224 , type = int , help = "the central crop size used for validation (default: 224)"
451
+ )
452
+ parser .add_argument (
453
+ "--train-crop-size" , default = 224 , type = int , help = "the random crop size used for training (default: 224)"
454
+ )
427
455
428
456
return parser
429
457
0 commit comments