Skip to content

add inverse_sqrt lr decay style #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 28, 2022
5 changes: 4 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def _add_learning_rate_args(parser):
'and initial warmup, the learing rate at each '
'iteration would be different.')
group.add_argument('--lr-decay-style', type=str, default='linear',
choices=['constant', 'linear', 'cosine'],
choices=['constant', 'linear', 'cosine', 'inverse_sqrt'],
help='Learning rate decay function.')
group.add_argument('--lr-decay-iters', type=int, default=None,
help='number of iterations to decay learning rate over,'
Expand All @@ -621,6 +621,9 @@ def _add_learning_rate_args(parser):
group.add_argument('--lr-decay-tokens', type=int, default=None,
help='number of tokens to decay learning rate over,'
' If not None will override iter/sample-based decay')
group.add_argument('--lr-warmup-style', type=str, default='linear',
choices=['constant', 'linear'], help='Learning rate '
'warmup function.')
group.add_argument('--lr-warmup-fraction', type=float, default=None,
help='fraction of lr-warmup-(iters/samples) to use '
'for warmup (as a float)')
Expand Down
24 changes: 20 additions & 4 deletions megatron/learning_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class AnnealingLR(object):
"""Anneals the learning rate."""

def __init__(self, optimizer, max_lr, min_lr,
warmup_steps, decay_steps, decay_style,
warmup_steps, decay_steps, decay_style, warmup_style,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False):
args = get_args()
Expand All @@ -46,6 +46,7 @@ def __init__(self, optimizer, max_lr, min_lr,
self.warmup_tokens = 0

self.decay_style = decay_style
self.warmup_style = warmup_style

self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
Expand All @@ -63,18 +64,33 @@ def get_lr(self):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""

# Use linear warmup for the initial part.

# Use warmup for the initial part.
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
if self.num_steps == self.warmup_steps and \
self.decay_tokens is not None:
self.warmup_tokens = self.num_tokens
return self.max_lr * float(self.num_steps) / \
float(self.warmup_steps)
if self.warmup_style == 'linear':
return self.max_lr * float(self.num_steps) / \
float(self.warmup_steps)
elif self.warmup_style == 'constant':
return self.max_lr
else:
raise ValueError('Unknown warmup style: {}'.format(
self.warmup_style))

# If the learning rate is constant, just return the initial value.
if self.decay_style == 'constant':
return self.max_lr


# If constant decay style
# In warmup phase: lr = max_lr
# In decay phase: lr = max_lr * sqrt(warmup_steps) / sqrt(num_steps)
# Note: To replicate t5x check https://github.com/TurkuNLP/Megatron-DeepSpeed/pull/2
if self.decay_style == 'inverse_sqrt':
return self.max_lr * (max(self.warmup_steps, 1) / max(self.num_steps, 1))**0.5

if self.decay_tokens is None:
# step-based decay

Expand Down
1 change: 1 addition & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def get_learning_rate_scheduler(optimizer):
warmup_steps=warmup_steps,
decay_steps=decay_steps,
decay_style=args.lr_decay_style,
warmup_style=args.lr_warmup_style,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler)

Expand Down