|
17 | 17 | import math
|
18 | 18 | import warnings
|
19 | 19 | from functools import partial
|
20 |
| -from typing import Callable, Iterable, Optional, Tuple, Union |
| 20 | +from typing import Optional, Union |
21 | 21 |
|
22 | 22 | import torch
|
23 |
| -from torch import nn |
24 | 23 | from torch.optim import Optimizer
|
25 | 24 | from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
|
26 | 25 |
|
@@ -604,120 +603,6 @@ def scheduler_hook(param):
|
604 | 603 | )
|
605 | 604 |
|
606 | 605 |
|
607 |
| -class AdamW(Optimizer): |
608 |
| - """ |
609 |
| - Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay |
610 |
| - Regularization](https://arxiv.org/abs/1711.05101). |
611 |
| -
|
612 |
| - Parameters: |
613 |
| - params (`Iterable[nn.parameter.Parameter]`): |
614 |
| - Iterable of parameters to optimize or dictionaries defining parameter groups. |
615 |
| - lr (`float`, *optional*, defaults to 0.001): |
616 |
| - The learning rate to use. |
617 |
| - betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`): |
618 |
| - Adam's betas parameters (b1, b2). |
619 |
| - eps (`float`, *optional*, defaults to 1e-06): |
620 |
| - Adam's epsilon for numerical stability. |
621 |
| - weight_decay (`float`, *optional*, defaults to 0.0): |
622 |
| - Decoupled weight decay to apply. |
623 |
| - correct_bias (`bool`, *optional*, defaults to `True`): |
624 |
| - Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). |
625 |
| - no_deprecation_warning (`bool`, *optional*, defaults to `False`): |
626 |
| - A flag used to disable the deprecation warning (set to `True` to disable the warning). |
627 |
| - """ |
628 |
| - |
629 |
| - def __init__( |
630 |
| - self, |
631 |
| - params: Iterable[nn.parameter.Parameter], |
632 |
| - lr: float = 1e-3, |
633 |
| - betas: Tuple[float, float] = (0.9, 0.999), |
634 |
| - eps: float = 1e-6, |
635 |
| - weight_decay: float = 0.0, |
636 |
| - correct_bias: bool = True, |
637 |
| - no_deprecation_warning: bool = False, |
638 |
| - ): |
639 |
| - if not no_deprecation_warning: |
640 |
| - warnings.warn( |
641 |
| - "This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch" |
642 |
| - " implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this" |
643 |
| - " warning", |
644 |
| - FutureWarning, |
645 |
| - ) |
646 |
| - require_version("torch>=1.5.0") # add_ with alpha |
647 |
| - if lr < 0.0: |
648 |
| - raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") |
649 |
| - if not 0.0 <= betas[0] < 1.0: |
650 |
| - raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)") |
651 |
| - if not 0.0 <= betas[1] < 1.0: |
652 |
| - raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") |
653 |
| - if not 0.0 <= eps: |
654 |
| - raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") |
655 |
| - defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias} |
656 |
| - super().__init__(params, defaults) |
657 |
| - |
658 |
| - @torch.no_grad() |
659 |
| - def step(self, closure: Callable = None): |
660 |
| - """ |
661 |
| - Performs a single optimization step. |
662 |
| -
|
663 |
| - Arguments: |
664 |
| - closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. |
665 |
| - """ |
666 |
| - loss = None |
667 |
| - if closure is not None: |
668 |
| - loss = closure() |
669 |
| - |
670 |
| - for group in self.param_groups: |
671 |
| - for p in group["params"]: |
672 |
| - if p.grad is None: |
673 |
| - continue |
674 |
| - grad = p.grad |
675 |
| - if grad.is_sparse: |
676 |
| - raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") |
677 |
| - |
678 |
| - state = self.state[p] |
679 |
| - |
680 |
| - # State initialization |
681 |
| - if len(state) == 0: |
682 |
| - state["step"] = 0 |
683 |
| - # Exponential moving average of gradient values |
684 |
| - state["exp_avg"] = torch.zeros_like(p) |
685 |
| - # Exponential moving average of squared gradient values |
686 |
| - state["exp_avg_sq"] = torch.zeros_like(p) |
687 |
| - |
688 |
| - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] |
689 |
| - beta1, beta2 = group["betas"] |
690 |
| - |
691 |
| - state["step"] += 1 |
692 |
| - |
693 |
| - # Decay the first and second moment running average coefficient |
694 |
| - # In-place operations to update the averages at the same time |
695 |
| - exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) |
696 |
| - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) |
697 |
| - denom = exp_avg_sq.sqrt().add_(group["eps"]) |
698 |
| - |
699 |
| - step_size = group["lr"] |
700 |
| - if group["correct_bias"]: # No bias correction for Bert |
701 |
| - bias_correction1 = 1.0 - beta1 ** state["step"] |
702 |
| - bias_correction2 = 1.0 - beta2 ** state["step"] |
703 |
| - step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 |
704 |
| - |
705 |
| - p.addcdiv_(exp_avg, denom, value=-step_size) |
706 |
| - |
707 |
| - # Just adding the square of the weights to the loss function is *not* |
708 |
| - # the correct way of using L2 regularization/weight decay with Adam, |
709 |
| - # since that will interact with the m and v parameters in strange ways. |
710 |
| - # |
711 |
| - # Instead we want to decay the weights in a manner that doesn't interact |
712 |
| - # with the m/v parameters. This is equivalent to adding the square |
713 |
| - # of the weights to the loss with plain (non-momentum) SGD. |
714 |
| - # Add weight decay at the end (fixed version) |
715 |
| - if group["weight_decay"] > 0.0: |
716 |
| - p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) |
717 |
| - |
718 |
| - return loss |
719 |
| - |
720 |
| - |
721 | 606 | class Adafactor(Optimizer):
|
722 | 607 | """
|
723 | 608 | AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
|
|
0 commit comments