12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
from abc import ABC , abstractmethod
15
- from dataclasses import dataclass , field
16
- from typing import Any , Callable , Dict , Optional
15
+ from dataclasses import dataclass
16
+ from typing import Any , Optional
17
17
18
- from torch import Tensor
19
-
20
- from pytorch_lightning .profiler import BaseProfiler , PassThroughProfiler
21
- from pytorch_lightning .utilities .apply_func import apply_to_collection
22
18
from pytorch_lightning .utilities .exceptions import MisconfigurationException
23
- from pytorch_lightning .utilities .memory import recursive_detach
24
- from pytorch_lightning .utilities .types import STEP_OUTPUT
25
- from pytorch_lightning .utilities .warnings import rank_zero_deprecation , WarningCache
26
19
27
20
28
21
@dataclass
29
- class ClosureResult :
30
- """A container to hold the result of a :class:`AbstractClosure` call.
31
-
32
- It is created from the output of :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`.
33
-
34
- Attributes:
35
- closure_loss: The loss with a graph attached.
36
- loss: A detached copy of the closure loss.
37
- extra: Any keys other than the loss returned.
38
- """
39
-
40
- closure_loss : Optional [Tensor ]
41
- loss : Optional [Tensor ] = field (init = False , default = None )
42
- extra : Dict [str , Tensor ] = field (default_factory = dict )
43
-
44
- def __post_init__ (self ) -> None :
45
- # TODO: remove with the deprecation removal in v1.6
46
- ClosureResult ._check_extra_detach_deprecation (self .extra )
47
- self .extra = recursive_detach (self .extra )
48
-
49
- self ._clone_loss ()
50
-
51
- def _clone_loss (self ) -> None :
52
- if self .closure_loss is not None :
53
- # the loss will get scaled for amp. avoid any modifications to it
54
- self .loss = self .closure_loss .detach ().clone ()
55
-
56
- @classmethod
57
- def from_training_step_output (
58
- cls , training_step_output : Optional [STEP_OUTPUT ], normalize : int = 1
59
- ) -> "ClosureResult" :
60
- closure_loss , extra = None , {}
61
-
62
- if isinstance (training_step_output , dict ):
63
- # this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
64
- closure_loss = training_step_output .get ("loss" )
65
- extra = {k : v for k , v in training_step_output .items () if k not in ("loss" , "hiddens" )}
66
- elif isinstance (training_step_output , Tensor ):
67
- closure_loss = training_step_output
68
-
69
- if closure_loss is not None :
70
- # accumulate the loss. If ``accumulate_grad_batches == 1``, no effect
71
- closure_loss /= normalize
72
-
73
- return cls (closure_loss , extra = extra )
74
-
75
- @staticmethod
76
- def _check_extra_detach_deprecation (extra : Dict [str , Any ]) -> None :
77
- def check_fn (v : Tensor ) -> Tensor :
78
- if v .grad_fn is not None :
79
- rank_zero_deprecation (
80
- f"One of the returned values { set (extra .keys ())} has a `grad_fn`. We will detach it automatically"
81
- " but this behaviour will change in v1.6. Please detach it manually:"
82
- " `return {'loss': ..., 'something': something.detach()}`"
83
- )
84
- return v
85
-
86
- apply_to_collection (extra , Tensor , check_fn )
87
-
88
- def drop_closure_loss (self ) -> "ClosureResult" :
89
- """Return itself without the closure loss which could have a `grad_fn`"""
90
- self .closure_loss = None
91
- return self
22
+ class OutputResult :
23
+ ...
92
24
93
25
94
26
class AbstractClosure (ABC ):
@@ -99,14 +31,14 @@ class AbstractClosure(ABC):
99
31
object which later can call it like a function but without requiring to pass in any arguments.
100
32
101
33
This class provides a simple abstraction making the instance of this class callable like a function while capturing
102
- the :class:`ClosureResult ` and caching it.
34
+ the :class:`OutputResult ` and caching it.
103
35
"""
104
36
105
37
def __init__ (self ) -> None :
106
38
super ().__init__ ()
107
- self ._result : Optional [ClosureResult ] = None
39
+ self ._result : Optional [OutputResult ] = None
108
40
109
- def consume_result (self ) -> ClosureResult :
41
+ def consume_result (self ) -> OutputResult :
110
42
"""The cached result from the last time the closure was called.
111
43
112
44
Once accessed, the internal reference gets reset and the consumer will have to hold on to the reference as long
@@ -122,69 +54,10 @@ def consume_result(self) -> ClosureResult:
122
54
return result
123
55
124
56
@abstractmethod
125
- def closure (self , * args : Any , ** kwargs : Any ) -> ClosureResult :
57
+ def closure (self , * args : Any , ** kwargs : Any ) -> OutputResult :
126
58
"""Implements the behavior of the closure once it is getting called."""
127
59
pass
128
60
129
- def __call__ (self , * args : Any , ** kwargs : Any ) -> Optional [ Tensor ] :
61
+ def __call__ (self , * args : Any , ** kwargs : Any ) -> "AbstractClosure" :
130
62
self ._result = self .closure (* args , ** kwargs )
131
- return self ._result .loss
132
-
133
-
134
- class Closure (AbstractClosure ):
135
- """An implementation of a :class:`AbstractClosure` for optimization in Lightning that combines three elementary
136
- closures into one: ``training_step``, ``backward`` and ``zero_grad``.
137
-
138
- The Closure gets created by the training loop(s) and is then passed to the
139
- :meth:`torch.optim.Optimizer.step` method. An optimizer is responsible for calling the closure and optionally
140
- do something with the output.
141
-
142
- Args:
143
- step_fn: This is typically the :meth:`pytorch_lightning.core.lightning.LightningModule.training_step
144
- wrapped with processing for its outputs
145
- backward_fn: A function that takes a loss value as input, performs back-propagation and returns the loss value.
146
- Can be set to ``None`` to skip the backward operation.
147
- zero_grad_fn: A function that zeroes the gradients. Can be set to ``None`` to skip zero_grad, for example
148
- when accumulating gradients.
149
- profiler: A profiler for profiling the actions of the passed in closure functions.
150
-
151
- Example:
152
-
153
- closure = Closure()
154
- optimizer = torch.optim.Adam(...)
155
- optimizer.step(closure)
156
- """
157
-
158
- warning_cache = WarningCache ()
159
-
160
- def __init__ (
161
- self ,
162
- step_fn : Callable [[], ClosureResult ],
163
- backward_fn : Optional [Callable [[Tensor ], Tensor ]] = None ,
164
- zero_grad_fn : Optional [Callable [[], None ]] = None ,
165
- profiler : Optional [BaseProfiler ] = None ,
166
- ):
167
- super ().__init__ ()
168
- self ._step_fn = step_fn
169
- self ._backward_fn = backward_fn
170
- self ._zero_grad_fn = zero_grad_fn
171
- self ._profiler = PassThroughProfiler () if profiler is None else profiler
172
-
173
- def closure (self , * args : Any , ** kwargs : Any ) -> ClosureResult :
174
- with self ._profiler .profile ("training_step_and_backward" ):
175
- step_output = self ._step_fn ()
176
-
177
- if step_output .closure_loss is None :
178
- self .warning_cache .warn (
179
- "`training_step` returned `None`. If this was on purpose, ignore this warning..."
180
- )
181
-
182
- if self ._zero_grad_fn is not None :
183
- with self ._profiler .profile ("zero_grad" ):
184
- self ._zero_grad_fn ()
185
-
186
- if self ._backward_fn is not None and step_output .closure_loss is not None :
187
- with self ._profiler .profile ("backward" ):
188
- step_output .closure_loss = self ._backward_fn (step_output .closure_loss )
189
-
190
- return step_output
63
+ return self
0 commit comments