Skip to content

Add typing for LightningOptimizer #9990

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ module = [
"pytorch_lightning.callbacks.model_summary",
"pytorch_lightning.callbacks.pruning",
"pytorch_lightning.callbacks.rich_model_summary",
"pytorch_lightning.core.optimizer",
"pytorch_lightning.loops.optimization.*",
"pytorch_lightning.loops.evaluation_loop",
"pytorch_lightning.trainer.connectors.checkpoint_connector",
Expand Down
58 changes: 30 additions & 28 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Tuple, Union

import torch
from torch import ScriptModule, Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric
from typing_extensions import Literal

import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress import base as progress_base
Expand Down Expand Up @@ -120,6 +121,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
# deprecated, will be removed in 1.6
self._loaded_optimizer_states_dict = {}

@overload
def optimizers(self, use_pl_optimizer: Literal[True] = True) -> Union[LightningOptimizer, List[LightningOptimizer]]:
...

@overload
def optimizers(self, use_pl_optimizer: Literal[False]) -> Union[Optimizer, List[Optimizer]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we need to add a third overload here for this to work as expected.

Eg.

@overload
def optimizer(self, use_pl_optimizer: bool) -> Union[...]

See the error that is raised as currently written here: https://mypy-play.net/?mypy=latest&python=3.8&gist=183e84ea2331260b77cf58c0bfbf1151

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are correct, this only works if True or False are explicitly passed.

I'll open a PR with the fix.

...

def optimizers(
self, use_pl_optimizer: bool = True
) -> Union[Optimizer, LightningOptimizer, List[Optimizer], List[LightningOptimizer]]:
Expand Down Expand Up @@ -1426,17 +1435,16 @@ def backward(self, loss, optimizer, optimizer_idx):
"""
loss.backward(*args, **kwargs)

def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer], optimizer_idx: int) -> None:
"""Makes sure only the gradients of the current optimizer's parameters are calculated in the training step
to prevent dangling gradients in multiple-optimizer setup. It works with :meth:`untoggle_optimizer` to make
sure ``param_requires_grad_state`` is properly reset. Override for your own behavior.
to prevent dangling gradients in multiple-optimizer setup.

Args:
optimizer: Current optimizer used in the training loop
optimizer_idx: Current optimizer idx in the training loop
This is only called automatically when automatic optimization is enabled and multiple optimizers are used.
It works with :meth:`untoggle_optimizer` to make sure ``param_requires_grad_state`` is properly reset.

Note:
Only called when using multiple optimizers
Args:
optimizer: The optimizer to toggle.
optimizer_idx: The index of the optimizer to toggle.
"""
# Iterate over all optimizer parameters to preserve their `requires_grad` information
# in case these are pre-defined during `configure_optimizers`
Expand All @@ -1457,15 +1465,13 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
param.requires_grad = param_requires_grad_state[param]
self._param_requires_grad_state = param_requires_grad_state

def untoggle_optimizer(self, optimizer_idx: int):
"""Resets the state of required gradients that were toggled with :meth:`toggle_optimizer`. Override for
your own behavior.
def untoggle_optimizer(self, optimizer_idx: int) -> None:
"""Resets the state of required gradients that were toggled with :meth:`toggle_optimizer`.

Args:
optimizer_idx: Current optimizer idx in the training loop
This is only called automatically when automatic optimization is enabled and multiple optimizers are used.

Note:
Only called when using multiple_optimizers
Args:
optimizer_idx: The index of the optimizer to untoggle.
"""
for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)):
if optimizer_idx != opt_idx:
Expand Down Expand Up @@ -1568,14 +1574,14 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va

