12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
from copy import deepcopy
15
- from functools import partial
16
- from typing import Any , Callable , List , Optional , Tuple
15
+ from typing import Any , List , Optional , Tuple
17
16
18
17
import numpy as np
19
18
from deprecate import void
20
19
from torch import Tensor
21
20
from torch .optim import Optimizer
22
21
23
22
from pytorch_lightning .loops .base import Loop
24
- from pytorch_lightning .loops .closure import Closure , ClosureResult
23
+ from pytorch_lightning .loops .batch . manual import ManualOptimization
25
24
from pytorch_lightning .loops .optimizer .optimizer_loop import OptimizerLoop
26
- from pytorch_lightning .loops .utilities import (
27
- _build_training_step_kwargs ,
28
- _check_training_step_output ,
29
- _process_training_step_output ,
30
- )
31
25
from pytorch_lightning .trainer .supporters import TensorRunningAccum
32
26
from pytorch_lightning .utilities import AttributeDict
33
27
from pytorch_lightning .utilities .types import STEP_OUTPUT
@@ -45,6 +39,7 @@ def __init__(self) -> None:
45
39
# the current split index when the batch gets split into chunks in truncated backprop through time
46
40
self .split_idx : Optional [int ] = None
47
41
self .optimizer_loop = OptimizerLoop ()
42
+ self .manual_loop = ManualOptimization ()
48
43
49
44
self ._warning_cache : WarningCache = WarningCache ()
50
45
self ._hiddens : Optional [Tensor ] = None
@@ -63,8 +58,13 @@ def optimizer_freq_cumsum(self) -> int:
63
58
self ._optimizer_freq_cumsum = np .cumsum (self .trainer .optimizer_frequencies )
64
59
return self ._optimizer_freq_cumsum
65
60
66
- def connect (self , optimizer_loop : "Loop" ) -> None :
67
- self .optimizer_loop = optimizer_loop
61
+ def connect (
62
+ self , optimizer_loop : Optional ["Loop" ] = None , manual_loop : Optional [ManualOptimization ] = None
63
+ ) -> None :
64
+ if optimizer_loop is not None :
65
+ self .optimizer_loop = optimizer_loop
66
+ if manual_loop is not None :
67
+ self .manual_loop = manual_loop
68
68
69
69
def run (self , batch : Any , batch_idx : int ) -> AttributeDict :
70
70
"""Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks.
@@ -132,10 +132,10 @@ def advance(self, batch, batch_idx):
132
132
for k in range (len (batch_outputs )):
133
133
self .batch_outputs [k ].extend (batch_outputs [k ])
134
134
else :
135
- # in manual optimization, there is no looping over optimizers
136
- result = self ._run_optimization ( batch_idx , split_batch )
137
- if result :
138
- self .batch_outputs [0 ].append (deepcopy (result . result_collection ))
135
+ # in manual optimization, hand over execution to the ManualOptimization loop
136
+ output , self . _hiddens = self .manual_loop . run ( split_batch , batch_idx , self . _hiddens )
137
+ if output :
138
+ self .batch_outputs [0 ].append (deepcopy (output ))
139
139
140
140
def teardown (self ) -> None :
141
141
# release memory
@@ -145,89 +145,6 @@ def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int:
145
145
"""Gets the number of active optimizers based on their frequency."""
146
146
return len (self .get_active_optimizers (batch_idx ))
147
147
148
- def _run_optimization (
149
- self ,
150
- batch_idx : int ,
151
- split_batch : Any ,
152
- ) -> Optional [ClosureResult ]:
153
- """Runs closure (train step + backward) together with optimization if necessary.
154
-
155
- Args:
156
- batch_idx: the index of the current batch
157
- split_batch: the current tbptt split of the whole batch
158
- """
159
- # TODO: replace call through closure by direct call (manual optimization)
160
- closure = self ._make_closure (split_batch , batch_idx , self ._hiddens )
161
- closure ()
162
- result = closure .get_result ()
163
-
164
- if result :
165
- # if no result, user decided to skip optimization
166
- # otherwise update running loss + reset accumulated loss
167
- self ._update_running_loss (result .loss )
168
-
169
- return result
170
-
171
- def _make_closure (
172
- self ,
173
- split_batch : Any ,
174
- batch_idx : int ,
175
- hiddens : Any ,
176
- ) -> Closure :
177
- """Build a closure object that captures the given arguments and runs the `training_step` function and
178
- optionally other functions such as `backward` and `zero_grad`."""
179
- step_fn = self ._make_step_fn (split_batch , batch_idx , hiddens )
180
- backward_fn = None
181
- zero_grad_fn = None
182
-
183
- return Closure (
184
- step_fn = step_fn ,
185
- backward_fn = backward_fn ,
186
- zero_grad_fn = zero_grad_fn ,
187
- profiler = self .trainer .profiler ,
188
- )
189
-
190
- def _make_step_fn (self , split_batch : Any , batch_idx : int , hiddens : Any ) -> Callable [[], dict ]:
191
- """Build the step function that runs the `training_step` and processes its output."""
192
- return partial (self ._training_step , split_batch , batch_idx , hiddens )
193
-
194
- def _training_step (self , split_batch : Any , batch_idx : int , hiddens : Tensor ) -> Optional [AttributeDict ]:
195
- """Performs the training step for manual optimization.
196
-
197
- Args:
198
- split_batch: the current tbptt split of the current batch
199
- batch_idx: the index of the current batch
200
- hiddens: the model's hidden state of the previous iteration
201
-
202
- Returns:
203
- an AttributeDict containing the training step output.
204
- """
205
- # give the PL module a result for logging
206
- model_ref = self .trainer .lightning_module
207
-
208
- with self .trainer .profiler .profile ("model_forward" ):
209
- step_kwargs = _build_training_step_kwargs (
210
- model_ref , self .trainer .optimizers , split_batch , batch_idx , opt_idx = None , hiddens = hiddens
211
- )
212
-
213
- # manually capture logged metrics
214
- model_ref ._current_fx_name = "training_step"
215
- with self .trainer .profiler .profile ("training_step" ):
216
- training_step_output = self .trainer .accelerator .training_step (step_kwargs )
217
- self .trainer .accelerator .post_training_step ()
218
-
219
- del step_kwargs
220
-
221
- training_step_output = self .trainer .call_hook ("training_step_end" , training_step_output )
222
-
223
- _check_training_step_output (self .trainer .lightning_module , training_step_output )
224
-
225
- result_collection , self ._hiddens = _process_training_step_output (self .trainer , training_step_output )
226
- if result_collection is None :
227
- return
228
-
229
- return AttributeDict (closure_loss = None , loss = None , result_collection = result_collection )
230
-
231
148
def _tbptt_split_batch (self , batch : Any ) -> List [Any ]:
232
149
"""Splits a single batch into a list of sequence steps for tbptt.
233
150
0 commit comments