Skip to content

Commit cf64f34

Browse files
ananthsubpre-commit-ci[bot]tchaton
authored
Refactor Strategy._move_optimizer_states as utility functions (#11758)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: thomas chaton <[email protected]>
1 parent d613719 commit cf64f34

File tree

9 files changed

+78
-25
lines changed

9 files changed

+78
-25
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
117117
- Added `Accelerator.is_available` to check device availability ([#11797](https://github.com/PyTorchLightning/pytorch-lightning/pull/11797))
118118

119119

120+
- Added utility functions for moving optimizers to devices ([#11758](https://github.com/PyTorchLightning/pytorch-lightning/pull/11758))
121+
122+
120123
### Changed
121124

122125
- Implemented a new native and rich format in `_print_results` method of the `EvaluationLoop` ([#11332](https://github.com/PyTorchLightning/pytorch-lightning/pull/11332))

docs/source/api_references.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ Utilities API
288288
finite_checks
289289
memory
290290
model_summary
291+
optimizer
291292
parsing
292293
rank_zero
293294
seed

pytorch_lightning/strategies/deepspeed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4040
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
4141
from pytorch_lightning.utilities.model_helpers import is_overridden
42+
from pytorch_lightning.utilities.optimizer import optimizers_to_device
4243
from pytorch_lightning.utilities.rank_zero import rank_zero_info
4344
from pytorch_lightning.utilities.seed import reset_seed
4445
from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, LRSchedulerTypeUnion, STEP_OUTPUT
@@ -349,7 +350,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
349350
self.accelerator.setup(trainer)
350351
self.setup_optimizers(trainer)
351352
self.setup_precision_plugin()
352-
self._move_optimizer_state()
353+
optimizers_to_device(self.optimizers, self.root_device)
353354
self.init_deepspeed()
354355
self.barrier()
355356

pytorch_lightning/strategies/fully_sharded.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
2626
from pytorch_lightning.utilities.enums import PrecisionType
2727
from pytorch_lightning.utilities.exceptions import MisconfigurationException
28+
from pytorch_lightning.utilities.optimizer import optimizers_to_device
2829
from pytorch_lightning.utilities.types import STEP_OUTPUT
2930

3031
if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
@@ -136,7 +137,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
136137
self.accelerator.setup(trainer)
137138
self.setup_optimizers(trainer)
138139
self.setup_precision_plugin()
139-
self._move_optimizer_state()
140+
optimizers_to_device(self.optimizers, self.root_device)
140141

141142
if self.sync_batchnorm:
142143
self.model = self.configure_sync_batchnorm(self.model)

pytorch_lightning/strategies/strategy.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@
3030
from pytorch_lightning.strategies.launchers.base import _Launcher
3131
from pytorch_lightning.trainer.states import TrainerFn
3232
from pytorch_lightning.utilities import rank_zero_deprecation
33-
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
33+
from pytorch_lightning.utilities.apply_func import move_data_to_device
3434
from pytorch_lightning.utilities.distributed import ReduceOp
3535
from pytorch_lightning.utilities.model_helpers import is_overridden
36+
from pytorch_lightning.utilities.optimizer import optimizer_to_device, optimizers_to_device
3637
from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, STEP_OUTPUT
3738

3839
TBroadcast = TypeVar("TBroadcast")
@@ -138,7 +139,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
138139
self.accelerator.setup(trainer)
139140
self.setup_optimizers(trainer)
140141
self.setup_precision_plugin()
141-
self._move_optimizer_state()
142+
optimizers_to_device(self.optimizers, self.root_device)
142143

143144
def setup_precision_plugin(self) -> None:
144145
"""Attaches the precision plugin to the accelerator."""
@@ -149,14 +150,6 @@ def setup_precision_plugin(self) -> None:
149150
self.optimizers = optimizers
150151
self.lr_scheduler_configs = lr_scheduler_configs
151152

152-
def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
153-
"""Moves the state of the optimizers to the appropriate device if needed."""
154-
for opt in self.optimizers:
155-
for p, v in opt.state.items():
156-
# `self.root_device` would raise error if called outside the spawn process
157-
# while training on 8 and more cores.
158-
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device or self.root_device)
159-
160153
def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
161154
"""Returns state of an optimizer.
162155
@@ -330,6 +323,7 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
330323
optimizer_states = checkpoint["optimizer_states"]
331324
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
332325
optimizer.load_state_dict(opt_state)
326+
optimizer_to_device(optimizer, self.root_device)
333327

334328
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
335329
"""The actual training step.
@@ -445,7 +439,7 @@ def teardown(self) -> None:
445439
446440
It is the right place to release memory and free other resources.
447441
"""
448-
self._move_optimizer_state(torch.device("cpu"))
442+
optimizers_to_device(self.optimizers, torch.device("cpu"))
449443
self.precision_plugin.teardown()
450444

451445
@classmethod

pytorch_lightning/strategies/tpu_spawn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from pytorch_lightning.utilities.distributed import ReduceOp
3232
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3333
from pytorch_lightning.utilities.model_helpers import is_overridden
34+
from pytorch_lightning.utilities.optimizer import optimizers_to_device
3435
from pytorch_lightning.utilities.rank_zero import rank_zero_only
3536
from pytorch_lightning.utilities.seed import reset_seed
3637
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
@@ -126,7 +127,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
126127
self.accelerator.setup(trainer)
127128
self.setup_optimizers(trainer)
128129
self.setup_precision_plugin()
129-
self._move_optimizer_state()
130+
optimizers_to_device(self.optimizers, self.root_device)
130131

131132
if self.debug:
132133
os.environ["PT_XLA_DEBUG"] = str(1)

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -296,17 +296,6 @@ def restore_optimizers(self) -> None:
296296

297297
# restore the optimizers
298298
self.trainer.strategy.load_optimizer_state_dict(self._loaded_checkpoint)
299-
for optimizer in self.trainer.optimizers:
300-
# move optimizer to GPU 1 weight at a time
301-
# avoids OOM
302-
if self.trainer.root_gpu is not None:
303-
for param, state in optimizer.state.items():
304-
if isinstance(state, dict):
305-
for k, v in state.items():
306-
if isinstance(v, torch.Tensor):
307-
state[k] = v.cuda(self.trainer.root_gpu)
308-
elif isinstance(state, torch.Tensor):
309-
optimizer.state[param] = state.cuda(self.trainer.root_gpu)
310299

311300
def restore_lr_schedulers(self) -> None:
312301
"""Restores the learning rate scheduler states from the pre-loaded checkpoint."""
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Iterable
16+
17+
import torch
18+
from torch.optim import Optimizer
19+
20+
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
21+
from pytorch_lightning.utilities.types import _DEVICE
22+
23+
24+
def optimizers_to_device(optimizers: Iterable[Optimizer], device: _DEVICE) -> None:
25+
"""Moves optimizer states for a sequence of optimizers to the device."""
26+
for opt in optimizers:
27+
optimizer_to_device(opt, device)
28+
29+
30+
def optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
31+
"""Moves the state of a single optimizer to the device."""
32+
for p, v in optimizer.state.items():
33+
optimizer.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device)

tests/utilities/test_optimizer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import collections
2+
3+
import torch
4+
5+
from pytorch_lightning.utilities.optimizer import optimizer_to_device
6+
7+
8+
def test_optimizer_to_device():
9+
class TestOptimizer(torch.optim.SGD):
10+
def __init__(self, *args, **kwargs):
11+
super().__init__(*args, **kwargs)
12+
self.state["dummy"] = torch.tensor(0)
13+
14+
layer = torch.nn.Linear(32, 2)
15+
opt = TestOptimizer(layer.parameters(), lr=0.1)
16+
optimizer_to_device(opt, "cpu")
17+
if torch.cuda.is_available():
18+
optimizer_to_device(opt, "cuda")
19+
assert_opt_parameters_on_device(opt, "cuda")
20+
21+
22+
def assert_opt_parameters_on_device(opt, device: str):
23+
for param in opt.state.values():
24+
# Not sure there are any global tensors in the state dict
25+
if isinstance(param, torch.Tensor):
26+
assert param.data.device.type == device
27+
elif isinstance(param, collections.Mapping):
28+
for subparam in param.values():
29+
if isinstance(subparam, torch.Tensor):
30+
assert param.data.device.type == device

0 commit comments

Comments
 (0)