Skip to content

Commit 13d6d7b

Browse files
Remove optimizer_connector.py (#10120)
1 parent 21a5867 commit 13d6d7b

File tree

5 files changed

+79
-104
lines changed

5 files changed

+79
-104
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
557557
- Removed automatic patching of `{train,val,test,predict}_dataloader()` on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764))
558558

559559

560+
- Removed `pytorch_lightning.trainer.connectors.OptimizerConnector` ([#10120](https://github.com/PyTorchLightning/pytorch-lightning/pull/10120))
561+
562+
560563
### Fixed
561564

562565

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached, _update_dataloader_iter
2424
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
2525
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
26+
from pytorch_lightning.utilities import rank_zero_warn
2627
from pytorch_lightning.utilities.apply_func import apply_to_collection
2728
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2829
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
@@ -443,12 +444,78 @@ def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) -
443444
active_optimizers = _get_active_optimizers(
444445
self.trainer.optimizers, self.trainer.optimizer_frequencies, self.total_batch_idx
445446
)
446-
self.trainer.optimizer_connector.update_learning_rates(
447+
self._update_learning_rates(
447448
interval=interval,
448449
update_plateau_schedulers=update_plateau_schedulers,
449450
opt_indices=[opt_idx for opt_idx, _ in active_optimizers],
450451
)
451452

453+
def _update_learning_rates(
454+
self, interval: str, update_plateau_schedulers: bool, opt_indices: Optional[List[int]] = None
455+
) -> None:
456+
"""Update learning rates.
457+
458+
Args:
459+
interval: either 'epoch' or 'step'.
460+
update_plateau_schedulers: control whether ``ReduceLROnPlateau`` or non-plateau schedulers get updated.
461+
This is used so non-plateau schedulers can be updated before running validation. Checkpoints are
462+
commonly saved during validation, however, on-plateau schedulers might monitor a validation metric
463+
so they have to be updated separately.
464+
opt_indices: indices of the optimizers to update.
465+
"""
466+
if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization:
467+
return
468+
469+
if opt_indices is None:
470+
opt_indices = []
471+
472+
for lr_scheduler in self.trainer.lr_schedulers:
473+
if isinstance(lr_scheduler["opt_idx"], int) and lr_scheduler["opt_idx"] not in opt_indices:
474+
continue
475+
476+
if update_plateau_schedulers ^ lr_scheduler["reduce_on_plateau"]:
477+
continue
478+
479+
current_idx = self.trainer.fit_loop.batch_idx if interval == "step" else self.trainer.current_epoch
480+
current_idx += 1 # account for both batch and epoch starts from 0
481+
# Take step if call to update_learning_rates matches the interval key and
482+
# the current step modulo the schedulers frequency is zero
483+
if lr_scheduler["interval"] == interval and current_idx % lr_scheduler["frequency"] == 0:
484+
monitor_val = None
485+
if lr_scheduler["reduce_on_plateau"]:
486+
# If instance of ReduceLROnPlateau, we need a monitor
487+
monitor_key = lr_scheduler["monitor"]
488+
monitor_val = self._get_monitor_value(monitor_key)
489+
if monitor_val is None:
490+
if lr_scheduler.get("strict", True):
491+
avail_metrics = list(self.trainer.callback_metrics)
492+
raise MisconfigurationException(
493+
f"ReduceLROnPlateau conditioned on metric {monitor_key}"
494+
f" which is not available. Available metrics are: {avail_metrics}."
495+
" Condition can be set using `monitor` key in lr scheduler dict"
496+
)
497+
rank_zero_warn(
498+
f"ReduceLROnPlateau conditioned on metric {monitor_key}"
499+
" which is not available but strict is set to `False`."
500+
" Skipping learning rate update.",
501+
RuntimeWarning,
502+
)
503+
continue
504+
505+
self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_ready()
506+
507+
# update LR
508+
if lr_scheduler["reduce_on_plateau"]:
509+
lr_scheduler["scheduler"].step(monitor_val)
510+
else:
511+
lr_scheduler["scheduler"].step()
512+
513+
self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_completed()
514+
515+
def _get_monitor_value(self, key: str) -> Any:
516+
# this is a separate method to aid in testing
517+
return self.trainer.callback_metrics.get(key)
518+
452519
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
453520
"""Decide if we should run validation."""
454521
if not self.trainer.enable_validation:

pytorch_lightning/trainer/connectors/optimizer_connector.py

Lines changed: 0 additions & 95 deletions
This file was deleted.

pytorch_lightning/trainer/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
5959
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
6060
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
61-
from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector
6261
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
6362
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
6463
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
@@ -430,7 +429,6 @@ def __init__(
430429

431430
# init connectors
432431
self._data_connector = DataConnector(self, multiple_trainloader_mode)
433-
self.optimizer_connector = OptimizerConnector(self)
434432

435433
self._accelerator_connector = AcceleratorConnector(
436434
num_processes,
@@ -517,7 +515,9 @@ def __init__(
517515
self.on_init_start()
518516

519517
# init optimizer + lr scheduler related flags
520-
self.optimizer_connector.on_trainer_init()
518+
self.lr_schedulers = []
519+
self.optimizers = []
520+
self.optimizer_frequencies = []
521521

522522
# init data flags
523523
self._data_connector.on_trainer_init(

tests/checkpointing/test_model_checkpoint.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,17 @@ def validation_epoch_end(self, outputs):
6363
self.log("val_acc", outs)
6464

6565

66-
def mock_optimizer_connector(trainer):
66+
def mock_training_epoch_loop(trainer):
6767
# do not use `unittest.Mock` because we need to store the return value
6868
calls = {}
69-
old_get_monitor_value = trainer.optimizer_connector._get_monitor_value
69+
old_get_monitor_value = trainer.fit_loop.epoch_loop._get_monitor_value
7070

7171
def mock(key):
7272
value = old_get_monitor_value(key)
7373
calls[trainer.current_epoch] = {key: value}
7474
return value
7575

76-
trainer.optimizer_connector._get_monitor_value = mock
76+
trainer.fit_loop.epoch_loop._get_monitor_value = mock
7777
return calls
7878

7979

@@ -150,7 +150,7 @@ def on_validation_epoch_end(self):
150150
max_epochs=max_epochs,
151151
enable_progress_bar=False,
152152
)
153-
calls = mock_optimizer_connector(trainer)
153+
calls = mock_training_epoch_loop(trainer)
154154
trainer.fit(model)
155155

156156
ckpt_files = list(Path(tmpdir).glob("*.ckpt"))
@@ -248,7 +248,7 @@ def configure_optimizers(self):
248248
enable_progress_bar=False,
249249
num_sanity_val_steps=0,
250250
)
251-
calls = mock_optimizer_connector(trainer)
251+
calls = mock_training_epoch_loop(trainer)
252252
trainer.fit(model)
253253

254254
def _make_assertions(epoch, ix):

0 commit comments

Comments
 (0)