@@ -24,7 +24,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
24
24
metric_logger .add_meter ("lr" , utils .SmoothedValue (window_size = 1 , fmt = "{value}" ))
25
25
metric_logger .add_meter ("img/s" , utils .SmoothedValue (window_size = 10 , fmt = "{value}" ))
26
26
27
- header = ' Epoch: [{}]' .format (epoch )
27
+ header = " Epoch: [{}]" .format (epoch )
28
28
for i , (image , target ) in enumerate (metric_logger .log_every (data_loader , args .print_freq , header )):
29
29
start_time = time .time ()
30
30
image , target = image .to (device ), target .to (device )
@@ -219,12 +219,18 @@ def main(args):
219
219
220
220
opt_name = args .opt .lower ()
221
221
if opt_name .startswith ("sgd" ):
222
- optimizer = torch .optim .SGD (parameters , lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay ,
223
- nesterov = "nesterov" in opt_name )
224
- elif opt_name == 'rmsprop' :
225
- optimizer = torch .optim .RMSprop (parameters , lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay ,
226
- eps = 0.0316 , alpha = 0.9 )
227
- elif opt_name == 'adamw' :
222
+ optimizer = torch .optim .SGD (
223
+ parameters ,
224
+ lr = args .lr ,
225
+ momentum = args .momentum ,
226
+ weight_decay = args .weight_decay ,
227
+ nesterov = "nesterov" in opt_name ,
228
+ )
229
+ elif opt_name == "rmsprop" :
230
+ optimizer = torch .optim .RMSprop (
231
+ parameters , lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay , eps = 0.0316 , alpha = 0.9
232
+ )
233
+ elif opt_name == "adamw" :
228
234
optimizer = torch .optim .AdamW (parameters , lr = args .lr , weight_decay = args .weight_decay )
229
235
else :
230
236
raise RuntimeError (f"Invalid optimizer { args .opt } . Only SGD, RMSprop and AdamW are supported." )
@@ -285,18 +291,18 @@ def main(args):
285
291
model_ema = utils .ExponentialMovingAverage (model_without_ddp , device = device , decay = 1.0 - alpha )
286
292
287
293
if args .resume :
288
- checkpoint = torch .load (args .resume , map_location = ' cpu' )
289
- model_without_ddp .load_state_dict (checkpoint [' model' ])
294
+ checkpoint = torch .load (args .resume , map_location = " cpu" )
295
+ model_without_ddp .load_state_dict (checkpoint [" model" ])
290
296
if not args .test_only :
291
- optimizer .load_state_dict (checkpoint [' optimizer' ])
292
- lr_scheduler .load_state_dict (checkpoint [' lr_scheduler' ])
293
- args .start_epoch = checkpoint [' epoch' ] + 1
297
+ optimizer .load_state_dict (checkpoint [" optimizer" ])
298
+ lr_scheduler .load_state_dict (checkpoint [" lr_scheduler" ])
299
+ args .start_epoch = checkpoint [" epoch" ] + 1
294
300
if model_ema :
295
301
model_ema .load_state_dict (checkpoint ["model_ema" ])
296
302
297
303
if args .test_only :
298
304
if model_ema :
299
- evaluate (model_ema , criterion , data_loader_test , device = device , log_suffix = ' EMA' )
305
+ evaluate (model_ema , criterion , data_loader_test , device = device , log_suffix = " EMA" )
300
306
else :
301
307
evaluate (model , criterion , data_loader_test , device = device )
302
308
return
@@ -331,42 +337,52 @@ def main(args):
331
337
332
338
def get_args_parser (add_help = True ):
333
339
import argparse
334
- parser = argparse .ArgumentParser (description = 'PyTorch Classification Training' , add_help = add_help )
335
-
336
- parser .add_argument ('--data-path' , default = '/datasets01/imagenet_full_size/061417/' , help = 'dataset' )
337
- parser .add_argument ('--model' , default = 'resnet18' , help = 'model' )
338
- parser .add_argument ('--device' , default = 'cuda' , help = 'device' )
339
- parser .add_argument ('-b' , '--batch-size' , default = 32 , type = int )
340
- parser .add_argument ('--epochs' , default = 90 , type = int , metavar = 'N' ,
341
- help = 'number of total epochs to run' )
342
- parser .add_argument ('-j' , '--workers' , default = 16 , type = int , metavar = 'N' ,
343
- help = 'number of data loading workers (default: 16)' )
344
- parser .add_argument ('--opt' , default = 'sgd' , type = str , help = 'optimizer' )
345
- parser .add_argument ('--lr' , default = 0.1 , type = float , help = 'initial learning rate' )
346
- parser .add_argument ('--momentum' , default = 0.9 , type = float , metavar = 'M' ,
347
- help = 'momentum' )
348
- parser .add_argument ('--wd' , '--weight-decay' , default = 1e-4 , type = float ,
349
- metavar = 'W' , help = 'weight decay (default: 1e-4)' ,
350
- dest = 'weight_decay' )
351
- parser .add_argument ('--norm-weight-decay' , default = None , type = float ,
352
- help = 'weight decay for Normalization layers (default: None, same value as --wd)' )
353
- parser .add_argument ('--label-smoothing' , default = 0.0 , type = float ,
354
- help = 'label smoothing (default: 0.0)' ,
355
- dest = 'label_smoothing' )
356
- parser .add_argument ('--mixup-alpha' , default = 0.0 , type = float , help = 'mixup alpha (default: 0.0)' )
357
- parser .add_argument ('--cutmix-alpha' , default = 0.0 , type = float , help = 'cutmix alpha (default: 0.0)' )
358
- parser .add_argument ('--lr-scheduler' , default = "steplr" , help = 'the lr scheduler (default: steplr)' )
359
- parser .add_argument ('--lr-warmup-epochs' , default = 0 , type = int , help = 'the number of epochs to warmup (default: 0)' )
360
- parser .add_argument ('--lr-warmup-method' , default = "constant" , type = str ,
361
- help = 'the warmup method (default: constant)' )
362
- parser .add_argument ('--lr-warmup-decay' , default = 0.01 , type = float , help = 'the decay for lr' )
363
- parser .add_argument ('--lr-step-size' , default = 30 , type = int , help = 'decrease lr every step-size epochs' )
364
- parser .add_argument ('--lr-gamma' , default = 0.1 , type = float , help = 'decrease lr by a factor of lr-gamma' )
365
- parser .add_argument ('--print-freq' , default = 10 , type = int , help = 'print frequency' )
366
- parser .add_argument ('--output-dir' , default = '.' , help = 'path where to save' )
367
- parser .add_argument ('--resume' , default = '' , help = 'resume from checkpoint' )
368
- parser .add_argument ('--start-epoch' , default = 0 , type = int , metavar = 'N' ,
369
- help = 'start epoch' )
340
+
341
+ parser = argparse .ArgumentParser (description = "PyTorch Classification Training" , add_help = add_help )
342
+
343
+ parser .add_argument ("--data-path" , default = "/datasets01/imagenet_full_size/061417/" , help = "dataset" )
344
+ parser .add_argument ("--model" , default = "resnet18" , help = "model" )
345
+ parser .add_argument ("--device" , default = "cuda" , help = "device" )
346
+ parser .add_argument ("-b" , "--batch-size" , default = 32 , type = int )
347
+ parser .add_argument ("--epochs" , default = 90 , type = int , metavar = "N" , help = "number of total epochs to run" )
348
+ parser .add_argument (
349
+ "-j" , "--workers" , default = 16 , type = int , metavar = "N" , help = "number of data loading workers (default: 16)"
350
+ )
351
+ parser .add_argument ("--opt" , default = "sgd" , type = str , help = "optimizer" )
352
+ parser .add_argument ("--lr" , default = 0.1 , type = float , help = "initial learning rate" )
353
+ parser .add_argument ("--momentum" , default = 0.9 , type = float , metavar = "M" , help = "momentum" )
354
+ parser .add_argument (
355
+ "--wd" ,
356
+ "--weight-decay" ,
357
+ default = 1e-4 ,
358
+ type = float ,
359
+ metavar = "W" ,
360
+ help = "weight decay (default: 1e-4)" ,
361
+ dest = "weight_decay" ,
362
+ )
363
+ parser .add_argument (
364
+ "--norm-weight-decay" ,
365
+ default = None ,
366
+ type = float ,
367
+ help = "weight decay for Normalization layers (default: None, same value as --wd)" ,
368
+ )
369
+ parser .add_argument (
370
+ "--label-smoothing" , default = 0.0 , type = float , help = "label smoothing (default: 0.0)" , dest = "label_smoothing"
371
+ )
372
+ parser .add_argument ("--mixup-alpha" , default = 0.0 , type = float , help = "mixup alpha (default: 0.0)" )
373
+ parser .add_argument ("--cutmix-alpha" , default = 0.0 , type = float , help = "cutmix alpha (default: 0.0)" )
374
+ parser .add_argument ("--lr-scheduler" , default = "steplr" , help = "the lr scheduler (default: steplr)" )
375
+ parser .add_argument ("--lr-warmup-epochs" , default = 0 , type = int , help = "the number of epochs to warmup (default: 0)" )
376
+ parser .add_argument (
377
+ "--lr-warmup-method" , default = "constant" , type = str , help = "the warmup method (default: constant)"
378
+ )
379
+ parser .add_argument ("--lr-warmup-decay" , default = 0.01 , type = float , help = "the decay for lr" )
380
+ parser .add_argument ("--lr-step-size" , default = 30 , type = int , help = "decrease lr every step-size epochs" )
381
+ parser .add_argument ("--lr-gamma" , default = 0.1 , type = float , help = "decrease lr by a factor of lr-gamma" )
382
+ parser .add_argument ("--print-freq" , default = 10 , type = int , help = "print frequency" )
383
+ parser .add_argument ("--output-dir" , default = "." , help = "path where to save" )
384
+ parser .add_argument ("--resume" , default = "" , help = "resume from checkpoint" )
385
+ parser .add_argument ("--start-epoch" , default = 0 , type = int , metavar = "N" , help = "start epoch" )
370
386
parser .add_argument (
371
387
"--cache-dataset" ,
372
388
dest = "cache_dataset" ,
@@ -412,11 +428,17 @@ def get_args_parser(add_help=True):
412
428
"--model-ema" , action = "store_true" , help = "enable tracking Exponential Moving Average of model parameters"
413
429
)
414
430
parser .add_argument (
415
- '--model-ema-steps' , type = int , default = 32 ,
416
- help = 'the number of iterations that controls how often to update the EMA model (default: 32)' )
431
+ "--model-ema-steps" ,
432
+ type = int ,
433
+ default = 32 ,
434
+ help = "the number of iterations that controls how often to update the EMA model (default: 32)" ,
435
+ )
417
436
parser .add_argument (
418
- '--model-ema-decay' , type = float , default = 0.99998 ,
419
- help = 'decay factor for Exponential Moving Average of model parameters (default: 0.99998)' )
437
+ "--model-ema-decay" ,
438
+ type = float ,
439
+ default = 0.99998 ,
440
+ help = "decay factor for Exponential Moving Average of model parameters (default: 0.99998)" ,
441
+ )
420
442
421
443
return parser
422
444
0 commit comments