Skip to content

Commit e0f2e04

Browse files
authored
Share the training step output data via ClosureResult (#9349)
1 parent 3118480 commit e0f2e04

16 files changed

+240
-221
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
337337
- Fixed `replace_sampler` missing the batch size under specific conditions ([#9367](https://github.com/PyTorchLightning/pytorch-lightning/pull/9367))
338338

339339

340+
- Fixed bug where the training step output needed to be `deepcopy`-ed ([#9349](https://github.com/PyTorchLightning/pytorch-lightning/pull/9349))
341+
342+
340343
- Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386))
341344

342345

pytorch_lightning/loops/batch/manual.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
from typing import Any, Optional
1515

1616
from pytorch_lightning.loops import Loop
17+
from pytorch_lightning.loops.closure import ClosureResult
1718
from pytorch_lightning.loops.utilities import (
1819
_build_training_step_kwargs,
1920
_check_training_step_output,
20-
_process_training_step_output,
21+
_extract_hiddens,
22+
check_finite_loss,
2123
)
22-
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
2324

2425

2526
class ManualOptimization(Loop):
@@ -35,7 +36,7 @@ def __init__(self) -> None:
3536
super().__init__()
3637
self._done: bool = False
3738
self._hiddens: Optional[Any] = None
38-
self._output: Optional[ResultCollection] = None
39+
self._output: Optional[ClosureResult] = None
3940

4041
@property
4142
def done(self) -> bool:
@@ -52,16 +53,16 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
5253
batch_idx: the index of the current batch
5354
"""
5455
assert self.trainer is not None
55-
ligtning_module = self.trainer.lightning_module
56+
lightning_module = self.trainer.lightning_module
5657

5758
with self.trainer.profiler.profile("model_forward"):
5859

5960
step_kwargs = _build_training_step_kwargs(
60-
ligtning_module, self.trainer.optimizers, batch, batch_idx, opt_idx=None, hiddens=self._hiddens
61+
lightning_module, self.trainer.optimizers, batch, batch_idx, opt_idx=None, hiddens=self._hiddens
6162
)
6263

6364
# manually capture logged metrics
64-
ligtning_module._current_fx_name = "training_step"
65+
lightning_module._current_fx_name = "training_step"
6566
with self.trainer.profiler.profile("training_step"):
6667
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
6768
self.trainer.accelerator.post_training_step()
@@ -70,14 +71,28 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
7071

7172
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
7273

73-
_check_training_step_output(ligtning_module, training_step_output)
74+
_check_training_step_output(lightning_module, training_step_output)
7475

75-
result_collection, self._hiddens = _process_training_step_output(self.trainer, training_step_output)
76+
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)
77+
78+
# TODO: do not use `ClosureResult`
79+
result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches)
80+
81+
if self.trainer.terminate_on_nan:
82+
check_finite_loss(result.closure_loss)
83+
84+
if self.trainer.move_metrics_to_cpu:
85+
# hiddens and the training step output are not moved as they are not considered "metrics"
86+
# the user might need them on the correct device for an operation in `training_epoch_end`
87+
assert self.trainer._results is not None
88+
self.trainer._results.cpu()
7689

7790
self._done = True
78-
self._output = result_collection
91+
self._output = result
7992

80-
def on_run_end(self) -> Optional[ResultCollection]:
93+
def on_run_end(self) -> Optional[ClosureResult]:
8194
"""Returns the result of this loop, i.e., the post-processed outputs from the training step."""
8295
output, self._output = self._output, None # free memory
96+
# #9052 added support for raising `StopIteration` in the `training_step`. If that happens, then `advance`
97+
# doesn't finish and `self._output` stays as `None`. If #9415 happens then this would always return a result
8398
return output

pytorch_lightning/loops/batch/training_batch_loop.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from copy import deepcopy
1514
from typing import Any, List, Optional, Tuple
1615

1716
import numpy as np
@@ -131,9 +130,9 @@ def advance(self, batch, batch_idx):
131130
self.batch_outputs[k].extend(batch_outputs[k])
132131
else:
133132
# in manual optimization, hand over execution to the ManualOptimization loop
134-
output = self.manual_loop.run(split_batch, batch_idx)
135-
if output is not None:
136-
self.batch_outputs[0].append(deepcopy(output))
133+
result = self.manual_loop.run(split_batch, batch_idx)
134+
if result is not None and result.loss is not None:
135+
self.batch_outputs[0].append(result.drop_closure_loss())
137136

138137
def on_run_end(self) -> None:
139138
self.optimizer_loop._hiddens = None

pytorch_lightning/loops/closure.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,32 +11,84 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
from abc import ABC, abstractmethod
16-
from dataclasses import dataclass
15+
from dataclasses import dataclass, field
1716
from typing import Any, Callable, Dict, Optional
1817

1918
from torch import Tensor
2019

2120
from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler
22-
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
21+
from pytorch_lightning.utilities.apply_func import apply_to_collection
2322
from pytorch_lightning.utilities.exceptions import MisconfigurationException
24-
from pytorch_lightning.utilities.warnings import WarningCache
23+
from pytorch_lightning.utilities.memory import recursive_detach
24+
from pytorch_lightning.utilities.types import STEP_OUTPUT
25+
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, WarningCache
2526

2627

2728
@dataclass
2829
class ClosureResult:
2930
"""A container to hold the result of a :class:`AbstractClosure` call.
3031
32+
It is created from the output of :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`.
33+
3134
Attributes:
3235
closure_loss: The loss with a graph attached.
3336
loss: A detached copy of the closure loss.
34-
result_collection: A collection of results returned by the closure.
37+
extra: Any keys other than the loss returned.
3538
"""
3639

3740
closure_loss: Optional[Tensor]
38-
loss: Optional[Tensor]
39-
result_collection: Optional[ResultCollection]
41+
loss: Optional[Tensor] = field(init=False, default=None)
42+
extra: Dict[str, Tensor] = field(default_factory=dict)
43+
44+
def __post_init__(self) -> None:
45+
# TODO: remove with the deprecation removal in v1.6
46+
ClosureResult._check_extra_detach_deprecation(self.extra)
47+
self.extra = recursive_detach(self.extra)
48+
49+
self._clone_loss()
50+
51+
def _clone_loss(self) -> None:
52+
if self.closure_loss is not None:
53+
# the loss will get scaled for amp. avoid any modifications to it
54+
self.loss = self.closure_loss.detach().clone()
55+
56+
@classmethod
57+
def from_training_step_output(
58+
cls, training_step_output: Optional[STEP_OUTPUT], normalize: int = 1
59+
) -> "ClosureResult":
60+
closure_loss, extra = None, {}
61+
62+
if isinstance(training_step_output, dict):
63+
# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
64+
closure_loss = training_step_output.get("loss")
65+
extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}
66+
elif isinstance(training_step_output, Tensor):
67+
closure_loss = training_step_output
68+
69+
if closure_loss is not None:
70+
# accumulate the loss. If ``accumulate_grad_batches == 1``, no effect
71+
closure_loss /= normalize
72+
73+
return cls(closure_loss, extra=extra)
74+
75+
@staticmethod
76+
def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> None:
77+
def check_fn(v: Tensor) -> Tensor:
78+
if v.grad_fn is not None:
79+
rank_zero_deprecation(
80+
f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"
81+
" but this behaviour will change in v1.6. Please detach it manually:"
82+
" `return {'loss': ..., 'something': something.detach()}`"
83+
)
84+
return v
85+
86+
apply_to_collection(extra, Tensor, check_fn)
87+
88+
def drop_closure_loss(self) -> "ClosureResult":
89+
"""Return itself without the closure loss which could have a `grad_fn`"""
90+
self.closure_loss = None
91+
return self
4092

4193

4294
class AbstractClosure(ABC):
@@ -107,7 +159,7 @@ class Closure(AbstractClosure):
107159

108160
def __init__(
109161
self,
110-
step_fn: Callable[[], Optional[Dict]],
162+
step_fn: Callable[[], ClosureResult],
111163
backward_fn: Optional[Callable[[Tensor], Tensor]] = None,
112164
zero_grad_fn: Optional[Callable[[], None]] = None,
113165
profiler: Optional[BaseProfiler] = None,
@@ -121,7 +173,6 @@ def __init__(
121173
def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
122174
with self._profiler.profile("training_step_and_backward"):
123175
step_output = self._step_fn()
124-
step_output = ClosureResult(**step_output) if step_output else ClosureResult(None, None, None)
125176

126177
if step_output.closure_loss is None:
127178
self.warning_cache.warn(

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from pytorch_lightning import loops # import as loops to avoid circular imports
1919
from pytorch_lightning.loops.batch import TrainingBatchLoop
20+
from pytorch_lightning.loops.closure import ClosureResult
2021
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
2122
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
2223
from pytorch_lightning.trainer.progress import Progress, SchedulerProgress
@@ -283,18 +284,18 @@ def _track_epoch_end_reduce_metrics(
283284

284285
@staticmethod
285286
def _prepare_outputs(
286-
outputs: List[List[List["ResultCollection"]]], batch_mode: bool
287+
outputs: List[List[List[ClosureResult]]], batch_mode: bool
287288
) -> Union[List[List[List[Dict]]], List[List[Dict]], List[Dict], Dict]:
288289
"""Extract required information from batch or epoch end results.
289290
290291
Args:
291-
outputs: A 3-dimensional list of ``ResultCollection`` objects with dimensions:
292+
outputs: A 3-dimensional list of ``ClosureResult`` objects with dimensions:
292293
``[optimizer outs][batch outs][tbptt steps]``.
293294
294295
batch_mode: If True, ignore the batch output dimension.
295296
296297
Returns:
297-
The cleaned outputs with ``ResultCollection`` objects converted to dictionaries.
298+
The cleaned outputs with ``ClosureResult`` objects converted to dictionaries.
298299
All list dimensions of size one will be collapsed.
299300
"""
300301
processed_outputs = []
@@ -311,13 +312,13 @@ def _prepare_outputs(
311312
for batch_outputs in opt_outputs:
312313
processed_tbptt_outputs = []
313314

314-
if isinstance(batch_outputs, ResultCollection):
315+
if isinstance(batch_outputs, ClosureResult):
315316
batch_outputs = [batch_outputs]
316317

317318
for tbptt_output in batch_outputs:
318319
out = {}
319-
if tbptt_output.minimize is not None:
320-
out["loss"] = tbptt_output.minimize.detach()
320+
if tbptt_output.loss is not None:
321+
out["loss"] = tbptt_output.loss
321322
out.update(tbptt_output.extra)
322323
processed_tbptt_outputs.append(out)
323324

pytorch_lightning/loops/optimizer/optimizer_loop.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
from copy import deepcopy
1614
from functools import partial
1715
from typing import Any, Callable, Dict, List, Optional
1816

@@ -27,16 +25,16 @@
2725
_block_parallel_sync_behavior,
2826
_build_training_step_kwargs,
2927
_check_training_step_output,
30-
_process_training_step_output,
28+
_extract_hiddens,
29+
check_finite_loss,
3130
)
32-
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
3331
from pytorch_lightning.trainer.progress import OptimizationProgress
34-
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
32+
from pytorch_lightning.utilities import AMPType, DeviceType, grad_norm
3533
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3634
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
3735
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
3836

39-
_OUTPUTS_TYPE = List[List[Optional[ResultCollection]]]
37+
_OUTPUTS_TYPE = List[List[ClosureResult]]
4038

4139

4240
class OptimizerLoop(Loop):
@@ -80,8 +78,8 @@ def advance(self, batch: Any, *args, **kwargs) -> None: # type: ignore[override
8078
self._optimizers[self.optim_progress.optimizer_idx],
8179
self.optim_progress.optimizer_idx,
8280
)
83-
if result.result_collection is not None:
84-
self.outputs[self.optim_progress.optimizer_idx].append(deepcopy(result.result_collection))
81+
if result.loss is not None:
82+
self.outputs[self.optim_progress.optimizer_idx].append(result.drop_closure_loss())
8583

8684
self.optim_progress.optimizer_idx += 1
8785

@@ -168,7 +166,7 @@ def _make_closure(self, split_batch: Any, batch_idx: int, opt_idx: int, optimize
168166
step_fn=step_fn, backward_fn=backward_fn, zero_grad_fn=zero_grad_fn, profiler=self.trainer.profiler
169167
)
170168

171-
def _make_step_fn(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Callable[[], Optional[AttributeDict]]:
169+
def _make_step_fn(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Callable[[], ClosureResult]:
172170
"""Build the step function that runs the `training_step` and processes its output."""
173171
return partial(self._training_step, split_batch, batch_idx, opt_idx)
174172

@@ -241,7 +239,7 @@ def _optimizer_step(
241239
train_step_and_backward_closure: the closure function performing the train step and computing the
242240
gradients. By default called by the optimizer (if possible)
243241
"""
244-
model_ref = self.trainer.lightning_module
242+
lightning_module = self.trainer.lightning_module
245243

246244
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
247245
using_native_amp = self.trainer.amp_backend is not None and self.trainer.amp_backend == AMPType.NATIVE
@@ -259,7 +257,7 @@ def _optimizer_step(
259257
self.optim_progress.optimizer.step.increment_ready()
260258

261259
# model hook
262-
model_ref.optimizer_step(
260+
lightning_module.optimizer_step(
263261
self.trainer.current_epoch,
264262
batch_idx,
265263
optimizer,
@@ -293,7 +291,7 @@ def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer,
293291
self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
294292
self.optim_progress.optimizer.zero_grad.increment_completed()
295293

296-
def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Optional[AttributeDict]:
294+
def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult:
297295
"""Performs the actual train step with the tied hooks.
298296
299297
Args:
@@ -302,19 +300,19 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Opti
302300
opt_idx: the index of the current optimizer
303301
304302
Returns:
305-
an AttributeDict containing the loss value and the training step output.
303+
A ``ClosureResult`` containing the training step output.
306304
"""
307305
# give the PL module a result for logging
308-
model_ref = self.trainer.lightning_module
306+
lightning_module = self.trainer.lightning_module
309307

310308
with self.trainer.profiler.profile("model_forward"):
311309

312310
step_kwargs = _build_training_step_kwargs(
313-
self.trainer.lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, self._hiddens
311+
lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, self._hiddens
314312
)
315313

316314
# manually capture logged metrics
317-
model_ref._current_fx_name = "training_step"
315+
lightning_module._current_fx_name = "training_step"
318316
with self.trainer.profiler.profile("training_step"):
319317
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
320318
self.trainer.accelerator.post_training_step()
@@ -323,20 +321,20 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Opti
323321

324322
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
325323

326-
_check_training_step_output(self.trainer.lightning_module, training_step_output)
324+
_check_training_step_output(lightning_module, training_step_output)
325+
326+
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)
327327

328-
result_collection, self._hiddens = _process_training_step_output(self.trainer, training_step_output)
329-
if result_collection is None:
330-
return None
328+
result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches)
331329

332-
# output validation already done, here loss can't be None
333-
assert result_collection.minimize is not None
330+
if self.trainer.terminate_on_nan:
331+
check_finite_loss(result.closure_loss)
332+
333+
if self.trainer.move_metrics_to_cpu:
334+
# hiddens and the training step output are not moved as they are not considered "metrics"
335+
self.trainer._results.cpu()
334336

335-
# accumulate loss. if accumulate_grad_batches==1, no effect
336-
closure_loss = result_collection.minimize / self.trainer.accumulate_grad_batches
337-
# the loss will get scaled for amp. avoid any modifications to it
338-
loss = closure_loss.detach().clone()
339-
return AttributeDict(closure_loss=closure_loss, loss=loss, result_collection=result_collection)
337+
return result
340338

341339
def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, float]:
342340
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.

0 commit comments

Comments
 (0)