Skip to content

Remove our AdamW implementation #36177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 19, 2025
Merged
3 changes: 0 additions & 3 deletions docs/source/en/main_classes/optimizer_schedules.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ The `.optimization` module provides:
- several schedules in the form of schedule objects that inherit from `_LRSchedule`:
- a gradient accumulation class to accumulate the gradients of multiple batches

## AdamW (PyTorch)

[[autodoc]] AdamW

## AdaFactor (PyTorch)

Expand Down
4 changes: 0 additions & 4 deletions docs/source/ja/main_classes/optimizer_schedules.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ rendered properly in your Markdown viewer.
- `_LRSchedule` から継承するスケジュール オブジェクトの形式のいくつかのスケジュール:
- 複数のバッチの勾配を累積するための勾配累積クラス

## AdamW (PyTorch)

[[autodoc]] AdamW

## AdaFactor (PyTorch)

[[autodoc]] Adafactor
Expand Down
4 changes: 0 additions & 4 deletions docs/source/zh/main_classes/optimizer_schedules.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ rendered properly in your Markdown viewer.
- 继承自 `_LRSchedule` 多个调度器:
- 一个梯度累积类,用于累积多个批次的梯度

## AdamW (PyTorch)

[[autodoc]] AdamW

## AdaFactor (PyTorch)

[[autodoc]] Adafactor
Expand Down
8 changes: 6 additions & 2 deletions examples/legacy/pytorch-lightning/lightning_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from pytorch_lightning.utilities import rank_zero_info

from transformers import (
AdamW,
AutoConfig,
AutoModel,
AutoModelForPreTraining,
Expand All @@ -20,6 +19,7 @@
AutoTokenizer,
PretrainedConfig,
PreTrainedTokenizer,
is_torch_available,
)
from transformers.optimization import (
Adafactor,
Expand All @@ -31,6 +31,10 @@
from transformers.utils.versions import require_version


if is_torch_available():
import torch


logger = logging.getLogger(__name__)

require_version("pytorch_lightning>=1.0.4")
Expand Down Expand Up @@ -146,7 +150,7 @@ def configure_optimizers(self):
)

