Skip to content

Commit c1f1e22

Browse files
authored
Added update_parameters to EMA to fix calculation (#4406)
1 parent 9fa689b commit c1f1e22

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

references/classification/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,17 @@ def __init__(self, model, decay, device='cpu'):
172172
decay * avg_model_param + (1 - decay) * model_param)
173173
super().__init__(model, device, ema_avg)
174174

175+
def update_parameters(self, model):
176+
for p_swa, p_model in zip(self.module.state_dict().values(), model.state_dict().values()):
177+
device = p_swa.device
178+
p_model_ = p_model.detach().to(device)
179+
if self.n_averaged == 0:
180+
p_swa.detach().copy_(p_model_)
181+
else:
182+
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
183+
self.n_averaged.to(device)))
184+
self.n_averaged += 1
185+
175186

176187
def accuracy(output, target, topk=(1,)):
177188
"""Computes the accuracy over the k top predictions for the specified values of k"""

0 commit comments

Comments
 (0)