Skip to content

move get_active_optimizers to utilities #9581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 25, 2021
5 changes: 4 additions & 1 deletion pytorch_lightning/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,10 @@ def _store(

def on_train_epoch_start(self, trainer, pl_module):
"""Called when the epoch begins."""
for opt_idx, optimizer in trainer.fit_loop.epoch_loop.batch_loop.get_active_optimizers():
# import is here to avoid circular imports
from pytorch_lightning.loops.utilities import _get_active_optimizers

for opt_idx, optimizer in _get_active_optimizers(trainer.optimizers, trainer.optimizer_frequencies):
num_param_groups = len(optimizer.param_groups)
self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
current_param_groups = optimizer.param_groups
Expand Down
38 changes: 4 additions & 34 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional

import numpy as np
from deprecate import void
from torch import Tensor
from torch.optim import Optimizer

from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop
from pytorch_lightning.loops.utilities import _get_active_optimizers
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand All @@ -41,21 +40,13 @@ def __init__(self) -> None:
self.manual_loop = ManualOptimization()

self._warning_cache: WarningCache = WarningCache()
self._optimizer_freq_cumsum: Optional[int] = None
self._remaining_splits: Optional[List[Any]] = None

@property
def done(self) -> bool:
"""Returns if all batch splits have been processed already."""
return len(self._remaining_splits) == 0

@property
def optimizer_freq_cumsum(self) -> int:
"""Returns the cumulated sum of optimizer frequencies."""
if self._optimizer_freq_cumsum is None:
self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
return self._optimizer_freq_cumsum

def connect(
self, optimizer_loop: Optional["Loop"] = None, manual_loop: Optional[ManualOptimization] = None
) -> None:
Expand Down Expand Up @@ -123,7 +114,8 @@ def advance(self, batch, batch_idx):

if self.trainer.lightning_module.automatic_optimization:
# in automatic optimization, hand over execution to the OptimizerLoop
batch_outputs = self.optimizer_loop.run(split_batch, self.get_active_optimizers(batch_idx), batch_idx)
optimizers = _get_active_optimizers(self.trainer.optimizers, self.trainer.optimizer_frequencies, batch_idx)
batch_outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
# combine outputs from each optimizer
for k in range(len(batch_outputs)):
self.batch_outputs[k].extend(batch_outputs[k])
Expand All @@ -142,10 +134,6 @@ def teardown(self) -> None:
# release memory
self._remaining_splits = None

def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int:
"""Gets the number of active optimizers based on their frequency."""
return len(self.get_active_optimizers(batch_idx))

def _tbptt_split_batch(self, batch: Any) -> List[Any]:
"""Splits a single batch into a list of sequence steps for tbptt.

Expand Down Expand Up @@ -175,21 +163,3 @@ def _update_running_loss(self, current_loss: Tensor) -> None:

# reset for next set of accumulated grads
self.accumulated_loss.reset()

def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, Optimizer]]:
"""Returns the currently active optimizers. When multiple optimizers are used with different frequencies,
only one of the optimizers is active at a time.

Returns:
A list of tuples (opt_idx, optimizer) of currently active optimizers.
"""
if not self.trainer.optimizer_frequencies:
# call training_step once per optimizer
return list(enumerate(self.trainer.optimizers))

optimizers_loop_length = self.optimizer_freq_cumsum[-1]
current_place_in_loop = batch_idx % optimizers_loop_length

# find optimzier index by looking for the first {item > current_place} in the cumsum list
opt_idx = np.searchsorted(self.optimizer_freq_cumsum, current_place_in_loop, side="right")
return [(opt_idx, self.trainer.optimizers[opt_idx])]
13 changes: 10 additions & 3 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pytorch_lightning import loops # import as loops to avoid circular imports
from pytorch_lightning.loops.batch import TrainingBatchLoop
from pytorch_lightning.loops.optimization.closure import OutputResult
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
from pytorch_lightning.loops.utilities import _get_active_optimizers, _prepare_dataloader_iter
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -97,7 +97,7 @@ def reset(self) -> None:
self.batch_loop.optimizer_loop.optim_progress.reset_on_restart()

# track epoch output
self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))]
self._epoch_output = [[] for _ in range(self._num_active_optimizers(self.total_batch_idx))]

if not self.restarting or self._num_training_batches_reached():
self.batch_progress.reset_on_epoch()
Expand Down Expand Up @@ -334,10 +334,13 @@ def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) -
"""updates the lr schedulers based on the given interval."""
if interval == "step" and self._should_accumulate():
return
active_optimizers = _get_active_optimizers(
self.trainer.optimizers, self.trainer.optimizer_frequencies, self.total_batch_idx
)
self.trainer.optimizer_connector.update_learning_rates(
interval=interval,
update_plateau_schedulers=update_plateau_schedulers,
opt_indices=[opt_idx for opt_idx, _ in self.batch_loop.get_active_optimizers(self.total_batch_idx)],
opt_indices=[opt_idx for opt_idx, _ in active_optimizers],
)

def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
Expand Down Expand Up @@ -371,3 +374,7 @@ def _save_loggers_on_train_batch_end(self) -> None:
should_flush_logs = self.trainer.logger_connector.should_flush_logs
if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None:
self.trainer.logger.save()

def _num_active_optimizers(self, batch_idx: Optional[int] = None) -> int:
"""Gets the number of active optimizers based on their frequency."""
return len(_get_active_optimizers(self.trainer.optimizers, self.trainer.optimizer_frequencies, batch_idx))
31 changes: 30 additions & 1 deletion pytorch_lightning/loops/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.
from collections import OrderedDict
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, Optional, Sequence
from functools import lru_cache
from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple

import numpy as np
import torch
from torch.optim import Optimizer

Expand Down Expand Up @@ -139,3 +141,30 @@ def _block_parallel_sync_behavior(trainer: "pl.Trainer", block: bool = True) ->
yield None
else:
yield None


@lru_cache(1)
def _cumulative_optimizer_frequencies(frequencies: Tuple[int]):
return np.cumsum(frequencies)


def _get_active_optimizers(
optimizers: List[Optimizer], frequencies: List[int], batch_idx: Optional[int] = None
) -> List[Tuple[int, Optimizer]]:
"""Returns the currently active optimizers. When multiple optimizers are used with different frequencies, only
one of the optimizers is active at a time.

Returns:
A list of tuples (opt_idx, optimizer) of currently active optimizers.
"""
if not frequencies:
# call training_step once per optimizer
return list(enumerate(optimizers))

freq_cumsum = _cumulative_optimizer_frequencies(tuple(frequencies))
optimizers_loop_length = freq_cumsum[-1]
current_place_in_loop = batch_idx % optimizers_loop_length

# find optimizer index by looking for the first {item > current_place} in the cumsum list
opt_idx = np.searchsorted(freq_cumsum, current_place_in_loop, side="right")
return [(opt_idx, optimizers[opt_idx])]