19
19
Monitor and logs learning rate for lr schedulers during training.
20
20
21
21
"""
22
+ import itertools
22
23
from collections import defaultdict
23
24
from typing import Any , DefaultDict , Dict , List , Optional , Set , Tuple , Type
24
25
@@ -123,7 +124,7 @@ def _check_no_key(key: str) -> bool:
123
124
)
124
125
125
126
# Find names for schedulers
126
- names : List [str ] = []
127
+ names : List [List [ str ] ] = []
127
128
(
128
129
sched_hparam_keys ,
129
130
optimizers_with_scheduler ,
@@ -140,8 +141,9 @@ def _check_no_key(key: str) -> bool:
140
141
names .extend (optimizer_hparam_keys )
141
142
142
143
# Initialize for storing values
143
- self .lrs = {name : [] for name in names }
144
- self .last_momentum_values = {name + "-momentum" : None for name in names }
144
+ names_flatten = list (itertools .chain .from_iterable (names ))
145
+ self .lrs = {name : [] for name in names_flatten }
146
+ self .last_momentum_values = {name + "-momentum" : None for name in names_flatten }
145
147
146
148
def on_train_batch_start (self , trainer : "pl.Trainer" , * args : Any , ** kwargs : Any ) -> None :
147
149
if not trainer .logger_connector .should_update_logs :
@@ -172,7 +174,7 @@ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, floa
172
174
) = self ._find_names_from_schedulers (trainer .lr_schedulers , add_lr_sch_names = False )
173
175
self ._remap_keys (scheduler_hparam_keys )
174
176
175
- for name , scheduler in zip (self . lr_sch_names , trainer .lr_schedulers ):
177
+ for name , scheduler in zip (scheduler_hparam_keys , trainer .lr_schedulers ):
176
178
if interval in [scheduler ["interval" ], "any" ]:
177
179
opt = scheduler ["scheduler" ].optimizer
178
180
current_stat = self ._get_lr_momentum_stat (opt , name )
@@ -186,23 +188,22 @@ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, floa
186
188
)
187
189
self ._remap_keys (optimizer_hparam_keys )
188
190
189
- for opt , name in zip (optimizers_without_scheduler , optimizer_hparam_keys ):
190
- current_stat = self ._get_lr_momentum_stat (opt , name )
191
+ for opt , names in zip (optimizers_without_scheduler , optimizer_hparam_keys ):
192
+ current_stat = self ._get_lr_momentum_stat (opt , names )
191
193
latest_stat .update (current_stat )
192
194
193
195
return latest_stat
194
196
195
- def _get_lr_momentum_stat (self , optimizer : Optimizer , name : str ) -> Dict [str , float ]:
197
+ def _get_lr_momentum_stat (self , optimizer : Optimizer , names : List [ str ] ) -> Dict [str , float ]:
196
198
lr_momentum_stat = {}
197
199
param_groups = optimizer .param_groups
198
200
use_betas = "betas" in optimizer .defaults
199
201
200
- for i , pg in enumerate (param_groups ):
201
- name_and_suffix = self ._add_suffix (name , param_groups , i )
202
- lr = self ._extract_lr (pg , name_and_suffix )
202
+ for pg , name in zip (param_groups , names ):
203
+ lr = self ._extract_lr (pg , name )
203
204
lr_momentum_stat .update (lr )
204
205
momentum = self ._extract_momentum (
205
- param_group = pg , name = name_and_suffix .replace (name , f"{ name } -momentum" ), use_betas = use_betas
206
+ param_group = pg , name = name .replace (name , f"{ name } -momentum" ), use_betas = use_betas
206
207
)
207
208
lr_momentum_stat .update (momentum )
208
209
@@ -213,14 +214,15 @@ def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]:
213
214
self .lrs [name ].append (lr )
214
215
return {name : lr }
215
216
216
- def _remap_keys (self , names : List [str ], token : str = "/pg1" ) -> None :
217
+ def _remap_keys (self , names : List [List [ str ] ], token : str = "/pg1" ) -> None :
217
218
"""This function is used the remap the keys if param groups for a given optimizer increased."""
218
- for new_name in names :
219
- old_name = new_name .replace (token , "" )
220
- if token in new_name and old_name in self .lrs :
221
- self .lrs [new_name ] = self .lrs .pop (old_name )
222
- elif new_name not in self .lrs :
223
- self .lrs [new_name ] = []
219
+ for group_new_names in names :
220
+ for new_name in group_new_names :
221
+ old_name = new_name .replace (token , "" )
222
+ if token in new_name and old_name in self .lrs :
223
+ self .lrs [new_name ] = self .lrs .pop (old_name )
224
+ elif new_name not in self .lrs :
225
+ self .lrs [new_name ] = []
224
226
225
227
def _extract_momentum (self , param_group : Dict [str , List ], name : str , use_betas : bool ) -> Dict [str , float ]:
226
228
if not self .log_momentum :
@@ -258,7 +260,7 @@ def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]:
258
260
259
261
def _find_names_from_schedulers (
260
262
self , lr_schedulers : List , add_lr_sch_names : bool = True
261
- ) -> Tuple [List [str ], List [Optimizer ], DefaultDict [Type [Optimizer ], int ]]:
263
+ ) -> Tuple [List [List [ str ] ], List [Optimizer ], DefaultDict [Type [Optimizer ], int ]]:
262
264
# Create unique names in the case we have multiple of the same learning
263
265
# rate scheduler + multiple parameter groups
264
266
names = []
@@ -271,10 +273,11 @@ def _find_names_from_schedulers(
271
273
else :
272
274
name = "lr-" + sch .optimizer .__class__ .__name__
273
275
274
- updated_name = self ._check_duplicates_and_update_name (
276
+ updated_names = self ._check_duplicates_and_update_name (
275
277
sch .optimizer , name , seen_optimizers , seen_optimizer_types , scheduler , add_lr_sch_names
276
278
)
277
- names .extend (updated_name )
279
+ names .append (updated_names )
280
+
278
281
return names , seen_optimizers , seen_optimizer_types
279
282
280
283
def _find_names_from_optimizers (
@@ -283,7 +286,7 @@ def _find_names_from_optimizers(
283
286
seen_optimizers : List [Optimizer ],
284
287
seen_optimizer_types : DefaultDict [Type [Optimizer ], int ],
285
288
add_lr_sch_names : bool = True ,
286
- ) -> Tuple [List [str ], List [Optimizer ]]:
289
+ ) -> Tuple [List [List [ str ] ], List [Optimizer ]]:
287
290
names = []
288
291
optimizers_without_scheduler = []
289
292
@@ -294,11 +297,12 @@ def _find_names_from_optimizers(
294
297
continue
295
298
296
299
name = "lr-" + optimizer .__class__ .__name__
297
- updated_name = self ._check_duplicates_and_update_name (
300
+ updated_names = self ._check_duplicates_and_update_name (
298
301
optimizer , name , seen_optimizers , seen_optimizer_types , None , add_lr_sch_names
299
302
)
300
- names .extend ( updated_name )
303
+ names .append ( updated_names )
301
304
optimizers_without_scheduler .append (optimizer )
305
+
302
306
return names , optimizers_without_scheduler
303
307
304
308
def _check_duplicates_and_update_name (
0 commit comments