Skip to content

Commit 1599c77

Browse files
authored
Fix LearningRateMonitor logging with multiple param groups optimizer with no scheduler (#10044)
1 parent 6aeebf1 commit 1599c77

File tree

3 files changed

+76
-24
lines changed

3 files changed

+76
-24
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
596596
- Fixed `train_dataloader` getting loaded twice when resuming from a checkpoint during `Trainer.fit()` ([#9671](https://github.com/PyTorchLightning/pytorch-lightning/pull/9671))
597597

598598

599+
- Fixed `LearningRateMonitor` logging with multiple param groups optimizer with no scheduler ([#10044](https://github.com/PyTorchLightning/pytorch-lightning/pull/10044))
600+
601+
599602

600603
- Fixed undesired side effects being caused by `Trainer` patching dataloader methods on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764))
601604

pytorch_lightning/callbacks/lr_monitor.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Monitor and logs learning rate for lr schedulers during training.
2020
2121
"""
22+
import itertools
2223
from collections import defaultdict
2324
from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Type
2425

@@ -123,7 +124,7 @@ def _check_no_key(key: str) -> bool:
123124
)
124125

125126
# Find names for schedulers
126-
names: List[str] = []
127+
names: List[List[str]] = []
127128
(
128129
sched_hparam_keys,
129130
optimizers_with_scheduler,
@@ -140,8 +141,9 @@ def _check_no_key(key: str) -> bool:
140141
names.extend(optimizer_hparam_keys)
141142

142143
# Initialize for storing values
143-
self.lrs = {name: [] for name in names}
144-
self.last_momentum_values = {name + "-momentum": None for name in names}
144+
names_flatten = list(itertools.chain.from_iterable(names))
145+
self.lrs = {name: [] for name in names_flatten}
146+
self.last_momentum_values = {name + "-momentum": None for name in names_flatten}
145147

146148
def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
147149
if not trainer.logger_connector.should_update_logs:
@@ -172,7 +174,7 @@ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, floa
172174
) = self._find_names_from_schedulers(trainer.lr_schedulers, add_lr_sch_names=False)
173175
self._remap_keys(scheduler_hparam_keys)
174176

175-
for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
177+
for name, scheduler in zip(scheduler_hparam_keys, trainer.lr_schedulers):
176178
if interval in [scheduler["interval"], "any"]:
177179
opt = scheduler["scheduler"].optimizer
178180
current_stat = self._get_lr_momentum_stat(opt, name)
@@ -186,23 +188,22 @@ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, floa
186188
)
187189
self._remap_keys(optimizer_hparam_keys)
188190

189-
for opt, name in zip(optimizers_without_scheduler, optimizer_hparam_keys):
190-
current_stat = self._get_lr_momentum_stat(opt, name)
191+
for opt, names in zip(optimizers_without_scheduler, optimizer_hparam_keys):
192+
current_stat = self._get_lr_momentum_stat(opt, names)
191193
latest_stat.update(current_stat)
192194

193195
return latest_stat
194196

195-
def _get_lr_momentum_stat(self, optimizer: Optimizer, name: str) -> Dict[str, float]:
197+
def _get_lr_momentum_stat(self, optimizer: Optimizer, names: List[str]) -> Dict[str, float]:
196198
lr_momentum_stat = {}
197199
param_groups = optimizer.param_groups
198200
use_betas = "betas" in optimizer.defaults
199201

200-
for i, pg in enumerate(param_groups):
201-
name_and_suffix = self._add_suffix(name, param_groups, i)
202-
lr = self._extract_lr(pg, name_and_suffix)
202+
for pg, name in zip(param_groups, names):
203+
lr = self._extract_lr(pg, name)
203204
lr_momentum_stat.update(lr)
204205
momentum = self._extract_momentum(
205-
param_group=pg, name=name_and_suffix.replace(name, f"{name}-momentum"), use_betas=use_betas
206+
param_group=pg, name=name.replace(name, f"{name}-momentum"), use_betas=use_betas
206207
)
207208
lr_momentum_stat.update(momentum)
208209

@@ -213,14 +214,15 @@ def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]:
213214
self.lrs[name].append(lr)
214215
return {name: lr}
215216

216-
def _remap_keys(self, names: List[str], token: str = "/pg1") -> None:
217+
def _remap_keys(self, names: List[List[str]], token: str = "/pg1") -> None:
217218
"""This function is used the remap the keys if param groups for a given optimizer increased."""
218-
for new_name in names:
219-
old_name = new_name.replace(token, "")
220-
if token in new_name and old_name in self.lrs:
221-
self.lrs[new_name] = self.lrs.pop(old_name)
222-
elif new_name not in self.lrs:
223-
self.lrs[new_name] = []
219+
for group_new_names in names:
220+
for new_name in group_new_names:
221+
old_name = new_name.replace(token, "")
222+
if token in new_name and old_name in self.lrs:
223+
self.lrs[new_name] = self.lrs.pop(old_name)
224+
elif new_name not in self.lrs:
225+
self.lrs[new_name] = []
224226

225227
def _extract_momentum(self, param_group: Dict[str, List], name: str, use_betas: bool) -> Dict[str, float]:
226228
if not self.log_momentum:
@@ -258,7 +260,7 @@ def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]:
258260

259261
def _find_names_from_schedulers(
260262
self, lr_schedulers: List, add_lr_sch_names: bool = True
261-
) -> Tuple[List[str], List[Optimizer], DefaultDict[Type[Optimizer], int]]:
263+
) -> Tuple[List[List[str]], List[Optimizer], DefaultDict[Type[Optimizer], int]]:
262264
# Create unique names in the case we have multiple of the same learning
263265
# rate scheduler + multiple parameter groups
264266
names = []
@@ -271,10 +273,11 @@ def _find_names_from_schedulers(
271273
else:
272274
name = "lr-" + sch.optimizer.__class__.__name__
273275

274-
updated_name = self._check_duplicates_and_update_name(
276+
updated_names = self._check_duplicates_and_update_name(
275277
sch.optimizer, name, seen_optimizers, seen_optimizer_types, scheduler, add_lr_sch_names
276278
)
277-
names.extend(updated_name)
279+
names.append(updated_names)
280+
278281
return names, seen_optimizers, seen_optimizer_types
279282

280283
def _find_names_from_optimizers(
@@ -283,7 +286,7 @@ def _find_names_from_optimizers(
283286
seen_optimizers: List[Optimizer],
284287
seen_optimizer_types: DefaultDict[Type[Optimizer], int],
285288
add_lr_sch_names: bool = True,
286-
) -> Tuple[List[str], List[Optimizer]]:
289+
) -> Tuple[List[List[str]], List[Optimizer]]:
287290
names = []
288291
optimizers_without_scheduler = []
289292

@@ -294,11 +297,12 @@ def _find_names_from_optimizers(
294297
continue
295298

296299
name = "lr-" + optimizer.__class__.__name__
297-
updated_name = self._check_duplicates_and_update_name(
300+
updated_names = self._check_duplicates_and_update_name(
298301
optimizer, name, seen_optimizers, seen_optimizer_types, None, add_lr_sch_names
299302
)
300-
names.extend(updated_name)
303+
names.append(updated_names)
301304
optimizers_without_scheduler.append(optimizer)
305+
302306
return names, optimizers_without_scheduler
303307

304308
def _check_duplicates_and_update_name(

tests/callbacks/test_lr_monitor.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,3 +510,48 @@ def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int):
510510

511511
expected = [0.1, 0.05]
512512
assert lr_monitor.lrs["lr-Adam-1/pg3"] == expected
513+
514+
515+
def test_lr_monitor_multiple_param_groups_no_scheduler(tmpdir):
516+
class TestModel(BoringModel):
517+
def __init__(self, lr, momentum):
518+
super().__init__()
519+
self.save_hyperparameters()
520+
self.linear_a = torch.nn.Linear(32, 16)
521+
self.linear_b = torch.nn.Linear(16, 2)
522+
523+
def forward(self, x):
524+
x = self.linear_a(x)
525+
x = self.linear_b(x)
526+
return x
527+
528+
def configure_optimizers(self):
529+
param_groups = [
530+
{"params": list(self.linear_a.parameters())},
531+
{"params": list(self.linear_b.parameters())},
532+
]
533+
optimizer = torch.optim.Adam(param_groups, lr=self.hparams.lr, betas=self.hparams.momentum)
534+
return optimizer
535+
536+
lr_monitor = LearningRateMonitor(log_momentum=True)
537+
trainer = Trainer(
538+
default_root_dir=tmpdir,
539+
max_epochs=2,
540+
limit_val_batches=2,
541+
limit_train_batches=2,
542+
callbacks=[lr_monitor],
543+
enable_progress_bar=False,
544+
enable_model_summary=False,
545+
)
546+
547+
lr = 1e-2
548+
momentum = 0.7
549+
model = TestModel(lr=lr, momentum=(momentum, 0.999))
550+
trainer.fit(model)
551+
552+
assert len(lr_monitor.lrs) == len(trainer.optimizers[0].param_groups)
553+
assert list(lr_monitor.lrs.keys()) == ["lr-Adam/pg1", "lr-Adam/pg2"]
554+
assert lr_monitor.lr_sch_names == ["lr-Adam"]
555+
assert list(lr_monitor.last_momentum_values.keys()) == ["lr-Adam/pg1-momentum", "lr-Adam/pg2-momentum"]
556+
assert all(val == momentum for val in lr_monitor.last_momentum_values.values())
557+
assert all(all(val == lr for val in lr_monitor.lrs[lr_key]) for lr_key in lr_monitor.lrs)

0 commit comments

Comments
 (0)