|
23 | 23 | from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached, _update_dataloader_iter
|
24 | 24 | from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
25 | 25 | from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
|
| 26 | +from pytorch_lightning.utilities import rank_zero_warn |
26 | 27 | from pytorch_lightning.utilities.apply_func import apply_to_collection
|
27 | 28 | from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
28 | 29 | from pytorch_lightning.utilities.fetching import AbstractDataFetcher
|
@@ -443,12 +444,78 @@ def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) -
|
443 | 444 | active_optimizers = _get_active_optimizers(
|
444 | 445 | self.trainer.optimizers, self.trainer.optimizer_frequencies, self.total_batch_idx
|
445 | 446 | )
|
446 |
| - self.trainer.optimizer_connector.update_learning_rates( |
| 447 | + self._update_learning_rates( |
447 | 448 | interval=interval,
|
448 | 449 | update_plateau_schedulers=update_plateau_schedulers,
|
449 | 450 | opt_indices=[opt_idx for opt_idx, _ in active_optimizers],
|
450 | 451 | )
|
451 | 452 |
|
| 453 | + def _update_learning_rates( |
| 454 | + self, interval: str, update_plateau_schedulers: bool, opt_indices: Optional[List[int]] = None |
| 455 | + ) -> None: |
| 456 | + """Update learning rates. |
| 457 | +
|
| 458 | + Args: |
| 459 | + interval: either 'epoch' or 'step'. |
| 460 | + update_plateau_schedulers: control whether ``ReduceLROnPlateau`` or non-plateau schedulers get updated. |
| 461 | + This is used so non-plateau schedulers can be updated before running validation. Checkpoints are |
| 462 | + commonly saved during validation, however, on-plateau schedulers might monitor a validation metric |
| 463 | + so they have to be updated separately. |
| 464 | + opt_indices: indices of the optimizers to update. |
| 465 | + """ |
| 466 | + if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization: |
| 467 | + return |
| 468 | + |
| 469 | + if opt_indices is None: |
| 470 | + opt_indices = [] |
| 471 | + |
| 472 | + for lr_scheduler in self.trainer.lr_schedulers: |
| 473 | + if isinstance(lr_scheduler["opt_idx"], int) and lr_scheduler["opt_idx"] not in opt_indices: |
| 474 | + continue |
| 475 | + |
| 476 | + if update_plateau_schedulers ^ lr_scheduler["reduce_on_plateau"]: |
| 477 | + continue |
| 478 | + |
| 479 | + current_idx = self.trainer.fit_loop.batch_idx if interval == "step" else self.trainer.current_epoch |
| 480 | + current_idx += 1 # account for both batch and epoch starts from 0 |
| 481 | + # Take step if call to update_learning_rates matches the interval key and |
| 482 | + # the current step modulo the schedulers frequency is zero |
| 483 | + if lr_scheduler["interval"] == interval and current_idx % lr_scheduler["frequency"] == 0: |
| 484 | + monitor_val = None |
| 485 | + if lr_scheduler["reduce_on_plateau"]: |
| 486 | + # If instance of ReduceLROnPlateau, we need a monitor |
| 487 | + monitor_key = lr_scheduler["monitor"] |
| 488 | + monitor_val = self._get_monitor_value(monitor_key) |
| 489 | + if monitor_val is None: |
| 490 | + if lr_scheduler.get("strict", True): |
| 491 | + avail_metrics = list(self.trainer.callback_metrics) |
| 492 | + raise MisconfigurationException( |
| 493 | + f"ReduceLROnPlateau conditioned on metric {monitor_key}" |
| 494 | + f" which is not available. Available metrics are: {avail_metrics}." |
| 495 | + " Condition can be set using `monitor` key in lr scheduler dict" |
| 496 | + ) |
| 497 | + rank_zero_warn( |
| 498 | + f"ReduceLROnPlateau conditioned on metric {monitor_key}" |
| 499 | + " which is not available but strict is set to `False`." |
| 500 | + " Skipping learning rate update.", |
| 501 | + RuntimeWarning, |
| 502 | + ) |
| 503 | + continue |
| 504 | + |
| 505 | + self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_ready() |
| 506 | + |
| 507 | + # update LR |
| 508 | + if lr_scheduler["reduce_on_plateau"]: |
| 509 | + lr_scheduler["scheduler"].step(monitor_val) |
| 510 | + else: |
| 511 | + lr_scheduler["scheduler"].step() |
| 512 | + |
| 513 | + self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_completed() |
| 514 | + |
| 515 | + def _get_monitor_value(self, key: str) -> Any: |
| 516 | + # this is a separate method to aid in testing |
| 517 | + return self.trainer.callback_metrics.get(key) |
| 518 | + |
452 | 519 | def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
|
453 | 520 | """Decide if we should run validation."""
|
454 | 521 | if not self.trainer.enable_validation:
|
|
0 commit comments