11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
-
15
- from copy import deepcopy
16
14
from functools import partial
17
15
from typing import Any , Callable , Dict , List , Optional
18
16
27
25
_block_parallel_sync_behavior ,
28
26
_build_training_step_kwargs ,
29
27
_check_training_step_output ,
30
- _process_training_step_output ,
28
+ _extract_hiddens ,
29
+ check_finite_loss ,
31
30
)
32
- from pytorch_lightning .trainer .connectors .logger_connector .result import ResultCollection
33
31
from pytorch_lightning .trainer .progress import OptimizationProgress
34
- from pytorch_lightning .utilities import AMPType , AttributeDict , DeviceType , grad_norm
32
+ from pytorch_lightning .utilities import AMPType , DeviceType , grad_norm
35
33
from pytorch_lightning .utilities .exceptions import MisconfigurationException
36
34
from pytorch_lightning .utilities .finite_checks import detect_nan_parameters
37
35
from pytorch_lightning .utilities .imports import _TPU_AVAILABLE
38
36
39
- _OUTPUTS_TYPE = List [List [Optional [ ResultCollection ] ]]
37
+ _OUTPUTS_TYPE = List [List [ClosureResult ]]
40
38
41
39
42
40
class OptimizerLoop (Loop ):
@@ -80,8 +78,8 @@ def advance(self, batch: Any, *args, **kwargs) -> None: # type: ignore[override
80
78
self ._optimizers [self .optim_progress .optimizer_idx ],
81
79
self .optim_progress .optimizer_idx ,
82
80
)
83
- if result .result_collection is not None :
84
- self .outputs [self .optim_progress .optimizer_idx ].append (deepcopy ( result .result_collection ))
81
+ if result .loss is not None :
82
+ self .outputs [self .optim_progress .optimizer_idx ].append (result .drop_closure_loss ( ))
85
83
86
84
self .optim_progress .optimizer_idx += 1
87
85
@@ -168,7 +166,7 @@ def _make_closure(self, split_batch: Any, batch_idx: int, opt_idx: int, optimize
168
166
step_fn = step_fn , backward_fn = backward_fn , zero_grad_fn = zero_grad_fn , profiler = self .trainer .profiler
169
167
)
170
168
171
- def _make_step_fn (self , split_batch : Any , batch_idx : int , opt_idx : int ) -> Callable [[], Optional [ AttributeDict ] ]:
169
+ def _make_step_fn (self , split_batch : Any , batch_idx : int , opt_idx : int ) -> Callable [[], ClosureResult ]:
172
170
"""Build the step function that runs the `training_step` and processes its output."""
173
171
return partial (self ._training_step , split_batch , batch_idx , opt_idx )
174
172
@@ -241,7 +239,7 @@ def _optimizer_step(
241
239
train_step_and_backward_closure: the closure function performing the train step and computing the
242
240
gradients. By default called by the optimizer (if possible)
243
241
"""
244
- model_ref = self .trainer .lightning_module
242
+ lightning_module = self .trainer .lightning_module
245
243
246
244
is_lbfgs = isinstance (optimizer , torch .optim .LBFGS )
247
245
using_native_amp = self .trainer .amp_backend is not None and self .trainer .amp_backend == AMPType .NATIVE
@@ -259,7 +257,7 @@ def _optimizer_step(
259
257
self .optim_progress .optimizer .step .increment_ready ()
260
258
261
259
# model hook
262
- model_ref .optimizer_step (
260
+ lightning_module .optimizer_step (
263
261
self .trainer .current_epoch ,
264
262
batch_idx ,
265
263
optimizer ,
@@ -293,7 +291,7 @@ def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer,
293
291
self .trainer .accelerator .optimizer_zero_grad (self .trainer .current_epoch , batch_idx , optimizer , opt_idx )
294
292
self .optim_progress .optimizer .zero_grad .increment_completed ()
295
293
296
- def _training_step (self , split_batch : Any , batch_idx : int , opt_idx : int ) -> Optional [ AttributeDict ] :
294
+ def _training_step (self , split_batch : Any , batch_idx : int , opt_idx : int ) -> ClosureResult :
297
295
"""Performs the actual train step with the tied hooks.
298
296
299
297
Args:
@@ -302,19 +300,19 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Opti
302
300
opt_idx: the index of the current optimizer
303
301
304
302
Returns:
305
- an AttributeDict containing the loss value and the training step output.
303
+ A ``ClosureResult`` containing the training step output.
306
304
"""
307
305
# give the PL module a result for logging
308
- model_ref = self .trainer .lightning_module
306
+ lightning_module = self .trainer .lightning_module
309
307
310
308
with self .trainer .profiler .profile ("model_forward" ):
311
309
312
310
step_kwargs = _build_training_step_kwargs (
313
- self . trainer . lightning_module , self .trainer .optimizers , split_batch , batch_idx , opt_idx , self ._hiddens
311
+ lightning_module , self .trainer .optimizers , split_batch , batch_idx , opt_idx , self ._hiddens
314
312
)
315
313
316
314
# manually capture logged metrics
317
- model_ref ._current_fx_name = "training_step"
315
+ lightning_module ._current_fx_name = "training_step"
318
316
with self .trainer .profiler .profile ("training_step" ):
319
317
training_step_output = self .trainer .accelerator .training_step (step_kwargs )
320
318
self .trainer .accelerator .post_training_step ()
@@ -323,20 +321,20 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Opti
323
321
324
322
training_step_output = self .trainer .call_hook ("training_step_end" , training_step_output )
325
323
326
- _check_training_step_output (self .trainer .lightning_module , training_step_output )
324
+ _check_training_step_output (lightning_module , training_step_output )
325
+
326
+ self ._hiddens = _extract_hiddens (training_step_output , lightning_module .truncated_bptt_steps )
327
327
328
- result_collection , self ._hiddens = _process_training_step_output (self .trainer , training_step_output )
329
- if result_collection is None :
330
- return None
328
+ result = ClosureResult .from_training_step_output (training_step_output , self .trainer .accumulate_grad_batches )
331
329
332
- # output validation already done, here loss can't be None
333
- assert result_collection .minimize is not None
330
+ if self .trainer .terminate_on_nan :
331
+ check_finite_loss (result .closure_loss )
332
+
333
+ if self .trainer .move_metrics_to_cpu :
334
+ # hiddens and the training step output are not moved as they are not considered "metrics"
335
+ self .trainer ._results .cpu ()
334
336
335
- # accumulate loss. if accumulate_grad_batches==1, no effect
336
- closure_loss = result_collection .minimize / self .trainer .accumulate_grad_batches
337
- # the loss will get scaled for amp. avoid any modifications to it
338
- loss = closure_loss .detach ().clone ()
339
- return AttributeDict (closure_loss = closure_loss , loss = loss , result_collection = result_collection )
337
+ return result
340
338
341
339
def _track_and_norm_grad (self , optimizer : torch .optim .Optimizer ) -> Dict [str , float ]:
342
340
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.
0 commit comments