Skip to content

Commit 3c303e9

Browse files
carmoccarohitgr7
authored andcommitted
Add typing for LightningOptimizer (Lightning-AI#9990)
1 parent 431919f commit 3c303e9

File tree

3 files changed

+60
-63
lines changed

3 files changed

+60
-63
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ module = [
6565
"pytorch_lightning.callbacks.model_summary",
6666
"pytorch_lightning.callbacks.pruning",
6767
"pytorch_lightning.callbacks.rich_model_summary",
68+
"pytorch_lightning.core.optimizer",
6869
"pytorch_lightning.loops.optimization.*",
6970
"pytorch_lightning.loops.evaluation_loop",
7071
"pytorch_lightning.trainer.connectors.checkpoint_connector",

pytorch_lightning/core/lightning.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@
2121
import tempfile
2222
from contextlib import contextmanager
2323
from pathlib import Path
24-
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
24+
from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Tuple, Union
2525

2626
import torch
2727
from torch import ScriptModule, Tensor
2828
from torch.nn import Module
2929
from torch.optim.optimizer import Optimizer
3030
from torchmetrics import Metric
31+
from typing_extensions import Literal
3132

3233
import pytorch_lightning as pl
3334
from pytorch_lightning.callbacks.progress import base as progress_base
@@ -120,6 +121,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
120121
# deprecated, will be removed in 1.6
121122
self._loaded_optimizer_states_dict = {}
122123

124+
@overload
125+
def optimizers(self, use_pl_optimizer: Literal[True] = True) -> Union[LightningOptimizer, List[LightningOptimizer]]:
126+
...
127+
128+
@overload
129+
def optimizers(self, use_pl_optimizer: Literal[False]) -> Union[Optimizer, List[Optimizer]]:
130+
...
131+
123132
def optimizers(
124133
self, use_pl_optimizer: bool = True
125134
) -> Union[Optimizer, LightningOptimizer, List[Optimizer], List[LightningOptimizer]]:
@@ -1426,17 +1435,16 @@ def backward(self, loss, optimizer, optimizer_idx):
14261435
"""
14271436
loss.backward(*args, **kwargs)
14281437

1429-
def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
1438+
def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer], optimizer_idx: int) -> None:
14301439
"""Makes sure only the gradients of the current optimizer's parameters are calculated in the training step
1431-
to prevent dangling gradients in multiple-optimizer setup. It works with :meth:`untoggle_optimizer` to make
1432-
sure ``param_requires_grad_state`` is properly reset. Override for your own behavior.
1440+
to prevent dangling gradients in multiple-optimizer setup.
14331441
1434-
Args:
1435-
optimizer: Current optimizer used in the training loop
1436-
optimizer_idx: Current optimizer idx in the training loop
1442+
This is only called automatically when automatic optimization is enabled and multiple optimizers are used.
1443+
It works with :meth:`untoggle_optimizer` to make sure ``param_requires_grad_state`` is properly reset.
14371444
1438-
Note:
1439-
Only called when using multiple optimizers
1445+
Args:
1446+
optimizer: The optimizer to toggle.
1447+
optimizer_idx: The index of the optimizer to toggle.
14401448
"""
14411449
# Iterate over all optimizer parameters to preserve their `requires_grad` information
14421450
# in case these are pre-defined during `configure_optimizers`
@@ -1457,15 +1465,13 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
14571465
param.requires_grad = param_requires_grad_state[param]
14581466
self._param_requires_grad_state = param_requires_grad_state
14591467

1460-
def untoggle_optimizer(self, optimizer_idx: int):
1461-
"""Resets the state of required gradients that were toggled with :meth:`toggle_optimizer`. Override for
1462-
your own behavior.
1468+
def untoggle_optimizer(self, optimizer_idx: int) -> None:
1469+
"""Resets the state of required gradients that were toggled with :meth:`toggle_optimizer`.
14631470
1464-
Args:
1465-
optimizer_idx: Current optimizer idx in the training loop
1471+
This is only called automatically when automatic optimization is enabled and multiple optimizers are used.
14661472
1467-
Note:
1468-
Only called when using multiple_optimizers
1473+
Args:
1474+
optimizer_idx: The index of the optimizer to untoggle.
14691475
"""
14701476
for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)):
14711477
if optimizer_idx != opt_idx:
@@ -1568,14 +1574,14 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va
15681574