else:
optimizer = AdamW(
optimizer = torch.optim.AdamW(
optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
)
self.opt = optimizer
Expand Down
3 changes: 1 addition & 2 deletions examples/legacy/question-answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from transformers import (
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
WEIGHTS_NAME,
AdamW,
AutoConfig,
AutoModelForQuestionAnswering,
AutoTokenizer,
Expand Down Expand Up @@ -96,7 +95,7 @@ def train(args, train_dataset, model, tokenizer):
},
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
Expand Down
3 changes: 1 addition & 2 deletions examples/legacy/run_openai_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from transformers import (
CONFIG_NAME,
WEIGHTS_NAME,
AdamW,
OpenAIGPTDoubleHeadsModel,
OpenAIGPTTokenizer,
get_linear_schedule_with_warmup,
Expand Down Expand Up @@ -236,7 +235,7 @@ def tokenize_and_encode(obj):
},
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
Expand Down
3 changes: 1 addition & 2 deletions examples/legacy/run_swag.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import transformers
from transformers import (
WEIGHTS_NAME,
AdamW,
AutoConfig,
AutoModelForMultipleChoice,
AutoTokenizer,
Expand Down Expand Up @@ -298,7 +297,7 @@ def train(args, train_dataset, model, tokenizer):
},
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
Expand Down
4 changes: 1 addition & 3 deletions examples/legacy/seq2seq/seq2seq_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from transformers.models.fsmt.configuration_fsmt import FSMTConfig
from transformers.optimization import (
Adafactor,
AdamW,
get_constant_schedule,
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
Expand Down Expand Up @@ -102,12 +101,11 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):
"weight_decay": 0.0,
},
]
optimizer_cls = Adafactor if self.args.adafactor else AdamW
if self.args.adafactor:
optimizer_cls = Adafactor
optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
else:
optimizer_cls = AdamW
optimizer_cls = torch.optim.AdamW
optimizer_kwargs = {
"betas": (self.args.adam_beta1, self.args.adam_beta2),
"eps": self.args.adam_epsilon,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@

import transformers
from transformers import (
AdamW,
DataCollatorWithPadding,
EvalPrediction,
SchedulerType,
Expand Down Expand Up @@ -767,7 +766,7 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

import transformers
from transformers import (
AdamW,
SchedulerType,
Wav2Vec2Config,
Wav2Vec2FeatureExtractor,
Expand Down Expand Up @@ -583,7 +582,7 @@ def prepare_dataset(batch):
)

# Optimizer
optimizer = AdamW(
optimizer = torch.optim.AdamW(
list(model.parameters()),
lr=args.learning_rate,
betas=[args.adam_beta1, args.adam_beta2],
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4104,7 +4104,6 @@
)
_import_structure["optimization"] = [
"Adafactor",
"AdamW",
"get_constant_schedule",
"get_constant_schedule_with_warmup",
"get_cosine_schedule_with_warmup",
Expand Down Expand Up @@ -8746,7 +8745,6 @@
# Optimization
from .optimization import (
Adafactor,
AdamW,
get_constant_schedule,
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
Expand Down
117 changes: 1 addition & 116 deletions src/transformers/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
import math
import warnings
from functools import partial
from typing import Callable, Iterable, Optional, Tuple, Union
from typing import Optional, Union

import torch
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau

Expand Down Expand Up @@ -604,120 +603,6 @@ def scheduler_hook(param):
)


class AdamW(Optimizer):
"""
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
Regularization](https://arxiv.org/abs/1711.05101).

Parameters:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*, defaults to 0.001):
The learning rate to use.
betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
Adam's betas parameters (b1, b2).
eps (`float`, *optional*, defaults to 1e-06):
Adam's epsilon for numerical stability.
weight_decay (`float`, *optional*, defaults to 0.0):
Decoupled weight decay to apply.
correct_bias (`bool`, *optional*, defaults to `True`):
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
A flag used to disable the deprecation warning (set to `True` to disable the warning).
"""

def __init__(
self,
params: Iterable[nn.parameter.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
no_deprecation_warning: bool = False,
):
if not no_deprecation_warning:
warnings.warn(
"This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch"
" implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this"
" warning",
FutureWarning,
)
require_version("torch>=1.5.0") # add_ with alpha
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
super().__init__(params, defaults)

@torch.no_grad()
def step(self, closure: Callable = None):
"""
Performs a single optimization step.

Arguments:
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")

state = self.state[p]

# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p)

exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]

state["step"] += 1

# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])

step_size = group["lr"]
if group["correct_bias"]: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state["step"]
bias_correction2 = 1.0 - beta2 ** state["step"]
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

p.addcdiv_(exp_avg, denom, value=-step_size)

# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0:
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))

return loss


class Adafactor(Optimizer):
"""
AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,11 +1421,6 @@ def optimizer_hook(param):
if args.optim == OptimizerNames.ADAFACTOR:
optimizer_cls = Adafactor
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
elif args.optim == OptimizerNames.ADAMW_HF:
from .optimization import AdamW

optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
from torch.optim import AdamW

Expand Down
5 changes: 2 additions & 3 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ class OptimizerNames(ExplicitEnum):
Stores the acceptable string identifiers for optimizers.
"""

ADAMW_HF = "adamw_hf"
ADAMW_TORCH = "adamw_torch"
ADAMW_TORCH_FUSED = "adamw_torch_fused"
ADAMW_TORCH_XLA = "adamw_torch_xla"
Expand Down Expand Up @@ -628,7 +627,7 @@ class TrainingArguments:

The options should be separated by whitespaces.
optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch"`):
The optimizer to use, such as "adamw_hf", "adamw_torch", "adamw_torch_fused", "adamw_apex_fused", "adamw_anyprecision",
The optimizer to use, such as "adamw_torch", "adamw_torch_fused", "adamw_apex_fused", "adamw_anyprecision",
"adafactor". See `OptimizerNames` in [training_args.py](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py)
for a full list of optimizers.
optim_args (`str`, *optional*):
Expand Down Expand Up @@ -2986,7 +2985,7 @@ def set_optimizer(

Args:
name (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch"`):
The optimizer to use: `"adamw_hf"`, `"adamw_torch"`, `"adamw_torch_fused"`, `"adamw_apex_fused"`,
The optimizer to use: `"adamw_torch"`, `"adamw_torch_fused"`, `"adamw_apex_fused"`,
`"adamw_anyprecision"` or `"adafactor"`.
learning_rate (`float`, *optional*, defaults to 5e-5):
The initial learning rate.
Expand Down
7 changes: 0 additions & 7 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -10842,13 +10842,6 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class AdamW(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


def get_constant_schedule(*args, **kwargs):
requires_backends(get_constant_schedule, ["torch"])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,6 @@ def _mp_fn(index):
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
AdamW,
AutoConfig,
{{cookiecutter.model_class}},
AutoTokenizer,
Expand Down Expand Up @@ -863,7 +862,7 @@ def tokenize_function(examples):
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
Expand Down
Loading