Skip to content

Commit 999ff93

Browse files
committed
Move results to their files
1 parent d46831e commit 999ff93

File tree

8 files changed

+270
-182
lines changed

8 files changed

+270
-182
lines changed

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from pytorch_lightning import loops # import as loops to avoid circular imports
1919
from pytorch_lightning.loops.batch import TrainingBatchLoop
20-
from pytorch_lightning.loops.closure import ClosureResult
20+
from pytorch_lightning.loops.optimization.closure import OutputResult
2121
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
2222
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
2323
from pytorch_lightning.trainer.progress import Progress, SchedulerProgress
@@ -284,18 +284,18 @@ def _track_epoch_end_reduce_metrics(
284284

285285
@staticmethod
286286
def _prepare_outputs(
287-
outputs: List[List[List[ClosureResult]]], batch_mode: bool
287+
outputs: List[List[List[OutputResult]]], batch_mode: bool
288288
) -> Union[List[List[List[Dict]]], List[List[Dict]], List[Dict], Dict]:
289289
"""Extract required information from batch or epoch end results.
290290
291291
Args:
292-
outputs: A 3-dimensional list of ``ClosureResult`` objects with dimensions:
292+
outputs: A 3-dimensional list of ``OutputResult`` objects with dimensions:
293293
``[optimizer outs][batch outs][tbptt steps]``.
294294
295295
batch_mode: If True, ignore the batch output dimension.
296296
297297
Returns:
298-
The cleaned outputs with ``ClosureResult`` objects converted to dictionaries.
298+
The cleaned outputs with ``OutputResult`` objects converted to dictionaries.
299299
All list dimensions of size one will be collapsed.
300300
"""
301301
processed_outputs = []
@@ -312,7 +312,7 @@ def _prepare_outputs(
312312
for batch_outputs in opt_outputs:
313313
processed_tbptt_outputs = []
314314

315-
if isinstance(batch_outputs, ClosureResult):
315+
if isinstance(batch_outputs, OutputResult):
316316
batch_outputs = [batch_outputs]
317317

318318
for tbptt_output in batch_outputs:

pytorch_lightning/loops/optimization/closure.py

Lines changed: 10 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -12,83 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from abc import ABC, abstractmethod
15-
from dataclasses import dataclass, field
16-
from typing import Any, Callable, Dict, Optional
15+
from dataclasses import dataclass
16+
from typing import Any, Optional
1717

18-
from torch import Tensor
19-
20-
from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler
21-
from pytorch_lightning.utilities.apply_func import apply_to_collection
2218
from pytorch_lightning.utilities.exceptions import MisconfigurationException
23-
from pytorch_lightning.utilities.memory import recursive_detach
24-
from pytorch_lightning.utilities.types import STEP_OUTPUT
25-
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, WarningCache
2619

2720

2821
@dataclass
29-
class ClosureResult:
30-
"""A container to hold the result of a :class:`AbstractClosure` call.
31-
32-
It is created from the output of :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`.
33-
34-
Attributes:
35-
closure_loss: The loss with a graph attached.
36-
loss: A detached copy of the closure loss.
37-
extra: Any keys other than the loss returned.
38-
"""
39-
40-
closure_loss: Optional[Tensor]
41-
loss: Optional[Tensor] = field(init=False, default=None)
42-
extra: Dict[str, Tensor] = field(default_factory=dict)
43-
44-
def __post_init__(self) -> None:
45-
# TODO: remove with the deprecation removal in v1.6
46-
ClosureResult._check_extra_detach_deprecation(self.extra)
47-
self.extra = recursive_detach(self.extra)
48-
49-
self._clone_loss()
50-
51-
def _clone_loss(self) -> None:
52-
if self.closure_loss is not None:
53-
# the loss will get scaled for amp. avoid any modifications to it
54-
self.loss = self.closure_loss.detach().clone()
55-
56-
@classmethod
57-
def from_training_step_output(
58-
cls, training_step_output: Optional[STEP_OUTPUT], normalize: int = 1
59-
) -> "ClosureResult":
60-
closure_loss, extra = None, {}
61-
62-
if isinstance(training_step_output, dict):
63-
# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
64-
closure_loss = training_step_output.get("loss")
65-
extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}
66-
elif isinstance(training_step_output, Tensor):
67-
closure_loss = training_step_output
68-
69-
if closure_loss is not None:
70-
# accumulate the loss. If ``accumulate_grad_batches == 1``, no effect
71-
closure_loss /= normalize
72-
73-
return cls(closure_loss, extra=extra)
74-
75-
@staticmethod
76-
def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> None:
77-
def check_fn(v: Tensor) -> Tensor:
78-
if v.grad_fn is not None:
79-
rank_zero_deprecation(
80-
f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"
81-
" but this behaviour will change in v1.6. Please detach it manually:"
82-
" `return {'loss': ..., 'something': something.detach()}`"
83-
)
84-
return v
85-
86-
apply_to_collection(extra, Tensor, check_fn)
87-
88-
def drop_closure_loss(self) -> "ClosureResult":
89-
"""Return itself without the closure loss which could have a `grad_fn`"""
90-
self.closure_loss = None
91-
return self
22+
class OutputResult:
23+
...
9224

9325

9426
class AbstractClosure(ABC):
@@ -99,14 +31,14 @@ class AbstractClosure(ABC):
9931
object which later can call it like a function but without requiring to pass in any arguments.
10032
10133
This class provides a simple abstraction making the instance of this class callable like a function while capturing
102-
the :class:`ClosureResult` and caching it.
34+
the :class:`OutputResult` and caching it.
10335
"""
10436

10537
def __init__(self) -> None:
10638
super().__init__()
107-
self._result: Optional[ClosureResult] = None
39+
self._result: Optional[OutputResult] = None
10840

109-
def consume_result(self) -> ClosureResult:
41+
def consume_result(self) -> OutputResult:
11042
"""The cached result from the last time the closure was called.
11143
11244
Once accessed, the internal reference gets reset and the consumer will have to hold on to the reference as long
@@ -122,69 +54,10 @@ def consume_result(self) -> ClosureResult:
12254
return result
12355

12456
@abstractmethod
125-
def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
57+
def closure(self, *args: Any, **kwargs: Any) -> OutputResult:
12658
"""Implements the behavior of the closure once it is getting called."""
12759
pass
12860

129-
def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
61+
def __call__(self, *args: Any, **kwargs: Any) -> "AbstractClosure":
13062
self._result = self.closure(*args, **kwargs)
131-
return self._result.loss
132-
133-
134-
class Closure(AbstractClosure):
135-
"""An implementation of a :class:`AbstractClosure` for optimization in Lightning that combines three elementary
136-
closures into one: ``training_step``, ``backward`` and ``zero_grad``.
137-
138-
The Closure gets created by the training loop(s) and is then passed to the
139-
:meth:`torch.optim.Optimizer.step` method. An optimizer is responsible for calling the closure and optionally
140-
do something with the output.
141-
142-
Args:
143-
step_fn: This is typically the :meth:`pytorch_lightning.core.lightning.LightningModule.training_step
144-
wrapped with processing for its outputs
145-
backward_fn: A function that takes a loss value as input, performs back-propagation and returns the loss value.
146-
Can be set to ``None`` to skip the backward operation.
147-
zero_grad_fn: A function that zeroes the gradients. Can be set to ``None`` to skip zero_grad, for example
148-
when accumulating gradients.
149-
profiler: A profiler for profiling the actions of the passed in closure functions.
150-
151-
Example:
152-
153-
closure = Closure()
154-
optimizer = torch.optim.Adam(...)
155-
optimizer.step(closure)
156-
"""
157-
158-
warning_cache = WarningCache()
159-
160-
def __init__(
161-
self,
162-
step_fn: Callable[[], ClosureResult],
163-
backward_fn: Optional[Callable[[Tensor], Tensor]] = None,
164-
zero_grad_fn: Optional[Callable[[], None]] = None,
165-
profiler: Optional[BaseProfiler] = None,
166-
):
167-
super().__init__()
168-
self._step_fn = step_fn
169-
self._backward_fn = backward_fn
170-
self._zero_grad_fn = zero_grad_fn
171-
self._profiler = PassThroughProfiler() if profiler is None else profiler
172-
173-
def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
174-
with self._profiler.profile("training_step_and_backward"):
175-
step_output = self._step_fn()
176-
177-
if step_output.closure_loss is None:
178-
self.warning_cache.warn(
179-
"`training_step` returned `None`. If this was on purpose, ignore this warning..."
180-
)
181-
182-
if self._zero_grad_fn is not None:
183-
with self._profiler.profile("zero_grad"):
184-
self._zero_grad_fn()
185-
186-
if self._backward_fn is not None and step_output.closure_loss is not None:
187-
with self._profiler.profile("backward"):
188-
step_output.closure_loss = self._backward_fn(step_output.closure_loss)
189-
190-
return step_output
63+
return self

pytorch_lightning/loops/optimization/manual_loop.py

Lines changed: 78 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,89 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Optional
14+
from dataclasses import dataclass, field
15+
from typing import Any, Dict, Optional
16+
17+
from torch import Tensor
1518

1619
from pytorch_lightning.loops import Loop
17-
from pytorch_lightning.loops.closure import ClosureResult
20+
from pytorch_lightning.loops.optimization.closure import OutputResult
1821
from pytorch_lightning.loops.utilities import (
1922
_build_training_step_kwargs,
2023
_check_training_step_output,
2124
_extract_hiddens,
2225
check_finite_loss,
2326
)
27+
from pytorch_lightning.utilities.apply_func import apply_to_collection
28+
from pytorch_lightning.utilities.memory import recursive_detach
29+
from pytorch_lightning.utilities.types import STEP_OUTPUT
30+
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
31+
32+
33+
@dataclass
34+
class ManualResult(OutputResult):
35+
"""A container to hold the result returned by the ``ManualLoop``.
36+
37+
It is created from the output of :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`.
38+
39+
Attributes:
40+
closure_loss: The loss with a graph attached.
41+
loss: A detached copy of the closure loss.
42+
extra: Any keys other than the loss returned.
43+
"""
44+
45+
closure_loss: Optional[Tensor]
46+
loss: Optional[Tensor] = field(init=False, default=None)
47+
extra: Dict[str, Tensor] = field(default_factory=dict)
48+
49+
def __post_init__(self) -> None:
50+
# TODO: remove with the deprecation removal in v1.6
51+
self._check_extra_detach_deprecation(self.extra)
52+
self.extra = recursive_detach(self.extra)
53+
54+
self._clone_loss()
55+
56+
def _clone_loss(self) -> None:
57+
if self.closure_loss is not None:
58+
# the loss will get scaled for amp. avoid any modifications to it
59+
self.loss = self.closure_loss.detach().clone()
60+
61+
@classmethod
62+
def from_training_step_output(
63+
cls, training_step_output: Optional[STEP_OUTPUT], normalize: int = 1
64+
) -> "ManualResult":
65+
closure_loss, extra = None, {}
66+
67+
if isinstance(training_step_output, dict):
68+
# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
69+
closure_loss = training_step_output.get("loss")
70+
extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}
71+
elif isinstance(training_step_output, Tensor):
72+
closure_loss = training_step_output
73+
74+
if closure_loss is not None:
75+
# accumulate the loss. If ``accumulate_grad_batches == 1``, no effect
76+
closure_loss /= normalize
77+
78+
return cls(closure_loss, extra=extra)
79+
80+
@staticmethod
81+
def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> None:
82+
def check_fn(v: Tensor) -> Tensor:
83+
if v.grad_fn is not None:
84+
rank_zero_deprecation(
85+
f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"
86+
" but this behaviour will change in v1.6. Please detach it manually:"
87+
" `return {'loss': ..., 'something': something.detach()}`"
88+
)
89+
return v
90+
91+
apply_to_collection(extra, Tensor, check_fn)
92+
93+
def drop_closure_loss(self) -> "ManualResult":
94+
"""Return itself without the closure loss which could have a `grad_fn`"""
95+
self.closure_loss = None
96+
return self
2497

2598

2699
class ManualOptimization(Loop):
@@ -36,7 +109,7 @@ def __init__(self) -> None:
36109
super().__init__()
37110
self._done: bool = False
38111
self._hiddens: Optional[Any] = None
39-
self._output: Optional[ClosureResult] = None
112+
self._output: Optional[ManualResult] = None
40113

41114
@property
42115
def done(self) -> bool:
@@ -75,8 +148,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
75148

76149
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)
77150

78-
# TODO: do not use `ClosureResult`
79-
result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches)
151+
result = ManualResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches)
80152

81153
if self.trainer.terminate_on_nan:
82154
check_finite_loss(result.closure_loss)
@@ -90,7 +162,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
90162
self._done = True
91163
self._output = result
92164

93-
def on_run_end(self) -> Optional[ClosureResult]:
165+
def on_run_end(self) -> Optional[ManualResult]:
94166
"""Returns the result of this loop, i.e., the post-processed outputs from the training step."""
95167
output, self._output = self._output, None # free memory
96168
# #9052 added support for raising `StopIteration` in the `training_step`. If that happens, then `advance`

0 commit comments

Comments
 (0)