15
15
import logging
16
16
import os
17
17
import uuid
18
- from functools import wraps
18
+ from copy import deepcopy
19
19
from typing import Any , Callable , cast , Dict , List , Optional , Sequence , TYPE_CHECKING , Union
20
20
21
21
import numpy as np
25
25
26
26
import pytorch_lightning as pl
27
27
from pytorch_lightning .callbacks import Callback
28
- from pytorch_lightning .core .optimizer import _init_optimizers_and_lr_schedulers , _set_scheduler_opt_idx
29
- from pytorch_lightning .loggers .logger import DummyLogger
30
28
from pytorch_lightning .utilities .exceptions import MisconfigurationException
31
29
from pytorch_lightning .utilities .parsing import lightning_hasattr , lightning_setattr
32
30
from pytorch_lightning .utilities .rank_zero import rank_zero_warn
@@ -92,7 +90,7 @@ class _LRFinder:
92
90
lr = lr_finder.suggestion()
93
91
"""
94
92
95
- def __init__ (self , mode : str , lr_min : float , lr_max : float , num_training : int ):
93
+ def __init__ (self , mode : str , lr_min : float , lr_max : float , num_training : int ) -> None :
96
94
assert mode in ("linear" , "exponential" ), "mode should be either `linear` or `exponential`"
97
95
98
96
self .mode = mode
@@ -104,38 +102,33 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
104
102
self ._total_batch_idx = 0 # for debug purpose
105
103
106
104
def _exchange_scheduler (self , trainer : "pl.Trainer" , model : "pl.LightningModule" ) -> Callable [["pl.Trainer" ], None ]:
105
+ # TODO: update docs here
107
106
"""Decorate `trainer.strategy.setup_optimizers` method such that it sets the user's originally specified
108
107
optimizer together with a new scheduler that takes care of the learning rate search."""
109
- setup_optimizers = trainer . strategy . setup_optimizers
108
+ from pytorch_lightning . core . optimizer import _set_scheduler_opt_idx
110
109
111
- @wraps (setup_optimizers )
112
- def func (trainer : "pl.Trainer" ) -> None :
113
- # Decide the structure of the output from _init_optimizers_and_lr_schedulers
114
- optimizers , _ , _ = _init_optimizers_and_lr_schedulers (trainer .lightning_module )
110
+ optimizers = trainer .strategy .optimizers
115
111
116
- if len (optimizers ) != 1 :
117
- raise MisconfigurationException (
118
- f"`model.configure_optimizers()` returned { len (optimizers )} , but"
119
- " learning rate finder only works with single optimizer"
120
- )
121
-
122
- optimizer = optimizers [0 ]
112
+ if len (optimizers ) != 1 :
113
+ raise MisconfigurationException (
114
+ f"`model.configure_optimizers()` returned { len (optimizers )} , but"
115
+ " learning rate finder only works with single optimizer"
116
+ )
123
117
124
- new_lrs = [self .lr_min ] * len (optimizer .param_groups )
125
- for param_group , new_lr in zip (optimizer .param_groups , new_lrs ):
126
- param_group ["lr" ] = new_lr
127
- param_group ["initial_lr" ] = new_lr
118
+ optimizer = optimizers [0 ]
128
119
129
- args = (optimizer , self .lr_max , self .num_training )
130
- scheduler = _LinearLR (* args ) if self .mode == "linear" else _ExponentialLR (* args )
131
- scheduler = cast (pl .utilities .types ._LRScheduler , scheduler )
120
+ new_lrs = [self .lr_min ] * len (optimizer .param_groups )
121
+ for param_group , new_lr in zip (optimizer .param_groups , new_lrs ):
122
+ param_group ["lr" ] = new_lr
123
+ param_group ["initial_lr" ] = new_lr
132
124
133
- trainer .strategy .optimizers = [optimizer ]
134
- trainer .strategy .lr_scheduler_configs = [LRSchedulerConfig (scheduler , interval = "step" , opt_idx = 0 )]
135
- trainer .strategy .optimizer_frequencies = []
136
- _set_scheduler_opt_idx (trainer .optimizers , trainer .lr_scheduler_configs )
125
+ args = (optimizer , self .lr_max , self .num_training )
126
+ scheduler = _LinearLR (* args ) if self .mode == "linear" else _ExponentialLR (* args )
127
+ scheduler = cast (pl .utilities .types ._LRScheduler , scheduler )
137
128
138
- return func
129
+ trainer .strategy .optimizers = [optimizer ]
130
+ trainer .strategy .lr_scheduler_configs = [LRSchedulerConfig (scheduler , interval = "step" , opt_idx = 0 )]
131
+ _set_scheduler_opt_idx (trainer .optimizers , trainer .lr_scheduler_configs )
139
132
140
133
def plot (self , suggest : bool = False , show : bool = False ) -> Optional ["plt.Figure" ]:
141
134
"""Plot results from lr_find run
@@ -225,23 +218,25 @@ def lr_find(
225
218
# Save initial model, that is loaded after learning rate is found
226
219
ckpt_path = os .path .join (trainer .default_root_dir , f".lr_find_{ uuid .uuid4 ()} .ckpt" )
227
220
trainer .save_checkpoint (ckpt_path )
221
+
222
+ # Arguments we adjust during the lr finder, save for restoring
228
223
params = __lr_finder_dump_params (trainer )
229
224
230
225
# Set to values that are required by the algorithm
231
226
__lr_finder_reset_params (trainer , num_training , early_stop_threshold )
232
227
233
- # Initialize lr finder object (stores results)
234
- lr_finder = _LRFinder (mode , min_lr , max_lr , num_training )
235
-
236
228
# Disable standard progress bar for fit
237
229
if trainer .progress_bar_callback :
238
230
trainer .progress_bar_callback .disable ()
239
231
232
+ # Initialize lr finder object (stores results)
233
+ lr_finder = _LRFinder (mode , min_lr , max_lr , num_training )
234
+
240
235
# Configure optimizer and scheduler
241
- trainer . strategy . setup_optimizers = lr_finder ._exchange_scheduler (trainer , model ) # type: ignore[assignment]
236
+ lr_finder ._exchange_scheduler (trainer , model )
242
237
243
238
# Fit, lr & loss logged in callback
244
- trainer . tuner . _run ( model )
239
+ _try_loop_run ( trainer , params )
245
240
246
241
# Prompt if we stopped early
247
242
if trainer .global_step != num_training :
@@ -274,31 +269,48 @@ def lr_find(
274
269
275
270
def __lr_finder_dump_params (trainer : "pl.Trainer" ) -> Dict [str , Any ]:
276
271
return {
277
- "auto_lr_find" : trainer .auto_lr_find ,
272
+ "optimizers" : trainer .strategy .optimizers ,
273
+ "lr_scheduler_configs" : trainer .strategy .lr_scheduler_configs ,
274
+ "optimizer_frequencies" : trainer .strategy .optimizer_frequencies ,
278
275
"callbacks" : trainer .callbacks ,
279
- "logger" : trainer .logger ,
276
+ "loggers" : trainer .loggers ,
277
+ # TODO: check if this is required
278
+ "auto_lr_find" : trainer .auto_lr_find ,
280
279
"max_steps" : trainer .fit_loop .max_steps ,
281
- "setup_optimizers" : trainer .strategy .setup_optimizers ,
280
+ "limit_val_batches" : trainer .limit_val_batches ,
281
+ "loop_state_dict" : deepcopy (trainer .fit_loop .state_dict ()),
282
282
}
283
283
284
284
285
285
def __lr_finder_reset_params (trainer : "pl.Trainer" , num_training : int , early_stop_threshold : float ) -> None :
286
+ from pytorch_lightning .loggers .logger import DummyLogger
287
+
288
+ trainer .strategy .lr_scheduler_configs = []
289
+ trainer .strategy .optimizer_frequencies = []
286
290
# avoid lr find being called multiple times
287
291
trainer .auto_lr_find = False
288
292
# Use special lr logger callback
289
293
trainer .callbacks = [_LRCallback (num_training , early_stop_threshold , progress_bar_refresh_rate = 1 )]
290
294
# No logging
291
- trainer .loggers = [ DummyLogger ()] if trainer .loggers else []
295
+ trainer .logger = DummyLogger () if trainer .logger is not None else None
292
296
# Max step set to number of iterations
293
297
trainer .fit_loop .max_steps = num_training
298
+ trainer .limit_val_batches = num_training
294
299
295
300
296
301
def __lr_finder_restore_params (trainer : "pl.Trainer" , params : Dict [str , Any ]) -> None :
302
+ trainer .strategy .optimizers = params ["optimizers" ]
303
+ trainer .strategy .lr_scheduler_configs = params ["lr_scheduler_configs" ]
304
+ trainer .strategy .optimizer_frequencies = params ["optimizer_frequencies" ]
297
305
trainer .auto_lr_find = params ["auto_lr_find" ]
298
306
trainer .callbacks = params ["callbacks" ]
299
- trainer .logger = params ["logger " ]
307
+ trainer .loggers = params ["loggers " ]
300
308
trainer .fit_loop .max_steps = params ["max_steps" ]
301
- trainer .strategy .setup_optimizers = params ["setup_optimizers" ] # type: ignore[assignment]
309
+ trainer .limit_val_batches = params ["limit_val_batches" ]
310
+
311
+ loop = trainer .fit_loop
312
+ loop .load_state_dict (deepcopy (params ["loop_state_dict" ]))
313
+ loop .restarting = False
302
314
303
315
304
316
class _LRCallback (Callback ):
@@ -453,3 +465,10 @@ def get_lr(self) -> List[float]: # type: ignore[override]
453
465
@property
454
466
def lr (self ) -> Union [float , List [float ]]:
455
467
return self ._lr
468
+
469
+
470
+ def _try_loop_run (trainer : "pl.Trainer" , params : Dict [str , Any ]) -> None :
471
+ loop = trainer .fit_loop
472
+ loop .load_state_dict (deepcopy (params ["loop_state_dict" ]))
473
+ loop .restarting = False
474
+ loop .run ()
0 commit comments