Skip to content

Commit 23450e2

Browse files
authored
Add custom logic to each OutputResult subclass [2/2] (#9424)
1 parent ad36a32 commit 23450e2

13 files changed

+151
-219
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7676
* Refactored `TrainingBatchLoop` and extracted `OptimizerLoop`, splitting off automatic optimization into its own loop ([#9191](https://github.com/PyTorchLightning/pytorch-lightning/pull/9191))
7777
* Removed `TrainingBatchLoop.backward()`; manual optimization now calls directly into `Accelerator.backward()` and automatic optimization handles backward in new `OptimizerLoop` ([#9265](https://github.com/PyTorchLightning/pytorch-lightning/pull/9265))
7878
* Extracted `ManualOptimization` logic from `TrainingBatchLoop` into its own separate loop class ([#9266](https://github.com/PyTorchLightning/pytorch-lightning/pull/9266))
79-
* Added `OutputResult` and `ManualResult` classes ([#9437](https://github.com/PyTorchLightning/pytorch-lightning/pull/9437))
79+
* Added `OutputResult` and `ManualResult` classes ([#9437](https://github.com/PyTorchLightning/pytorch-lightning/pull/9437), [#9424](https://github.com/PyTorchLightning/pytorch-lightning/pull/9424))
8080

8181

8282
- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))

pytorch_lightning/loops/batch/training_batch_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ def advance(self, batch, batch_idx):
130130
else:
131131
# in manual optimization, hand over execution to the ManualOptimization loop
132132
result = self.manual_loop.run(split_batch, batch_idx)
133-
if result is not None and result.loss is not None:
134-
self.batch_outputs[0].append(result.drop_closure_loss())
133+
if result:
134+
self.batch_outputs[0].append(result)
135135

136136
def on_run_end(self) -> None:
137137
self.optimizer_loop._hiddens = None

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -312,18 +312,7 @@ def _prepare_outputs(
312312
opt_outputs = [opt_outputs]
313313

314314
for batch_outputs in opt_outputs:
315-
processed_tbptt_outputs = []
316-
317-
if isinstance(batch_outputs, OutputResult):
318-
batch_outputs = [batch_outputs]
319-
320-
for tbptt_output in batch_outputs:
321-
out = {}
322-
if tbptt_output.loss is not None:
323-
out["loss"] = tbptt_output.loss
324-
out.update(tbptt_output.extra)
325-
processed_tbptt_outputs.append(out)
326-
315+
processed_tbptt_outputs = batch_outputs if isinstance(batch_outputs, list) else [batch_outputs]
327316
# if there was only one tbptt step then we can collapse that dimension
328317
if len(processed_tbptt_outputs) == 1:
329318
processed_tbptt_outputs = processed_tbptt_outputs[0]

pytorch_lightning/loops/optimization/closure.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,37 @@
1313
# limitations under the License.
1414
from abc import ABC, abstractmethod
1515
from dataclasses import dataclass
16-
from typing import Any, Generic, Optional, TypeVar
16+
from typing import Any, Dict, Generic, Optional, TypeVar
1717

18+
from torch import Tensor
19+
20+
from pytorch_lightning.utilities import rank_zero_deprecation
21+
from pytorch_lightning.utilities.apply_func import apply_to_collection
1822
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1923

2024
T = TypeVar("T")
2125

2226

2327
@dataclass
2428
class OutputResult:
25-
...
29+
@staticmethod
30+
def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> Dict[str, Any]:
31+
# TODO: remove with the deprecation removal in v1.6
32+
# this is only here to avoid duplication
33+
def check_fn(v: Tensor) -> Tensor:
34+
if v.grad_fn is not None:
35+
rank_zero_deprecation(
36+
f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"
37+
" but this behaviour will change in v1.6. Please detach it manually:"
38+
" `return {'loss': ..., 'something': something.detach()}`"
39+
)
40+
return v.detach()
41+
return v
42+
43+
return apply_to_collection(extra, Tensor, check_fn)
44+
45+
def asdict(self) -> Dict[str, Any]:
46+
raise NotImplementedError
2647

2748

2849
class AbstractClosure(ABC, Generic[T]):
@@ -33,7 +54,7 @@ class AbstractClosure(ABC, Generic[T]):
3354
object which later can call it like a function but without requiring to pass in any arguments.
3455
3556
This class provides a simple abstraction making the instance of this class callable like a function while capturing
36-
the :class:`OutputResult` and caching it.
57+
the closure result and caching it.
3758
"""
3859

3960
def __init__(self) -> None:

pytorch_lightning/loops/optimization/manual_loop.py

Lines changed: 26 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,9 @@
1818

1919
from pytorch_lightning.loops import Loop
2020
from pytorch_lightning.loops.optimization.closure import OutputResult
21-
from pytorch_lightning.loops.utilities import (
22-
_build_training_step_kwargs,
23-
_check_training_step_output,
24-
_extract_hiddens,
25-
check_finite_loss,
26-
)
27-
from pytorch_lightning.utilities.apply_func import apply_to_collection
28-
from pytorch_lightning.utilities.memory import recursive_detach
21+
from pytorch_lightning.loops.utilities import _build_training_step_kwargs, _extract_hiddens
22+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2923
from pytorch_lightning.utilities.types import STEP_OUTPUT
30-
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
3124

3225

3326
@dataclass
@@ -37,66 +30,45 @@ class ManualResult(OutputResult):
3730
It is created from the output of :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`.
3831
3932
Attributes:
40-
closure_loss: The loss with a graph attached.
41-
loss: A detached copy of the closure loss.
42-
extra: Any keys other than the loss returned.
33+
extra: Anything returned by the ``training_step``.
4334
"""
4435

45-
closure_loss: Optional[Tensor]
46-
loss: Optional[Tensor] = field(init=False, default=None)
47-
extra: Dict[str, Tensor] = field(default_factory=dict)
36+
extra: Dict[str, Any] = field(default_factory=dict)
4837

4938
def __post_init__(self) -> None:
5039
# TODO: remove with the deprecation removal in v1.6
51-
self._check_extra_detach_deprecation(self.extra)
52-
self.extra = recursive_detach(self.extra)
53-
54-
self._clone_loss()
55-
56-
def _clone_loss(self) -> None:
57-
if self.closure_loss is not None:
58-
# the loss will get scaled for amp. avoid any modifications to it
59-
self.loss = self.closure_loss.detach().clone()
40+
self.extra = self._check_extra_detach_deprecation(self.extra)
6041

6142
@classmethod
6243
def from_training_step_output(
6344
cls, training_step_output: Optional[STEP_OUTPUT], normalize: int = 1
6445
) -> "ManualResult":
65-
closure_loss, extra = None, {}
66-
46+
extra = {}
6747
if isinstance(training_step_output, dict):
68-
# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
69-
closure_loss = training_step_output.get("loss")
70-
extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}
48+
extra = {k: v for k, v in training_step_output.items() if k != "hiddens"}
7149
elif isinstance(training_step_output, Tensor):
72-
closure_loss = training_step_output
50+
extra = {"loss": training_step_output}
51+
elif training_step_output is not None:
52+
raise MisconfigurationException(
53+
"In manual optimization, `training_step` must either return a Tensor, "
54+
"a dict with extras to pass to `training_epoch_end` or have no return."
55+
)
7356

74-
if closure_loss is not None:
75-
# accumulate the loss. If ``accumulate_grad_batches == 1``, no effect
76-
closure_loss /= normalize
57+
if "loss" in extra:
58+
# accumulate the loss. If `accumulate_grad_batches == 1`, no effect.
59+
# we detach manually as it's expected that it will have a `grad_fn`
60+
extra["loss"] = extra["loss"].detach().div(normalize)
7761

78-
return cls(closure_loss, extra=extra)
62+
return cls(extra=extra)
7963

80-
@staticmethod
81-
def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> None:
82-
def check_fn(v: Tensor) -> Tensor:
83-
if v.grad_fn is not None:
84-
rank_zero_deprecation(
85-
f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"
86-
" but this behaviour will change in v1.6. Please detach it manually:"
87-
" `return {'loss': ..., 'something': something.detach()}`"
88-
)
89-
return v
64+
def asdict(self) -> Dict[str, Any]:
65+
return self.extra
9066

91-
apply_to_collection(extra, Tensor, check_fn)
9267

93-
def drop_closure_loss(self) -> "ManualResult":
94-
"""Return itself without the closure loss which could have a `grad_fn`"""
95-
self.closure_loss = None
96-
return self
68+
_OUTPUTS_TYPE = Dict[str, Any]
9769

9870

99-
class ManualOptimization(Loop):
71+
class ManualOptimization(Loop[_OUTPUTS_TYPE]):
10072
"""A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens
10173
entirely in the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` and therefore the user
10274
is responsible for back-propagating gradients and making calls to the optimizers.
@@ -109,7 +81,7 @@ def __init__(self) -> None:
10981
super().__init__()
11082
self._done: bool = False
11183
self._hiddens: Optional[Any] = None
112-
self._output: Optional[ManualResult] = None
84+
self._output: _OUTPUTS_TYPE = {}
11385

11486
@property
11587
def done(self) -> bool:
@@ -144,27 +116,22 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
144116

145117
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
146118

147-
_check_training_step_output(lightning_module, training_step_output)
148-
149119
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)
150120

151121
result = ManualResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches)
152122

153-
if self.trainer.terminate_on_nan:
154-
check_finite_loss(result.closure_loss)
155-
156123
if self.trainer.move_metrics_to_cpu:
157124
# hiddens and the training step output are not moved as they are not considered "metrics"
158125
# the user might need them on the correct device for an operation in `training_epoch_end`
159126
assert self.trainer._results is not None
160127
self.trainer._results.cpu()
161128

162129
self._done = True
163-
self._output = result
130+
self._output = result.asdict()
164131

165-
def on_run_end(self) -> Optional[ManualResult]:
132+
def on_run_end(self) -> _OUTPUTS_TYPE:
166133
"""Returns the result of this loop, i.e., the post-processed outputs from the training step."""
167-
output, self._output = self._output, None # free memory
134+
output, self._output = self._output, {} # free memory
168135
# #9052 added support for raising `StopIteration` in the `training_step`. If that happens, then `advance`
169136
# doesn't finish and `self._output` stays as `None`. If #9415 happens then this would always return a result
170137
return output

pytorch_lightning/loops/optimization/optimizer_loop.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,17 @@
2525
from pytorch_lightning.loops.utilities import (
2626
_block_parallel_sync_behavior,
2727
_build_training_step_kwargs,
28-
_check_training_step_output,
2928
_extract_hiddens,
3029
check_finite_loss,
3130
)
3231
from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler
3332
from pytorch_lightning.trainer.progress import OptimizationProgress
3433
from pytorch_lightning.utilities import AMPType, DeviceType, grad_norm
35-
from pytorch_lightning.utilities.apply_func import apply_to_collection
3634
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3735
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
3836
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
39-
from pytorch_lightning.utilities.memory import recursive_detach
4037
from pytorch_lightning.utilities.types import STEP_OUTPUT
41-
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, WarningCache
38+
from pytorch_lightning.utilities.warnings import WarningCache
4239

4340

4441
@dataclass
@@ -55,12 +52,11 @@ class ClosureResult(OutputResult):
5552

5653
closure_loss: Optional[Tensor]
5754
loss: Optional[Tensor] = field(init=False, default=None)
58-
extra: Dict[str, Tensor] = field(default_factory=dict)
55+
extra: Dict[str, Any] = field(default_factory=dict)
5956

6057
def __post_init__(self) -> None:
6158
# TODO: remove with the deprecation removal in v1.6
62-
ClosureResult._check_extra_detach_deprecation(self.extra)
63-
self.extra = recursive_detach(self.extra)
59+
self.extra = self._check_extra_detach_deprecation(self.extra)
6460

6561
self._clone_loss()
6662

@@ -78,33 +74,27 @@ def from_training_step_output(
7874
if isinstance(training_step_output, dict):
7975
# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
8076
closure_loss = training_step_output.get("loss")
77+
if closure_loss is None:
78+
raise MisconfigurationException(
79+
"In automatic_optimization, when `training_step` returns a dict, the 'loss' key needs to be present"
80+
)
8181
extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}
8282
elif isinstance(training_step_output, Tensor):
8383
closure_loss = training_step_output
84+
elif training_step_output is not None:
85+
raise MisconfigurationException(
86+
"In automatic optimization, `training_step` must return a Tensor, "
87+
"a dict, or None (where the step will be skipped)."
88+
)
8489

8590
if closure_loss is not None:
8691
# accumulate the loss. If ``accumulate_grad_batches == 1``, no effect
8792
closure_loss /= normalize
8893

8994
return cls(closure_loss, extra=extra)
9095

91-
@staticmethod
92-
def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> None:
93-
def check_fn(v: Tensor) -> Tensor:
94-
if v.grad_fn is not None:
95-
rank_zero_deprecation(
96-
f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"
97-
" but this behaviour will change in v1.6. Please detach it manually:"
98-
" `return {'loss': ..., 'something': something.detach()}`"
99-
)
100-
return v
101-
102-
apply_to_collection(extra, Tensor, check_fn)
103-
104-
def drop_closure_loss(self) -> "ClosureResult":
105-
"""Return itself without the closure loss which could have a `grad_fn`"""
106-
self.closure_loss = None
107-
return self
96+
def asdict(self) -> Dict[str, Any]:
97+
return {"loss": self.loss, **self.extra}
10898

10999

110100
class Closure(AbstractClosure[ClosureResult]):
@@ -170,7 +160,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
170160
return self._result.loss
171161

172162

173-
_OUTPUTS_TYPE = List[List[ClosureResult]]
163+
_OUTPUTS_TYPE = List[List[Dict[str, Any]]]
174164

175165

176166
class OptimizerLoop(Loop):
@@ -222,7 +212,9 @@ def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignor
222212
self.optimizer_idx,
223213
)
224214
if result.loss is not None:
225-
self.outputs[self.optimizer_idx].append(result.drop_closure_loss())
215+
# automatic optimization assumes a loss needs to be returned for extras to be considered as the batch
216+
# would be skipped otherwise
217+
self.outputs[self.optimizer_idx].append(result.asdict())
226218
self.optim_progress.optimizer_position += 1
227219

228220
def on_run_end(self) -> _OUTPUTS_TYPE:
@@ -467,8 +459,6 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos
467459

468460
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
469461

470-
_check_training_step_output(lightning_module, training_step_output)
471-
472462
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)
473463

474464
result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches)

pytorch_lightning/loops/utilities.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
from collections import OrderedDict
1515
from contextlib import contextmanager
16-
from typing import Any, Dict, Generator, Iterator, Mapping, Optional, Sequence
16+
from typing import Any, Dict, Generator, Iterator, Optional, Sequence
1717

1818
import torch
1919
from torch.optim import Optimizer
@@ -37,31 +37,6 @@ def check_finite_loss(loss: Optional[torch.Tensor]) -> None:
3737
raise ValueError(f"The loss returned in `training_step` is {loss}.")
3838

3939

40-
def _check_training_step_output(model: "pl.LightningModule", training_step_output: STEP_OUTPUT) -> None:
41-
"""Sanity checks that training produced a valid output.
42-
43-
Args:
44-
model: a reference to the trainer
45-
training_step_output: the output of the training step (before wrapping in an AttributeDict)
46-
"""
47-
if (
48-
isinstance(training_step_output, torch.Tensor)
49-
and not model.automatic_optimization
50-
and training_step_output.grad_fn is None
51-
):
52-
# TODO: in manual optimization, anything returned should be considered an `extra`
53-
raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor")
54-
if model.automatic_optimization and not (
55-
isinstance(training_step_output, torch.Tensor)
56-
or (isinstance(training_step_output, Mapping) and "loss" in training_step_output)
57-
or training_step_output is None
58-
):
59-
raise MisconfigurationException(
60-
"In automatic optimization, `training_step` must either return a Tensor, "
61-
"a dict with key 'loss' or None (where the step will be skipped)."
62-
)
63-
64-
6540
def _extract_hiddens(training_step_output: STEP_OUTPUT, truncated_bptt_steps: int) -> Optional[Any]:
6641
"""Get the hidden state if present from the training step output.
6742

0 commit comments

Comments
 (0)