def optimizer_step(
self,
epoch: int = None,
batch_idx: int = None,
optimizer: Optimizer = None,
optimizer_idx: int = None,
optimizer_closure: Optional[Callable] = None,
on_tpu: bool = None,
using_native_amp: bool = None,
using_lbfgs: bool = None,
epoch: int,
batch_idx: int,
optimizer: Union[Optimizer, LightningOptimizer],
optimizer_idx: int = 0,
optimizer_closure: Optional[Callable[[], Any]] = None,
on_tpu: bool = False,
using_native_amp: bool = False,
using_lbfgs: bool = False,
) -> None:
r"""
Override this method to adjust the default way the
Expand All @@ -1584,10 +1590,6 @@ def optimizer_step(
once per optimizer. This method (and ``zero_grad()``) won't be called during the
accumulation phase when ``Trainer(accumulate_grad_batches != 1)``.

Warning:
If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter
to ``optimizer.step()`` function as shown in the examples.

Args:
epoch: Current epoch
batch_idx: Index of current batch
Expand Down
64 changes: 29 additions & 35 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Callable, Optional
from typing import Any, Callable, Generator, List, Optional
from weakref import proxy

from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException


def do_nothing_closure():
def do_nothing_closure() -> None:
return


Expand All @@ -44,93 +45,86 @@ def __init__(self, optimizer: Optimizer):
self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})

self._optimizer = optimizer
self._trainer = None
self._optimizer_idx = None
self._trainer: Optional["pl.Trainer"] = None
self._optimizer_idx = 0

@property
def optimizer(self):
def optimizer(self) -> Optimizer:
return self._optimizer

@property
def defaults(self):
def defaults(self) -> dict:
return self._optimizer.defaults

@defaults.setter
def defaults(self, defaults):
def defaults(self, defaults: dict) -> None:
self._optimizer.defaults = defaults

@property
def state(self):
def state(self) -> dict:
return self._optimizer.state

@state.setter
def state(self, state):
def state(self, state: dict) -> None:
self._optimizer.state = state

@property
def param_groups(self):
def param_groups(self) -> List[dict]:
return self._optimizer.param_groups

@param_groups.setter
def param_groups(self, param_groups):
def param_groups(self, param_groups: List[dict]) -> None:
self._optimizer.param_groups = param_groups

def _on_trainer_init(self, trainer):
def _on_trainer_init(self, trainer: "pl.Trainer") -> None:
self._trainer = proxy(trainer)
for opt_idx, opt in enumerate(trainer.optimizers):
if opt == self._optimizer:
self._optimizer_idx = opt_idx
break

@classmethod
def _to_lightning_optimizer(cls, optimizer, trainer, opt_idx):
def _to_lightning_optimizer(cls, optimizer: Optimizer, trainer: "pl.Trainer", opt_idx: int) -> "LightningOptimizer":
# apex overrides .step function and need to be wrapped on each step
if trainer.amp_backend == AMPType.APEX:
optimizer = cls(optimizer)
optimizer._on_trainer_init(trainer)
if trainer.amp_backend is not None and trainer.amp_backend == AMPType.APEX:
lightning_optimizer = cls(optimizer)
lightning_optimizer._on_trainer_init(trainer)
else:
optimizer = trainer.lightning_optimizers[opt_idx]
return optimizer
lightning_optimizer = trainer.lightning_optimizers[opt_idx]
return lightning_optimizer

@contextmanager
def toggle_model(self, sync_grad: bool = True):
def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]:
"""This function is just a helper for advanced users.

Considering the current optimizer as A and all other optimizers as B.
Toggling means all parameters from B exclusive to A will have ``requires_grad`` set to False.


When performing gradient accumulation, there is no need to perform grad synchronization
during the accumulation phase.
Setting `sync_grad` to False will block this synchronization and improve performance.
"""
# local import here to avoid circular import
from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior

assert self._trainer is not None
lightning_module = self._trainer.lightning_module

with _block_parallel_sync_behavior(self._trainer, block=(not sync_grad)):
lightning_module.toggle_optimizer(self, self._optimizer_idx)
yield
lightning_module.untoggle_optimizer(self._optimizer_idx)

def step(self, closure: Optional[Callable] = None, **kwargs):
"""Call this directly from your training_step when doing optimizations manually. By using this we can
ensure that all the proper scaling when using 16-bit, accelerator etc is been done properly for you.

.. note:: In Manual Optimization, the user is expected to know when to call zero_grad,
perform accumulated_grad_batches, etc ... Lightning will only take care of precision and accelerators
def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> None:
"""Performs a single optimization step (parameter update).

Args:

closure: One could provide its own optimizer_closure. Set to None by default.

kwargs: Any parameters provided to wrapped optimizer.step()
closure: An optional optimizer_closure.
kwargs: Any additional arguments to the ``optimizer.step()`` call.

Example::

# Scenario for a GAN.

# Scenario for a GAN using manual optimization
def training_step(...):
opt_gen, opt_dis = self.optimizers()

Expand All @@ -152,8 +146,7 @@ def training_step(...):
opt_dis.step()


# Scenario for a GAN advanced

# A more advanced example
def training_step(self, batch, batch_idx, ...):
opt_gen, opt_dis = self.optimizers()

Expand Down Expand Up @@ -189,10 +182,11 @@ def closure_dis():
profiler_action += f"_{self._optimizer_idx}"

trainer = self._trainer
assert trainer is not None
with trainer.profiler.profile(profiler_action):
trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)

def __repr__(self):
def __repr__(self) -> str:
groups = [
{k: round(v, 12) if isinstance(v, float) else v for k, v in sorted(group.items()) if k != "params"}
for group in self.param_groups
Expand Down