15691575
def optimizer_step(
15701576
self,
1571-
epoch: int = None,
1572-
batch_idx: int = None,
1573-
optimizer: Optimizer = None,
1574-
optimizer_idx: int = None,
1575-
optimizer_closure: Optional[Callable] = None,
1576-
on_tpu: bool = None,
1577-
using_native_amp: bool = None,
1578-
using_lbfgs: bool = None,
1577+
epoch: int,
1578+
batch_idx: int,
1579+
optimizer: Union[Optimizer, LightningOptimizer],
1580+
optimizer_idx: int = 0,
1581+
optimizer_closure: Optional[Callable[[], Any]] = None,
1582+
on_tpu: bool = False,
1583+
using_native_amp: bool = False,
1584+
using_lbfgs: bool = False,
15791585
) -> None:
15801586
r"""
15811587
Override this method to adjust the default way the
@@ -1584,10 +1590,6 @@ def optimizer_step(
15841590
once per optimizer. This method (and ``zero_grad()``) won't be called during the
15851591
accumulation phase when ``Trainer(accumulate_grad_batches != 1)``.
15861592
1587-
Warning:
1588-
If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter
1589-
to ``optimizer.step()`` function as shown in the examples.
1590-
15911593
Args:
15921594
epoch: Current epoch
15931595
batch_idx: Index of current batch

pytorch_lightning/core/optimizer.py

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import contextmanager
15-
from typing import Callable, Optional
15+
from typing import Any, Callable, Generator, List, Optional
1616
from weakref import proxy
1717

1818
from torch.optim import Optimizer
1919

20+
import pytorch_lightning as pl
2021
from pytorch_lightning.utilities import AMPType
2122
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2223

2324

24-
def do_nothing_closure():
25+
def do_nothing_closure() -> None:
2526
return
2627

2728

@@ -44,93 +45,86 @@ def __init__(self, optimizer: Optimizer):
4445
self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
4546

4647
self._optimizer = optimizer
47-
self._trainer = None
48-
self._optimizer_idx = None
48+
self._trainer: Optional["pl.Trainer"] = None
49+
self._optimizer_idx = 0
4950

5051
@property
51-
def optimizer(self):
52+
def optimizer(self) -> Optimizer:
5253
return self._optimizer
5354

5455
@property
55-
def defaults(self):
56+
def defaults(self) -> dict:
5657
return self._optimizer.defaults
5758

5859
@defaults.setter
59-
def defaults(self, defaults):
60+
def defaults(self, defaults: dict) -> None:
6061
self._optimizer.defaults = defaults
6162

6263
@property
63-
def state(self):
64+
def state(self) -> dict:
6465
return self._optimizer.state
6566

6667
@state.setter
67-
def state(self, state):
68+
def state(self, state: dict) -> None:
6869
self._optimizer.state = state
6970

7071
@property
71-
def param_groups(self):
72+
def param_groups(self) -> List[dict]:
7273
return self._optimizer.param_groups
7374

7475
@param_groups.setter
75-
def param_groups(self, param_groups):
76+
def param_groups(self, param_groups: List[dict]) -> None:
7677
self._optimizer.param_groups = param_groups
7778

78-
def _on_trainer_init(self, trainer):
79+
def _on_trainer_init(self, trainer: "pl.Trainer") -> None:
7980
self._trainer = proxy(trainer)
8081
for opt_idx, opt in enumerate(trainer.optimizers):
8182
if opt == self._optimizer:
8283
self._optimizer_idx = opt_idx
8384
break
8485

8586
@classmethod
86-
def _to_lightning_optimizer(cls, optimizer, trainer, opt_idx):
87+
def _to_lightning_optimizer(cls, optimizer: Optimizer, trainer: "pl.Trainer", opt_idx: int) -> "LightningOptimizer":
8788
# apex overrides .step function and need to be wrapped on each step
88-
if trainer.amp_backend == AMPType.APEX:
89-
optimizer = cls(optimizer)
90-
optimizer._on_trainer_init(trainer)
89+
if trainer.amp_backend is not None and trainer.amp_backend == AMPType.APEX:
90+
lightning_optimizer = cls(optimizer)
91+
lightning_optimizer._on_trainer_init(trainer)
9192
else:
92-
optimizer = trainer.lightning_optimizers[opt_idx]
93-
return optimizer
93+
lightning_optimizer = trainer.lightning_optimizers[opt_idx]
94+
return lightning_optimizer
9495

9596
@contextmanager
96-
def toggle_model(self, sync_grad: bool = True):
97+
def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]:
9798
"""This function is just a helper for advanced users.
9899
99100
Considering the current optimizer as A and all other optimizers as B.
100101
Toggling means all parameters from B exclusive to A will have ``requires_grad`` set to False.
101102
102-
103103
When performing gradient accumulation, there is no need to perform grad synchronization
104104
during the accumulation phase.
105105
Setting `sync_grad` to False will block this synchronization and improve performance.
106106
"""
107107
# local import here to avoid circular import
108108
from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior
109109

110+
assert self._trainer is not None
110111
lightning_module = self._trainer.lightning_module
111112

112113
with _block_parallel_sync_behavior(self._trainer, block=(not sync_grad)):
113114
lightning_module.toggle_optimizer(self, self._optimizer_idx)
114115
yield
115116
lightning_module.untoggle_optimizer(self._optimizer_idx)
116117

117-
def step(self, closure: Optional[Callable] = None, **kwargs):
118-
"""Call this directly from your training_step when doing optimizations manually. By using this we can
119-
ensure that all the proper scaling when using 16-bit, accelerator etc is been done properly for you.
120-
121-
.. note:: In Manual Optimization, the user is expected to know when to call zero_grad,
122-
perform accumulated_grad_batches, etc ... Lightning will only take care of precision and accelerators
118+
def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> None:
119+
"""Performs a single optimization step (parameter update).
123120
124121
Args:
125-
126-
closure: One could provide its own optimizer_closure. Set to None by default.
127-
128-
kwargs: Any parameters provided to wrapped optimizer.step()
122+
closure: An optional optimizer_closure.
123+
kwargs: Any additional arguments to the ``optimizer.step()`` call.
129124
130125
Example::
131126
132-
# Scenario for a GAN.
133-
127+
# Scenario for a GAN using manual optimization
134128
def training_step(...):
135129
opt_gen, opt_dis = self.optimizers()
136130
@@ -152,8 +146,7 @@ def training_step(...):
152146
opt_dis.step()
153147
154148
155-
# Scenario for a GAN advanced
156-
149+
# A more advanced example
157150
def training_step(self, batch, batch_idx, ...):
158151
opt_gen, opt_dis = self.optimizers()
159152
@@ -189,10 +182,11 @@ def closure_dis():
189182
profiler_action += f"_{self._optimizer_idx}"
190183

191184
trainer = self._trainer
185+
assert trainer is not None
192186
with trainer.profiler.profile(profiler_action):
193187
trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
194188

195-
def __repr__(self):
189+
def __repr__(self) -> str:
196190
groups = [
197191
{k: round(v, 12) if isinstance(v, float) else v for k, v in sorted(group.items()) if k != "params"}
198192
for group in self.param_groups

0 commit comments

Comments
 (0)