diff --git a/references/classification/utils.py b/references/classification/utils.py index 473f4815265..7f573415c4c 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -166,17 +166,7 @@ def __init__(self, model, decay, device="cpu"): def ema_avg(avg_model_param, model_param, num_averaged): return decay * avg_model_param + (1 - decay) * model_param - super().__init__(model, device, ema_avg) - - def update_parameters(self, model): - for p_swa, p_model in zip(self.module.state_dict().values(), model.state_dict().values()): - device = p_swa.device - p_model_ = p_model.detach().to(device) - if self.n_averaged == 0: - p_swa.detach().copy_(p_model_) - else: - p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_, self.n_averaged.to(device))) - self.n_averaged += 1 + super().__init__(model, device, ema_avg, use_buffers=True) def accuracy(output, target, topk=(1,)):