23
23
import warnings
24
24
from functools import update_wrapper
25
25
from types import SimpleNamespace
26
- from typing import Optional
26
+ from typing import Callable , Optional
27
27
28
28
import torch
29
29
import torch .utils .data
@@ -62,23 +62,32 @@ def train_one_epoch(
62
62
device : torch .device ,
63
63
epoch : int ,
64
64
args ,
65
+ log_metrics_fn : Callable [[str , utils .MetricLogger , int , int ], None ],
65
66
manager = None ,
66
67
model_ema = None ,
67
68
scaler = None ,
68
69
) -> utils .MetricLogger :
70
+ accum_steps = args .gradient_accum_steps
71
+
69
72
model .train ()
70
73
metric_logger = utils .MetricLogger (_LOGGER , delimiter = " " )
71
74
metric_logger .add_meter ("lr" , utils .SmoothedValue (window_size = 1 , fmt = "{value}" ))
72
- metric_logger .add_meter ("img/s" , utils .SmoothedValue (window_size = 10 , fmt = "{value}" ))
75
+ metric_logger .add_meter (
76
+ "imgs_per_sec" , utils .SmoothedValue (window_size = 10 , fmt = "{value}" )
77
+ )
78
+ metric_logger .add_meter ("loss" , utils .SmoothedValue (window_size = accum_steps ))
79
+ metric_logger .add_meter ("acc1" , utils .SmoothedValue (window_size = accum_steps ))
80
+ metric_logger .add_meter ("acc5" , utils .SmoothedValue (window_size = accum_steps ))
73
81
74
82
steps_accumulated = 0
83
+ num_optim_steps = 0
75
84
76
85
# initial zero grad for gradient accumulation
77
86
optimizer .zero_grad ()
78
87
79
88
header = f"Epoch: [{ epoch } ]"
80
- for i , (image , target ) in enumerate (
81
- metric_logger . log_every ( data_loader , args .print_freq , header )
89
+ for (image , target ) in metric_logger . log_every (
90
+ data_loader , args .logging_steps * accum_steps , header
82
91
):
83
92
start_time = time .time ()
84
93
image , target = image .to (device ), target .to (device )
@@ -89,7 +98,7 @@ def train_one_epoch(
89
98
output = output [0 ]
90
99
loss = criterion (output , target )
91
100
92
- if steps_accumulated % args . gradient_accum_steps == 0 :
101
+ if steps_accumulated % accum_steps == 0 :
93
102
if manager is not None :
94
103
loss = manager .loss_update (
95
104
loss = loss ,
@@ -119,9 +128,10 @@ def train_one_epoch(
119
128
120
129
# zero grad here to start accumulating next set of gradients
121
130
optimizer .zero_grad ()
131
+ num_optim_steps += 1
122
132
steps_accumulated += 1
123
133
124
- if model_ema and i % args .model_ema_steps == 0 :
134
+ if model_ema and num_optim_steps % args .model_ema_steps == 0 :
125
135
model_ema .update_parameters (model )
126
136
if epoch < args .lr_warmup_epochs :
127
137
# Reset ema buffer to keep copying weights during warmup period
@@ -132,7 +142,12 @@ def train_one_epoch(
132
142
metric_logger .update (loss = loss .item (), lr = optimizer .param_groups [0 ]["lr" ])
133
143
metric_logger .meters ["acc1" ].update (acc1 .item (), n = batch_size )
134
144
metric_logger .meters ["acc5" ].update (acc5 .item (), n = batch_size )
135
- metric_logger .meters ["img/s" ].update (batch_size / (time .time () - start_time ))
145
+ metric_logger .meters ["imgs_per_sec" ].update (
146
+ batch_size / (time .time () - start_time )
147
+ )
148
+
149
+ if num_optim_steps % args .logging_steps == 0 :
150
+ log_metrics_fn ("Train" , metric_logger , epoch , num_optim_steps )
136
151
return metric_logger
137
152
138
153
@@ -504,10 +519,17 @@ def collate_fn(batch):
504
519
criterion ,
505
520
data_loader_test ,
506
521
device ,
522
+ print_freq = args .logging_steps ,
507
523
log_suffix = "EMA" ,
508
524
)
509
525
else :
510
- evaluate (model , criterion , data_loader_test , device )
526
+ evaluate (
527
+ model ,
528
+ criterion ,
529
+ data_loader_test ,
530
+ device ,
531
+ print_freq = args .logging_steps ,
532
+ )
511
533
return
512
534
513
535
if utils .is_main_process ():
@@ -523,10 +545,13 @@ def collate_fn(batch):
523
545
else :
524
546
logger = LoggerManager (log_python = False )
525
547
526
- def log_metrics (tag : str , metrics : utils .MetricLogger , epoch : int ):
548
+ steps_per_epoch = len (data_loader ) / args .gradient_accum_steps
549
+
550
+ def log_metrics (tag : str , metrics : utils .MetricLogger , epoch : int , epoch_step : int ):
551
+ step = int (epoch * steps_per_epoch + epoch_step )
527
552
for metric_name , smoothed_value in metrics .meters .items ():
528
553
logger .log_scalar (
529
- f"{ tag } /{ metric_name } " , smoothed_value .global_avg , step = epoch
554
+ f"{ tag } /{ metric_name } " , smoothed_value .global_avg , step = step
530
555
)
531
556
532
557
if manager is not None :
@@ -537,7 +562,7 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int):
537
562
distillation_teacher = args .distill_teacher ,
538
563
)
539
564
optimizer = manager .modify (
540
- model , optimizer , len ( data_loader ) , epoch = args .start_epoch
565
+ model , optimizer , steps_per_epoch = steps_per_epoch , epoch = args .start_epoch
541
566
)
542
567
543
568
lr_scheduler = _get_lr_scheduler (
@@ -570,17 +595,18 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int):
570
595
device ,
571
596
epoch ,
572
597
args ,
598
+ log_metrics ,
573
599
manager = manager ,
574
600
model_ema = model_ema ,
575
601
scaler = scaler ,
576
602
)
577
- log_metrics ("Train" , train_metrics , epoch )
603
+ log_metrics ("Train" , train_metrics , epoch , steps_per_epoch )
578
604
579
605
if lr_scheduler :
580
606
lr_scheduler .step ()
581
607
582
608
eval_metrics = evaluate (model , criterion , data_loader_test , device )
583
- log_metrics ("Test" , eval_metrics , epoch )
609
+ log_metrics ("Test" , eval_metrics , epoch , steps_per_epoch )
584
610
585
611
top1_acc = eval_metrics .acc1 .global_avg
586
612
if model_ema :
@@ -591,7 +617,7 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int):
591
617
device ,
592
618
log_suffix = "EMA" ,
593
619
)
594
- log_metrics ("Test/EMA" , ema_eval_metrics , epoch )
620
+ log_metrics ("Test/EMA" , ema_eval_metrics , epoch , steps_per_epoch )
595
621
596
622
is_new_best = epoch >= args .save_best_after and top1_acc > best_top1_acc
597
623
if is_new_best :
@@ -916,7 +942,13 @@ def new_func(*args, **kwargs):
916
942
type = float ,
917
943
help = "minimum lr of lr schedule" ,
918
944
)
919
- @click .option ("--print-freq" , default = 10 , type = int , help = "print frequency" )
945
+ @click .option ("--print-freq" , default = None , type = int , help = "DEPRECATED. Does nothing." )
946
+ @click .option (
947
+ "--logging-steps" ,
948
+ default = 10 ,
949
+ type = int ,
950
+ help = "Frequency in number of batch updates for logging/printing" ,
951
+ )
920
952
@click .option ("--output-dir" , default = "." , type = str , help = "path to save outputs" )
921
953
@click .option ("--resume" , default = None , type = str , help = "path of checkpoint" )
922
954
@click .option (
0 commit comments