12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
from contextlib import contextmanager
15
- from typing import Callable , Optional
15
+ from typing import Any , Callable , Generator , List , Optional
16
16
from weakref import proxy
17
17
18
18
from torch .optim import Optimizer
19
19
20
+ import pytorch_lightning as pl
20
21
from pytorch_lightning .utilities import AMPType
21
22
from pytorch_lightning .utilities .exceptions import MisconfigurationException
22
23
23
24
24
- def do_nothing_closure ():
25
+ def do_nothing_closure () -> None :
25
26
return
26
27
27
28
@@ -44,93 +45,86 @@ def __init__(self, optimizer: Optimizer):
44
45
self .__class__ = type ("Lightning" + optimizer .__class__ .__name__ , (self .__class__ , optimizer .__class__ ), {})
45
46
46
47
self ._optimizer = optimizer
47
- self ._trainer = None
48
- self ._optimizer_idx = None
48
+ self ._trainer : Optional [ "pl.Trainer" ] = None
49
+ self ._optimizer_idx = 0
49
50
50
51
@property
51
- def optimizer (self ):
52
+ def optimizer (self ) -> Optimizer :
52
53
return self ._optimizer
53
54
54
55
@property
55
- def defaults (self ):
56
+ def defaults (self ) -> dict :
56
57
return self ._optimizer .defaults
57
58
58
59
@defaults .setter
59
- def defaults (self , defaults ) :
60
+ def defaults (self , defaults : dict ) -> None :
60
61
self ._optimizer .defaults = defaults
61
62
62
63
@property
63
- def state (self ):
64
+ def state (self ) -> dict :
64
65
return self ._optimizer .state
65
66
66
67
@state .setter
67
- def state (self , state ) :
68
+ def state (self , state : dict ) -> None :
68
69
self ._optimizer .state = state
69
70
70
71
@property
71
- def param_groups (self ):
72
+ def param_groups (self ) -> List [ dict ] :
72
73
return self ._optimizer .param_groups
73
74
74
75
@param_groups .setter
75
- def param_groups (self , param_groups ) :
76
+ def param_groups (self , param_groups : List [ dict ]) -> None :
76
77
self ._optimizer .param_groups = param_groups
77
78
78
- def _on_trainer_init (self , trainer ) :
79
+ def _on_trainer_init (self , trainer : "pl.Trainer" ) -> None :
79
80
self ._trainer = proxy (trainer )
80
81
for opt_idx , opt in enumerate (trainer .optimizers ):
81
82
if opt == self ._optimizer :
82
83
self ._optimizer_idx = opt_idx
83
84
break
84
85
85
86
@classmethod
86
- def _to_lightning_optimizer (cls , optimizer , trainer , opt_idx ) :
87
+ def _to_lightning_optimizer (cls , optimizer : Optimizer , trainer : "pl.Trainer" , opt_idx : int ) -> "LightningOptimizer" :
87
88
# apex overrides .step function and need to be wrapped on each step
88
- if trainer .amp_backend == AMPType .APEX :
89
- optimizer = cls (optimizer )
90
- optimizer ._on_trainer_init (trainer )
89
+ if trainer .amp_backend is not None and trainer . amp_backend == AMPType .APEX :
90
+ lightning_optimizer = cls (optimizer )
91
+ lightning_optimizer ._on_trainer_init (trainer )
91
92
else :
92
- optimizer = trainer .lightning_optimizers [opt_idx ]
93
- return optimizer
93
+ lightning_optimizer = trainer .lightning_optimizers [opt_idx ]
94
+ return lightning_optimizer
94
95
95
96
@contextmanager
96
- def toggle_model (self , sync_grad : bool = True ):
97
+ def toggle_model (self , sync_grad : bool = True ) -> Generator [ None , None , None ] :
97
98
"""This function is just a helper for advanced users.
98
99
99
100
Considering the current optimizer as A and all other optimizers as B.
100
101
Toggling means all parameters from B exclusive to A will have ``requires_grad`` set to False.
101
102
102
-
103
103
When performing gradient accumulation, there is no need to perform grad synchronization
104
104
during the accumulation phase.
105
105
Setting `sync_grad` to False will block this synchronization and improve performance.
106
106
"""
107
107
# local import here to avoid circular import
108
108
from pytorch_lightning .loops .utilities import _block_parallel_sync_behavior
109
109
110
+ assert self ._trainer is not None
110
111
lightning_module = self ._trainer .lightning_module
111
112
112
113
with _block_parallel_sync_behavior (self ._trainer , block = (not sync_grad )):
113
114
lightning_module .toggle_optimizer (self , self ._optimizer_idx )
114
115
yield
115
116
lightning_module .untoggle_optimizer (self ._optimizer_idx )
116
117
117
- def step (self , closure : Optional [Callable ] = None , ** kwargs ):
118
- """Call this directly from your training_step when doing optimizations manually. By using this we can
119
- ensure that all the proper scaling when using 16-bit, accelerator etc is been done properly for you.
120
-
121
- .. note:: In Manual Optimization, the user is expected to know when to call zero_grad,
122
- perform accumulated_grad_batches, etc ... Lightning will only take care of precision and accelerators
118
+ def step (self , closure : Optional [Callable [[], Any ]] = None , ** kwargs : Any ) -> None :
119
+ """Performs a single optimization step (parameter update).
123
120
124
121
Args:
125
-
126
- closure: One could provide its own optimizer_closure. Set to None by default.
127
-
128
- kwargs: Any parameters provided to wrapped optimizer.step()
122
+ closure: An optional optimizer_closure.
123
+ kwargs: Any additional arguments to the ``optimizer.step()`` call.
129
124
130
125
Example::
131
126
132
- # Scenario for a GAN.
133
-
127
+ # Scenario for a GAN using manual optimization
134
128
def training_step(...):
135
129
opt_gen, opt_dis = self.optimizers()
136
130
@@ -152,8 +146,7 @@ def training_step(...):
152
146
opt_dis.step()
153
147
154
148
155
- # Scenario for a GAN advanced
156
-
149
+ # A more advanced example
157
150
def training_step(self, batch, batch_idx, ...):
158
151
opt_gen, opt_dis = self.optimizers()
159
152
@@ -189,10 +182,11 @@ def closure_dis():
189
182
profiler_action += f"_{ self ._optimizer_idx } "
190
183
191
184
trainer = self ._trainer
185
+ assert trainer is not None
192
186
with trainer .profiler .profile (profiler_action ):
193
187
trainer .accelerator .optimizer_step (self ._optimizer , self ._optimizer_idx , closure , ** kwargs )
194
188
195
- def __repr__ (self ):
189
+ def __repr__ (self ) -> str :
196
190
groups = [
197
191
{k : round (v , 12 ) if isinstance (v , float ) else v for k , v in sorted (group .items ()) if k != "params" }
198
192
for group in self .param_groups
0 commit comments