Skip to content

Commit ca679cd

Browse files
awaelchlicarmoccatchaton
authored
Add ManualOptimization loop (#9266)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: thomas chaton <[email protected]>
1 parent a79c351 commit ca679cd

File tree

9 files changed

+179
-152
lines changed

9 files changed

+179
-152
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6969
* Added `Closure` and `AbstractClosure` classes ([#8642](https://github.com/PyTorchLightning/pytorch-lightning/pull/8642))
7070
* Refactored `TrainingBatchLoop` and extracted `OptimizerLoop`, splitting off automatic optimization into its own loop ([#9191](https://github.com/PyTorchLightning/pytorch-lightning/pull/9191))
7171
* Removed `TrainingBatchLoop.backward()`; manual optimization now calls directly into `Accelerator.backward()` and automatic optimization handles backward in new `OptimizerLoop` ([#9265](https://github.com/PyTorchLightning/pytorch-lightning/pull/9265))
72+
* Extracted `ManualOptimization` logic from `TrainingBatchLoop` into its own separate loop class ([#9266](https://github.com/PyTorchLightning/pytorch-lightning/pull/9266))
7273

7374
- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))
7475

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ ignore_errors = "True"
6363
module = [
6464
"pytorch_lightning.callbacks.pruning",
6565
"pytorch_lightning.loops.closure",
66+
"pytorch_lightning.loops.batch.manual",
6667
"pytorch_lightning.loops.optimizer",
6768
"pytorch_lightning.trainer.evaluation_loop",
6869
"pytorch_lightning.trainer.connectors.logger_connector.*",

pytorch_lightning/loops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from pytorch_lightning.loops.base import Loop # noqa: F401
16+
from pytorch_lightning.loops.batch import ManualOptimization # noqa: F401
1617
from pytorch_lightning.loops.batch import TrainingBatchLoop # noqa: F401
1718
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
1819
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401

pytorch_lightning/loops/batch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from pytorch_lightning.loops.batch.manual import ManualOptimization # noqa: F401
1516
from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Optional, Tuple
16+
17+
from pytorch_lightning.loops import Loop
18+
from pytorch_lightning.loops.utilities import (
19+
_build_training_step_kwargs,
20+
_check_training_step_output,
21+
_process_training_step_output,
22+
)
23+
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
24+
25+
26+
class ManualOptimization(Loop):
27+
"""A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens
28+
entirely in the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` and therefore the user
29+
is responsible for back-propagating gradients and making calls to the optimizers.
30+
31+
This loop is a trivial case because it performs only a single iteration (calling directly into the module's
32+
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`) and passing through the output(s).
33+
"""
34+
35+
def __init__(self) -> None:
36+
super().__init__()
37+
self._done: bool = False
38+
self._hiddens: Optional[Any] = None
39+
self._output: Optional[ResultCollection] = None
40+
41+
@property
42+
def done(self) -> bool:
43+
return self._done
44+
45+
def reset(self) -> None:
46+
self._done = False
47+
48+
def advance(self, batch: Any, batch_idx: int, hiddens: Optional[Any] = None) -> None: # type: ignore[override]
49+
"""Performs the training step for manual optimization.
50+
51+
Args:
52+
batch: the current tbptt split of the current batch
53+
batch_idx: the index of the current batch
54+
hiddens: the model's hidden state of the previous iteration
55+
"""
56+
assert self.trainer is not None
57+
ligtning_module = self.trainer.lightning_module
58+
59+
with self.trainer.profiler.profile("model_forward"):
60+
61+
step_kwargs = _build_training_step_kwargs(
62+
ligtning_module, self.trainer.optimizers, batch, batch_idx, opt_idx=None, hiddens=hiddens
63+
)
64+
65+
# manually capture logged metrics
66+
ligtning_module._current_fx_name = "training_step"
67+
with self.trainer.profiler.profile("training_step"):
68+
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
69+
self.trainer.accelerator.post_training_step()
70+
71+
del step_kwargs
72+
73+
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
74+
75+
_check_training_step_output(ligtning_module, training_step_output)
76+
77+
result_collection, hiddens = _process_training_step_output(self.trainer, training_step_output)
78+
79+
self._done = True
80+
self._hiddens = hiddens
81+
self._output = result_collection
82+
83+
def on_run_end(self) -> Tuple[Optional[ResultCollection], Optional[Any]]:
84+
"""Returns the result of this loop, i.e., the post-processed outputs from the training step, and the hidden
85+
state."""
86+
output = self._output
87+
hiddens = self._hiddens
88+
self._output, self._hiddens = None, None # free memory
89+
return output, hiddens

pytorch_lightning/loops/batch/training_batch_loop.py

Lines changed: 14 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from copy import deepcopy
15-
from functools import partial
16-
from typing import Any, Callable, List, Optional, Tuple
15+
from typing import Any, List, Optional, Tuple
1716

1817
import numpy as np
1918
from deprecate import void
2019
from torch import Tensor
2120
from torch.optim import Optimizer
2221

2322
from pytorch_lightning.loops.base import Loop
24-
from pytorch_lightning.loops.closure import Closure, ClosureResult
23+
from pytorch_lightning.loops.batch.manual import ManualOptimization
2524
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop
26-
from pytorch_lightning.loops.utilities import (
27-
_build_training_step_kwargs,
28-
_check_training_step_output,
29-
_process_training_step_output,
30-
)
3125
from pytorch_lightning.trainer.supporters import TensorRunningAccum
3226
from pytorch_lightning.utilities import AttributeDict
3327
from pytorch_lightning.utilities.types import STEP_OUTPUT
@@ -45,6 +39,7 @@ def __init__(self) -> None:
4539
# the current split index when the batch gets split into chunks in truncated backprop through time
4640
self.split_idx: Optional[int] = None
4741
self.optimizer_loop = OptimizerLoop()
42+
self.manual_loop = ManualOptimization()
4843

4944
self._warning_cache: WarningCache = WarningCache()
5045
self._hiddens: Optional[Tensor] = None
@@ -63,8 +58,13 @@ def optimizer_freq_cumsum(self) -> int:
6358
self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
6459
return self._optimizer_freq_cumsum
6560

66-
def connect(self, optimizer_loop: "Loop") -> None:
67-
self.optimizer_loop = optimizer_loop
61+
def connect(
62+
self, optimizer_loop: Optional["Loop"] = None, manual_loop: Optional[ManualOptimization] = None
63+
) -> None:
64+
if optimizer_loop is not None:
65+
self.optimizer_loop = optimizer_loop
66+
if manual_loop is not None:
67+
self.manual_loop = manual_loop
6868

6969
def run(self, batch: Any, batch_idx: int) -> AttributeDict:
7070
"""Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks.
@@ -132,10 +132,10 @@ def advance(self, batch, batch_idx):
132132
for k in range(len(batch_outputs)):
133133
self.batch_outputs[k].extend(batch_outputs[k])
134134
else:
135-
# in manual optimization, there is no looping over optimizers
136-
result = self._run_optimization(batch_idx, split_batch)
137-
if result:
138-
self.batch_outputs[0].append(deepcopy(result.result_collection))
135+
# in manual optimization, hand over execution to the ManualOptimization loop
136+
output, self._hiddens = self.manual_loop.run(split_batch, batch_idx, self._hiddens)
137+
if output:
138+
self.batch_outputs[0].append(deepcopy(output))
139139

140140
def teardown(self) -> None:
141141
# release memory
@@ -145,89 +145,6 @@ def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int:
145145
"""Gets the number of active optimizers based on their frequency."""
146146
return len(self.get_active_optimizers(batch_idx))
147147

148-
def _run_optimization(
149-
self,
150-
batch_idx: int,
151-
split_batch: Any,
152-
) -> Optional[ClosureResult]:
153-
"""Runs closure (train step + backward) together with optimization if necessary.
154-
155-
Args:
156-
batch_idx: the index of the current batch
157-
split_batch: the current tbptt split of the whole batch
158-
"""
159-
# TODO: replace call through closure by direct call (manual optimization)
160-
closure = self._make_closure(split_batch, batch_idx, self._hiddens)
161-
closure()
162-
result = closure.get_result()
163-
164-
if result:
165-
# if no result, user decided to skip optimization
166-
# otherwise update running loss + reset accumulated loss
167-
self._update_running_loss(result.loss)
168-
169-
return result
170-
171-
def _make_closure(
172-
self,
173-
split_batch: Any,
174-
batch_idx: int,
175-
hiddens: Any,
176-
) -> Closure:
177-
"""Build a closure object that captures the given arguments and runs the `training_step` function and
178-
optionally other functions such as `backward` and `zero_grad`."""
179-
step_fn = self._make_step_fn(split_batch, batch_idx, hiddens)
180-
backward_fn = None
181-
zero_grad_fn = None
182-
183-
return Closure(
184-
step_fn=step_fn,
185-
backward_fn=backward_fn,
186-
zero_grad_fn=zero_grad_fn,
187-
profiler=self.trainer.profiler,
188-
)
189-
190-
def _make_step_fn(self, split_batch: Any, batch_idx: int, hiddens: Any) -> Callable[[], dict]:
191-
"""Build the step function that runs the `training_step` and processes its output."""
192-
return partial(self._training_step, split_batch, batch_idx, hiddens)
193-
194-
def _training_step(self, split_batch: Any, batch_idx: int, hiddens: Tensor) -> Optional[AttributeDict]:
195-
"""Performs the training step for manual optimization.
196-
197-
Args:
198-
split_batch: the current tbptt split of the current batch
199-
batch_idx: the index of the current batch
200-
hiddens: the model's hidden state of the previous iteration
201-
202-
Returns:
203-
an AttributeDict containing the training step output.
204-
"""
205-
# give the PL module a result for logging
206-
model_ref = self.trainer.lightning_module
207-
208-
with self.trainer.profiler.profile("model_forward"):
209-
step_kwargs = _build_training_step_kwargs(
210-
model_ref, self.trainer.optimizers, split_batch, batch_idx, opt_idx=None, hiddens=hiddens
211-
)
212-
213-
# manually capture logged metrics
214-
model_ref._current_fx_name = "training_step"
215-
with self.trainer.profiler.profile("training_step"):
216-
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
217-
self.trainer.accelerator.post_training_step()
218-
219-
del step_kwargs
220-
221-
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
222-
223-
_check_training_step_output(self.trainer.lightning_module, training_step_output)
224-
225-
result_collection, self._hiddens = _process_training_step_output(self.trainer, training_step_output)
226-
if result_collection is None:
227-
return
228-
229-
return AttributeDict(closure_loss=None, loss=None, result_collection=result_collection)
230-
231148
def _tbptt_split_batch(self, batch: Any) -> List[Any]:
232149
"""Splits a single batch into a list of sequence steps for tbptt.
233150

tests/loops/test_loop_state_dict.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ def test_loops_state_dict_structure():
5656
"total": {"ready": 0, "completed": 0},
5757
"current": {"ready": 0, "completed": 0},
5858
},
59-
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
6059
"epoch_loop.batch_loop.state_dict": {},
60+
"epoch_loop.batch_loop.manual_loop.state_dict": {},
61+
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
6162
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
6263
"optimizer": {
6364
"step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},

tests/loops/test_loops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,7 @@ def configure_optimizers_multiple(self):
464464
"current": {"ready": be_sch_steps, "completed": be_sch_steps},
465465
},
466466
"epoch_loop.batch_loop.state_dict": ANY,
467+
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
467468
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
468469
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
469470
"optimizer_idx": stop_optimizer,

0 commit comments

Comments
 (0)