Skip to content

Commit 1569869

Browse files
VirajBagalpre-commit-ci[bot]rohitgr7
authored
Log LR using LearningRateMonitor even when LR Scheduler is not defined. (#9786)
* LR logging works even with no lr scheduler, wrote few extra tests as well * updated changelog * modified code as suggested by DeepSource * added helper functions * opt with no scheduler * rename * chlog * update test Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: rohitgr7 <[email protected]>
1 parent 940b910 commit 1569869

File tree

3 files changed

+253
-82
lines changed

3 files changed

+253
-82
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010
### Added
1111

1212

13+
- Add support for monitoring the learning rate monitor without schedulers in `LearningRateMonitor` ([#9786](https://github.com/PyTorchLightning/pytorch-lightning/issues/9786))
14+
15+
1316
- Register `ShardedTensor` state dict hooks in `LightningModule.__init__` if the pytorch version supports `ShardedTensor` ([#8944](https://github.com/PyTorchLightning/pytorch-lightning/pull/8944))
1417

1518

pytorch_lightning/callbacks/lr_monitor.py

Lines changed: 111 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,13 @@ def on_train_start(self, trainer, *args, **kwargs):
106106
"Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger."
107107
)
108108

109-
if not trainer.lr_schedulers:
110-
rank_zero_warn(
111-
"You are using `LearningRateMonitor` callback with models that"
112-
" have no learning rate schedulers. Please see documentation"
113-
" for `configure_optimizers` method.",
114-
RuntimeWarning,
115-
)
116-
117109
if self.log_momentum:
118110

119111
def _check_no_key(key):
120-
return any(key not in sch["scheduler"].optimizer.defaults for sch in trainer.lr_schedulers)
112+
if trainer.lr_schedulers:
113+
return any(key not in sch["scheduler"].optimizer.defaults for sch in trainer.lr_schedulers)
114+
115+
return any(key not in optimizer.defaults for optimizer in trainer.optimizers)
121116

122117
if _check_no_key("momentum") and _check_no_key("betas"):
123118
rank_zero_warn(
@@ -127,7 +122,21 @@ def _check_no_key(key):
127122
)
128123

129124
# Find names for schedulers
130-
names = self._find_names(trainer.lr_schedulers)
125+
names = []
126+
(
127+
sched_hparam_keys,
128+
optimizers_with_scheduler,
129+
optimizers_with_scheduler_types,
130+
) = self._find_names_from_schedulers(trainer.lr_schedulers)
131+
names.extend(sched_hparam_keys)
132+
133+
# Find names for leftover optimizers
134+
optimizer_hparam_keys, _ = self._find_names_from_optimizers(
135+
trainer.optimizers,
136+
seen_optimizers=optimizers_with_scheduler,
137+
seen_optimizer_types=optimizers_with_scheduler_types,
138+
)
139+
names.extend(optimizer_hparam_keys)
131140

132141
# Initialize for storing values
133142
self.lrs = {name: [] for name in names}
@@ -155,26 +164,49 @@ def on_train_epoch_start(self, trainer, *args, **kwargs):
155164
def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
156165
latest_stat = {}
157166

158-
names = self._find_names(trainer.lr_schedulers, add_lr_sch_names=False)
159-
self._remap_keys(names)
167+
(
168+
scheduler_hparam_keys,
169+
optimizers_with_scheduler,
170+
optimizers_with_scheduler_types,
171+
) = self._find_names_from_schedulers(trainer.lr_schedulers, add_lr_sch_names=False)
172+
self._remap_keys(scheduler_hparam_keys)
160173

161174
for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
162-
if scheduler["interval"] == interval or interval == "any":
175+
if interval in [scheduler["interval"], "any"]:
163176
opt = scheduler["scheduler"].optimizer
164-
param_groups = opt.param_groups
165-
use_betas = "betas" in opt.defaults
166-
167-
for i, pg in enumerate(param_groups):
168-
name_and_suffix = self._add_suffix(name, param_groups, i)
169-
lr = self._extract_lr(pg, name_and_suffix)
170-
latest_stat.update(lr)
171-
momentum = self._extract_momentum(
172-
param_group=pg, name=name_and_suffix.replace(name, f"{name}-momentum"), use_betas=use_betas
173-
)
174-
latest_stat.update(momentum)
177+
current_stat = self._get_lr_momentum_stat(opt, name)
178+
latest_stat.update(current_stat)
179+
180+
optimizer_hparam_keys, optimizers_without_scheduler = self._find_names_from_optimizers(
181+
trainer.optimizers,
182+
seen_optimizers=optimizers_with_scheduler,
183+
seen_optimizer_types=optimizers_with_scheduler_types,
184+
add_lr_sch_names=False,
185+
)
186+
self._remap_keys(optimizer_hparam_keys)
187+
188+
for opt, name in zip(optimizers_without_scheduler, optimizer_hparam_keys):
189+
current_stat = self._get_lr_momentum_stat(opt, name)
190+
latest_stat.update(current_stat)
175191

176192
return latest_stat
177193

194+
def _get_lr_momentum_stat(self, optimizer: Optimizer, name: str) -> None:
195+
lr_momentum_stat = {}
196+
param_groups = optimizer.param_groups
197+
use_betas = "betas" in optimizer.defaults
198+
199+
for i, pg in enumerate(param_groups):
200+
name_and_suffix = self._add_suffix(name, param_groups, i)
201+
lr = self._extract_lr(pg, name_and_suffix)
202+
lr_momentum_stat.update(lr)
203+
momentum = self._extract_momentum(
204+
param_group=pg, name=name_and_suffix.replace(name, f"{name}-momentum"), use_betas=use_betas
205+
)
206+
lr_momentum_stat.update(momentum)
207+
208+
return lr_momentum_stat
209+
178210
def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]:
179211
lr = param_group.get("lr")
180212
self.lrs[name].append(lr)
@@ -223,7 +255,7 @@ def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]:
223255
return set()
224256
return {n for n in names if names.count(n) > 1}
225257

226-
def _find_names(self, lr_schedulers: List, add_lr_sch_names: bool = True) -> List[str]:
258+
def _find_names_from_schedulers(self, lr_schedulers: List, add_lr_sch_names: bool = True) -> List[str]:
227259
# Create unique names in the case we have multiple of the same learning
228260
# rate scheduler + multiple parameter groups
229261
names = []
@@ -236,28 +268,64 @@ def _find_names(self, lr_schedulers: List, add_lr_sch_names: bool = True) -> Lis
236268
else:
237269
name = "lr-" + sch.optimizer.__class__.__name__
238270

239-
seen_optimizers.append(sch.optimizer)
240-
optimizer_cls = type(sch.optimizer)
241-
if scheduler["name"] is None:
242-
seen_optimizer_types[optimizer_cls] += 1
243-
244-
# Multiple param groups for the same scheduler
245-
param_groups = sch.optimizer.param_groups
246-
duplicates = self._duplicate_param_group_names(param_groups)
247-
if duplicates:
248-
raise MisconfigurationException(
249-
"A single `Optimizer` cannot have multiple parameter groups with identical "
250-
f"`name` values. {name} has duplicated parameter group names {duplicates}"
251-
)
271+
updated_name = self._check_duplicates_and_update_name(
272+
sch.optimizer, name, seen_optimizers, seen_optimizer_types, scheduler, add_lr_sch_names
273+
)
274+
names.extend(updated_name)
275+
return names, seen_optimizers, seen_optimizer_types
276+
277+
def _find_names_from_optimizers(
278+
self, optimizers, seen_optimizers, seen_optimizer_types, add_lr_sch_names: bool = True
279+
) -> List[str]:
280+
names = []
281+
optimizers_without_scheduler = []
252282

253-
name = self._add_prefix(name, optimizer_cls, seen_optimizer_types)
283+
for optimizer in optimizers:
284+
# Deepspeed optimizer wraps the native optimizer
285+
optimizer = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer
286+
if optimizer in seen_optimizers:
287+
continue
288+
289+
name = "lr-" + optimizer.__class__.__name__
290+
updated_name = self._check_duplicates_and_update_name(
291+
optimizer, name, seen_optimizers, seen_optimizer_types, None, add_lr_sch_names
292+
)
293+
names.extend(updated_name)
294+
optimizers_without_scheduler.append(optimizer)
295+
return names, optimizers_without_scheduler
296+
297+
def _check_duplicates_and_update_name(
298+
self,
299+
optimizer: Optimizer,
300+
name: str,
301+
seen_optimizers: List,
302+
seen_optimizer_types: List,
303+
scheduler: Dict[str, Any] = None,
304+
add_lr_sch_names: bool = True,
305+
) -> List:
306+
seen_optimizers.append(optimizer)
307+
optimizer_cls = type(optimizer)
308+
if scheduler is not None and scheduler["name"] is None:
309+
seen_optimizer_types[optimizer_cls] += 1
310+
elif scheduler is None:
311+
seen_optimizer_types[optimizer_cls] += 1
312+
313+
# Multiple param groups for the same optimizer
314+
param_groups = optimizer.param_groups
315+
duplicates = self._duplicate_param_group_names(param_groups)
316+
if duplicates:
317+
raise MisconfigurationException(
318+
"A single `Optimizer` cannot have multiple parameter groups with identical "
319+
f"`name` values. {name} has duplicated parameter group names {duplicates}"
320+
)
254321

255-
names.extend(self._add_suffix(name, param_groups, i) for i in range(len(param_groups)))
322+
name = self._add_prefix(name, optimizer_cls, seen_optimizer_types)
323+
name_list = [self._add_suffix(name, param_groups, i) for i in range(len(param_groups))]
256324

257-
if add_lr_sch_names:
258-
self.lr_sch_names.append(name)
325+
if add_lr_sch_names:
326+
self.lr_sch_names.append(name)
259327

260-
return names
328+
return name_list
261329

262330
@staticmethod
263331
def _should_log(trainer) -> bool:

0 commit comments

Comments
 (0)