@@ -106,18 +106,13 @@ def on_train_start(self, trainer, *args, **kwargs):
106
106
"Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger."
107
107
)
108
108
109
- if not trainer .lr_schedulers :
110
- rank_zero_warn (
111
- "You are using `LearningRateMonitor` callback with models that"
112
- " have no learning rate schedulers. Please see documentation"
113
- " for `configure_optimizers` method." ,
114
- RuntimeWarning ,
115
- )
116
-
117
109
if self .log_momentum :
118
110
119
111
def _check_no_key (key ):
120
- return any (key not in sch ["scheduler" ].optimizer .defaults for sch in trainer .lr_schedulers )
112
+ if trainer .lr_schedulers :
113
+ return any (key not in sch ["scheduler" ].optimizer .defaults for sch in trainer .lr_schedulers )
114
+
115
+ return any (key not in optimizer .defaults for optimizer in trainer .optimizers )
121
116
122
117
if _check_no_key ("momentum" ) and _check_no_key ("betas" ):
123
118
rank_zero_warn (
@@ -127,7 +122,21 @@ def _check_no_key(key):
127
122
)
128
123
129
124
# Find names for schedulers
130
- names = self ._find_names (trainer .lr_schedulers )
125
+ names = []
126
+ (
127
+ sched_hparam_keys ,
128
+ optimizers_with_scheduler ,
129
+ optimizers_with_scheduler_types ,
130
+ ) = self ._find_names_from_schedulers (trainer .lr_schedulers )
131
+ names .extend (sched_hparam_keys )
132
+
133
+ # Find names for leftover optimizers
134
+ optimizer_hparam_keys , _ = self ._find_names_from_optimizers (
135
+ trainer .optimizers ,
136
+ seen_optimizers = optimizers_with_scheduler ,
137
+ seen_optimizer_types = optimizers_with_scheduler_types ,
138
+ )
139
+ names .extend (optimizer_hparam_keys )
131
140
132
141
# Initialize for storing values
133
142
self .lrs = {name : [] for name in names }
@@ -155,26 +164,49 @@ def on_train_epoch_start(self, trainer, *args, **kwargs):
155
164
def _extract_stats (self , trainer , interval : str ) -> Dict [str , float ]:
156
165
latest_stat = {}
157
166
158
- names = self ._find_names (trainer .lr_schedulers , add_lr_sch_names = False )
159
- self ._remap_keys (names )
167
+ (
168
+ scheduler_hparam_keys ,
169
+ optimizers_with_scheduler ,
170
+ optimizers_with_scheduler_types ,
171
+ ) = self ._find_names_from_schedulers (trainer .lr_schedulers , add_lr_sch_names = False )
172
+ self ._remap_keys (scheduler_hparam_keys )
160
173
161
174
for name , scheduler in zip (self .lr_sch_names , trainer .lr_schedulers ):
162
- if scheduler ["interval" ] == interval or interval == "any" :
175
+ if interval in [ scheduler ["interval" ], "any" ] :
163
176
opt = scheduler ["scheduler" ].optimizer
164
- param_groups = opt .param_groups
165
- use_betas = "betas" in opt .defaults
166
-
167
- for i , pg in enumerate (param_groups ):
168
- name_and_suffix = self ._add_suffix (name , param_groups , i )
169
- lr = self ._extract_lr (pg , name_and_suffix )
170
- latest_stat .update (lr )
171
- momentum = self ._extract_momentum (
172
- param_group = pg , name = name_and_suffix .replace (name , f"{ name } -momentum" ), use_betas = use_betas
173
- )
174
- latest_stat .update (momentum )
177
+ current_stat = self ._get_lr_momentum_stat (opt , name )
178
+ latest_stat .update (current_stat )
179
+
180
+ optimizer_hparam_keys , optimizers_without_scheduler = self ._find_names_from_optimizers (
181
+ trainer .optimizers ,
182
+ seen_optimizers = optimizers_with_scheduler ,
183
+ seen_optimizer_types = optimizers_with_scheduler_types ,
184
+ add_lr_sch_names = False ,
185
+ )
186
+ self ._remap_keys (optimizer_hparam_keys )
187
+
188
+ for opt , name in zip (optimizers_without_scheduler , optimizer_hparam_keys ):
189
+ current_stat = self ._get_lr_momentum_stat (opt , name )
190
+ latest_stat .update (current_stat )
175
191
176
192
return latest_stat
177
193
194
+ def _get_lr_momentum_stat (self , optimizer : Optimizer , name : str ) -> None :
195
+ lr_momentum_stat = {}
196
+ param_groups = optimizer .param_groups
197
+ use_betas = "betas" in optimizer .defaults
198
+
199
+ for i , pg in enumerate (param_groups ):
200
+ name_and_suffix = self ._add_suffix (name , param_groups , i )
201
+ lr = self ._extract_lr (pg , name_and_suffix )
202
+ lr_momentum_stat .update (lr )
203
+ momentum = self ._extract_momentum (
204
+ param_group = pg , name = name_and_suffix .replace (name , f"{ name } -momentum" ), use_betas = use_betas
205
+ )
206
+ lr_momentum_stat .update (momentum )
207
+
208
+ return lr_momentum_stat
209
+
178
210
def _extract_lr (self , param_group : Dict [str , Any ], name : str ) -> Dict [str , Any ]:
179
211
lr = param_group .get ("lr" )
180
212
self .lrs [name ].append (lr )
@@ -223,7 +255,7 @@ def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]:
223
255
return set ()
224
256
return {n for n in names if names .count (n ) > 1 }
225
257
226
- def _find_names (self , lr_schedulers : List , add_lr_sch_names : bool = True ) -> List [str ]:
258
+ def _find_names_from_schedulers (self , lr_schedulers : List , add_lr_sch_names : bool = True ) -> List [str ]:
227
259
# Create unique names in the case we have multiple of the same learning
228
260
# rate scheduler + multiple parameter groups
229
261
names = []
@@ -236,28 +268,64 @@ def _find_names(self, lr_schedulers: List, add_lr_sch_names: bool = True) -> Lis
236
268
else :
237
269
name = "lr-" + sch .optimizer .__class__ .__name__
238
270
239
- seen_optimizers .append (sch .optimizer )
240
- optimizer_cls = type (sch .optimizer )
241
- if scheduler ["name" ] is None :
242
- seen_optimizer_types [optimizer_cls ] += 1
243
-
244
- # Multiple param groups for the same scheduler
245
- param_groups = sch .optimizer .param_groups
246
- duplicates = self ._duplicate_param_group_names (param_groups )
247
- if duplicates :
248
- raise MisconfigurationException (
249
- "A single `Optimizer` cannot have multiple parameter groups with identical "
250
- f"`name` values. { name } has duplicated parameter group names { duplicates } "
251
- )
271
+ updated_name = self ._check_duplicates_and_update_name (
272
+ sch .optimizer , name , seen_optimizers , seen_optimizer_types , scheduler , add_lr_sch_names
273
+ )
274
+ names .extend (updated_name )
275
+ return names , seen_optimizers , seen_optimizer_types
276
+
277
+ def _find_names_from_optimizers (
278
+ self , optimizers , seen_optimizers , seen_optimizer_types , add_lr_sch_names : bool = True
279
+ ) -> List [str ]:
280
+ names = []
281
+ optimizers_without_scheduler = []
252
282
253
- name = self ._add_prefix (name , optimizer_cls , seen_optimizer_types )
283
+ for optimizer in optimizers :
284
+ # Deepspeed optimizer wraps the native optimizer
285
+ optimizer = optimizer .optimizer if hasattr (optimizer , "optimizer" ) else optimizer
286
+ if optimizer in seen_optimizers :
287
+ continue
288
+
289
+ name = "lr-" + optimizer .__class__ .__name__
290
+ updated_name = self ._check_duplicates_and_update_name (
291
+ optimizer , name , seen_optimizers , seen_optimizer_types , None , add_lr_sch_names
292
+ )
293
+ names .extend (updated_name )
294
+ optimizers_without_scheduler .append (optimizer )
295
+ return names , optimizers_without_scheduler
296
+
297
+ def _check_duplicates_and_update_name (
298
+ self ,
299
+ optimizer : Optimizer ,
300
+ name : str ,
301
+ seen_optimizers : List ,
302
+ seen_optimizer_types : List ,
303
+ scheduler : Dict [str , Any ] = None ,
304
+ add_lr_sch_names : bool = True ,
305
+ ) -> List :
306
+ seen_optimizers .append (optimizer )
307
+ optimizer_cls = type (optimizer )
308
+ if scheduler is not None and scheduler ["name" ] is None :
309
+ seen_optimizer_types [optimizer_cls ] += 1
310
+ elif scheduler is None :
311
+ seen_optimizer_types [optimizer_cls ] += 1
312
+
313
+ # Multiple param groups for the same optimizer
314
+ param_groups = optimizer .param_groups
315
+ duplicates = self ._duplicate_param_group_names (param_groups )
316
+ if duplicates :
317
+ raise MisconfigurationException (
318
+ "A single `Optimizer` cannot have multiple parameter groups with identical "
319
+ f"`name` values. { name } has duplicated parameter group names { duplicates } "
320
+ )
254
321
255
- names .extend (self ._add_suffix (name , param_groups , i ) for i in range (len (param_groups )))
322
+ name = self ._add_prefix (name , optimizer_cls , seen_optimizer_types )
323
+ name_list = [self ._add_suffix (name , param_groups , i ) for i in range (len (param_groups ))]
256
324
257
- if add_lr_sch_names :
258
- self .lr_sch_names .append (name )
325
+ if add_lr_sch_names :
326
+ self .lr_sch_names .append (name )
259
327
260
- return names
328
+ return name_list
261
329
262
330
@staticmethod
263
331
def _should_log (trainer ) -> bool :
0 commit comments