Skip to content

Commit 7fed7a1

Browse files
rohitgr7pre-commit-ci[bot]Felonious-Spellfireawaelchli
authored
Add LRFinder callback (#13802)
* add BatchSizeFinderCallback callback * enable fast_dev_run test * keep tune and remove early_exit * move exception to setup * Apply suggestions from code review Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laverne Henderson <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 5f10695 commit 7fed7a1

File tree

10 files changed

+226
-64
lines changed

10 files changed

+226
-64
lines changed

docs/source-pytorch/api_references.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ callbacks
3636
EarlyStopping
3737
GradientAccumulationScheduler
3838
LambdaCallback
39+
LearningRateFinder
3940
LearningRateMonitor
4041
ModelCheckpoint
4142
ModelPruning

docs/source-pytorch/extensions/callbacks.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ Lightning has a few built-in callbacks.
9393
EarlyStopping
9494
GradientAccumulationScheduler
9595
LambdaCallback
96+
LearningRateFinder
9697
LearningRateMonitor
9798
ModelCheckpoint
9899
ModelPruning

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212
- Added `BatchSizeFinder` callback ([#11089](https://github.com/PyTorchLightning/pytorch-lightning/pull/11089))
1313

1414

15+
- Added `LearningRateFinder` callback ([#13802](https://github.com/PyTorchLightning/pytorch-lightning/pull/13802))
16+
17+
1518
- Tuner now supports a new `method` argument which will determine when to run the `BatchSizeFinder`: one of `fit`, `validate`, `test` or `predict` ([#11089](https://github.com/PyTorchLightning/pytorch-lightning/pull/11089))
1619

1720

src/pytorch_lightning/callbacks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning
2020
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
2121
from pytorch_lightning.callbacks.lambda_function import LambdaCallback
22+
from pytorch_lightning.callbacks.lr_finder import LearningRateFinder
2223
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
2324
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
2425
from pytorch_lightning.callbacks.model_summary import ModelSummary
@@ -41,6 +42,7 @@
4142
"EarlyStopping",
4243
"GradientAccumulationScheduler",
4344
"LambdaCallback",
45+
"LearningRateFinder",
4446
"LearningRateMonitor",
4547
"ModelCheckpoint",
4648
"ModelPruning",
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
r"""
15+
LearningRateFinder
16+
==================
17+
18+
Finds optimal learning rate
19+
"""
20+
from typing import Optional
21+
22+
import pytorch_lightning as pl
23+
from pytorch_lightning.callbacks.callback import Callback
24+
from pytorch_lightning.tuner.lr_finder import _LRFinder, lr_find
25+
from pytorch_lightning.utilities.exceptions import _TunerExitException, MisconfigurationException
26+
from pytorch_lightning.utilities.seed import isolate_rng
27+
28+
29+
class LearningRateFinder(Callback):
30+
"""The ``LearningRateFinder`` callback enables the user to do a range test of good initial learning rates, to
31+
reduce the amount of guesswork in picking a good starting learning rate.
32+
33+
Args:
34+
min_lr: Minimum learning rate to investigate
35+
36+
max_lr: Maximum learning rate to investigate
37+
38+
num_training_steps: Number of learning rates to test
39+
40+
mode: Search strategy to update learning rate after each batch:
41+
42+
- ``'exponential'`` (default): Increases the learning rate exponentially.
43+
- ``'linear'``: Increases the learning rate linearly.
44+
45+
early_stop_threshold: Threshold for stopping the search. If the
46+
loss at any point is larger than early_stop_threshold*best_loss
47+
then the search is stopped. To disable, set to None.
48+
49+
update_attr: Whether to update the learning rate attribute or not.
50+
51+
Raises:
52+
MisconfigurationException:
53+
If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden when ``auto_lr_find=True``,
54+
or if you are using more than one optimizer.
55+
"""
56+
57+
SUPPORTED_MODES = ("linear", "exponential")
58+
59+
def __init__(
60+
self,
61+
min_lr: float = 1e-8,
62+
max_lr: float = 1,
63+
num_training_steps: int = 100,
64+
mode: str = "exponential",
65+
early_stop_threshold: float = 4.0,
66+
update_attr: bool = False,
67+
) -> None:
68+
mode = mode.lower()
69+
if mode not in self.SUPPORTED_MODES:
70+
raise MisconfigurationException(f"`mode` should be either of {self.SUPPORTED_MODES}")
71+
72+
self._min_lr = min_lr
73+
self._max_lr = max_lr
74+
self._num_training_steps = num_training_steps
75+
self._mode = mode
76+
self._early_stop_threshold = early_stop_threshold
77+
self._update_attr = update_attr
78+
79+
self._early_exit = False
80+
self.lr_finder: Optional[_LRFinder] = None
81+
82+
def lr_find(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
83+
with isolate_rng():
84+
self.optimal_lr = lr_find(
85+
trainer,
86+
pl_module,
87+
min_lr=self._min_lr,
88+
max_lr=self._max_lr,
89+
num_training=self._num_training_steps,
90+
mode=self._mode,
91+
early_stop_threshold=self._early_stop_threshold,
92+
update_attr=self._update_attr,
93+
)
94+
95+
if self._early_exit:
96+
raise _TunerExitException()
97+
98+
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
99+
self.lr_find(trainer, pl_module)

src/pytorch_lightning/trainer/connectors/callback_connector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
TQDMProgressBar,
3030
)
3131
from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder
32+
from pytorch_lightning.callbacks.lr_finder import LearningRateFinder
3233
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
3334
from pytorch_lightning.callbacks.timer import Timer
3435
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -246,7 +247,7 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
246247
checkpoint_callbacks: List[Callback] = []
247248

248249
for cb in callbacks:
249-
if isinstance(cb, BatchSizeFinder):
250+
if isinstance(cb, (BatchSizeFinder, LearningRateFinder)):
250251
tuner_callbacks.append(cb)
251252
elif isinstance(cb, Checkpoint):
252253
checkpoint_callbacks.append(cb)

src/pytorch_lightning/tuner/lr_finder.py

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import logging
1616
import os
1717
import uuid
18-
from functools import wraps
18+
from copy import deepcopy
1919
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, TYPE_CHECKING, Union
2020

2121
import numpy as np
@@ -25,8 +25,6 @@
2525

2626
import pytorch_lightning as pl
2727
from pytorch_lightning.callbacks import Callback
28-
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, _set_scheduler_opt_idx
29-
from pytorch_lightning.loggers.logger import DummyLogger
3028
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3129
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_setattr
3230
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
@@ -92,7 +90,7 @@ class _LRFinder:
9290
lr = lr_finder.suggestion()
9391
"""
9492

95-
def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
93+
def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int) -> None:
9694
assert mode in ("linear", "exponential"), "mode should be either `linear` or `exponential`"
9795

9896
self.mode = mode
@@ -104,38 +102,33 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
104102
self._total_batch_idx = 0 # for debug purpose
105103

106104
def _exchange_scheduler(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> Callable[["pl.Trainer"], None]:
105+
# TODO: update docs here
107106
"""Decorate `trainer.strategy.setup_optimizers` method such that it sets the user's originally specified
108107
optimizer together with a new scheduler that takes care of the learning rate search."""
109-
setup_optimizers = trainer.strategy.setup_optimizers
108+
from pytorch_lightning.core.optimizer import _set_scheduler_opt_idx
110109

111-
@wraps(setup_optimizers)
112-
def func(trainer: "pl.Trainer") -> None:
113-
# Decide the structure of the output from _init_optimizers_and_lr_schedulers
114-
optimizers, _, _ = _init_optimizers_and_lr_schedulers(trainer.lightning_module)
110+
optimizers = trainer.strategy.optimizers
115111

116-
if len(optimizers) != 1:
117-
raise MisconfigurationException(
118-
f"`model.configure_optimizers()` returned {len(optimizers)}, but"
119-
" learning rate finder only works with single optimizer"
120-
)
121-
122-
optimizer = optimizers[0]
112+
if len(optimizers) != 1:
113+
raise MisconfigurationException(
114+
f"`model.configure_optimizers()` returned {len(optimizers)}, but"
115+
" learning rate finder only works with single optimizer"
116+
)
123117

124-
new_lrs = [self.lr_min] * len(optimizer.param_groups)
125-
for param_group, new_lr in zip(optimizer.param_groups, new_lrs):
126-
param_group["lr"] = new_lr
127-
param_group["initial_lr"] = new_lr
118+
optimizer = optimizers[0]
128119

129-
args = (optimizer, self.lr_max, self.num_training)
130-
scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
131-
scheduler = cast(pl.utilities.types._LRScheduler, scheduler)
120+
new_lrs = [self.lr_min] * len(optimizer.param_groups)
121+
for param_group, new_lr in zip(optimizer.param_groups, new_lrs):
122+
param_group["lr"] = new_lr
123+
param_group["initial_lr"] = new_lr
132124

133-
trainer.strategy.optimizers = [optimizer]
134-
trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)]
135-
trainer.strategy.optimizer_frequencies = []
136-
_set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs)
125+
args = (optimizer, self.lr_max, self.num_training)
126+
scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
127+
scheduler = cast(pl.utilities.types._LRScheduler, scheduler)
137128

138-
return func
129+
trainer.strategy.optimizers = [optimizer]
130+
trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step", opt_idx=0)]
131+
_set_scheduler_opt_idx(trainer.optimizers, trainer.lr_scheduler_configs)
139132

140133
def plot(self, suggest: bool = False, show: bool = False) -> Optional["plt.Figure"]:
141134
"""Plot results from lr_find run
@@ -225,23 +218,25 @@ def lr_find(
225218
# Save initial model, that is loaded after learning rate is found
226219
ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
227220
trainer.save_checkpoint(ckpt_path)
221+
222+
# Arguments we adjust during the lr finder, save for restoring
228223
params = __lr_finder_dump_params(trainer)
229224

230225
# Set to values that are required by the algorithm
231226
__lr_finder_reset_params(trainer, num_training, early_stop_threshold)
232227

233-
# Initialize lr finder object (stores results)
234-
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)
235-
236228
# Disable standard progress bar for fit
237229
if trainer.progress_bar_callback:
238230
trainer.progress_bar_callback.disable()
239231

232+
# Initialize lr finder object (stores results)
233+
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)
234+
240235
# Configure optimizer and scheduler
241-
trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model) # type: ignore[assignment]
236+
lr_finder._exchange_scheduler(trainer, model)
242237

243238
# Fit, lr & loss logged in callback
244-
trainer.tuner._run(model)
239+
_try_loop_run(trainer, params)
245240

246241
# Prompt if we stopped early
247242
if trainer.global_step != num_training:
@@ -274,31 +269,48 @@ def lr_find(
274269

275270
def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]:
276271
return {
277-
"auto_lr_find": trainer.auto_lr_find,
272+
"optimizers": trainer.strategy.optimizers,
273+
"lr_scheduler_configs": trainer.strategy.lr_scheduler_configs,
274+
"optimizer_frequencies": trainer.strategy.optimizer_frequencies,
278275
"callbacks": trainer.callbacks,
279-
"logger": trainer.logger,
276+
"loggers": trainer.loggers,
277+
# TODO: check if this is required
278+
"auto_lr_find": trainer.auto_lr_find,
280279
"max_steps": trainer.fit_loop.max_steps,
281-
"setup_optimizers": trainer.strategy.setup_optimizers,
280+
"limit_val_batches": trainer.limit_val_batches,
281+
"loop_state_dict": deepcopy(trainer.fit_loop.state_dict()),
282282
}
283283

284284

285285
def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: float) -> None:
286+
from pytorch_lightning.loggers.logger import DummyLogger
287+
288+
trainer.strategy.lr_scheduler_configs = []
289+
trainer.strategy.optimizer_frequencies = []
286290
# avoid lr find being called multiple times
287291
trainer.auto_lr_find = False
288292
# Use special lr logger callback
289293
trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)]
290294
# No logging
291-
trainer.loggers = [DummyLogger()] if trainer.loggers else []
295+
trainer.logger = DummyLogger() if trainer.logger is not None else None
292296
# Max step set to number of iterations
293297
trainer.fit_loop.max_steps = num_training
298+
trainer.limit_val_batches = num_training
294299

295300

296301
def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:
302+
trainer.strategy.optimizers = params["optimizers"]
303+
trainer.strategy.lr_scheduler_configs = params["lr_scheduler_configs"]
304+
trainer.strategy.optimizer_frequencies = params["optimizer_frequencies"]
297305
trainer.auto_lr_find = params["auto_lr_find"]
298306
trainer.callbacks = params["callbacks"]
299-
trainer.logger = params["logger"]
307+
trainer.loggers = params["loggers"]
300308
trainer.fit_loop.max_steps = params["max_steps"]
301-
trainer.strategy.setup_optimizers = params["setup_optimizers"] # type: ignore[assignment]
309+
trainer.limit_val_batches = params["limit_val_batches"]
310+
311+
loop = trainer.fit_loop
312+
loop.load_state_dict(deepcopy(params["loop_state_dict"]))
313+
loop.restarting = False
302314

303315

304316
class _LRCallback(Callback):
@@ -453,3 +465,10 @@ def get_lr(self) -> List[float]: # type: ignore[override]
453465
@property
454466
def lr(self) -> Union[float, List[float]]:
455467
return self._lr
468+
469+
470+
def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:
471+
loop = trainer.fit_loop
472+
loop.load_state_dict(deepcopy(params["loop_state_dict"]))
473+
loop.restarting = False
474+
loop.run()

0 commit comments

Comments
 (0)