@@ -260,10 +260,10 @@ def main(args):
260
260
# Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
261
261
# https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
262
262
#
263
- # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size * EMA_steps)
264
- # We consider constant = ( Dataset_size / n_GPUs) for a given dataset/setup and ommit it. Thus:
265
- # adjust = 1 / total_ema_updates ~= batch_size * EMA_steps / epochs
266
- adjust = args .batch_size * args .model_ema_steps / args .epochs
263
+ # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
264
+ # We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
265
+ # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
266
+ adjust = args .world_size * args . batch_size * args .model_ema_steps / args .epochs
267
267
alpha = 1.0 - args .model_ema_decay
268
268
alpha = min (1.0 , alpha * adjust )
269
269
model_ema = utils .ExponentialMovingAverage (model_without_ddp , device = device , decay = 1.0 - alpha )
@@ -397,8 +397,8 @@ def get_args_parser(add_help=True):
397
397
'--model-ema-steps' , type = int , default = 32 ,
398
398
help = 'the number of iterations that controls how often to update the EMA model (default: 32)' )
399
399
parser .add_argument (
400
- '--model-ema-decay' , type = float , default = 0.99999 ,
401
- help = 'decay factor for Exponential Moving Average of model parameters (default: 0.99999 )' )
400
+ '--model-ema-decay' , type = float , default = 0.99998 ,
401
+ help = 'decay factor for Exponential Moving Average of model parameters (default: 0.99998 )' )
402
402
403
403
return parser
404
404
0 commit comments