Skip to content

Commit 01bbae9

Browse files
committed
Revert "Update setup logic in training type plugins (sharded) [4 / 4] (#10028)"
This reverts commit 4ea72a9.
1 parent 94e2bf5 commit 01bbae9

File tree

3 files changed

+34
-82
lines changed

3 files changed

+34
-82
lines changed

CHANGELOG.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
213213
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
214214
* Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010))
215215
* Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009))
216-
* Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_models_and_optimizers` ([#10028](https://github.com/PyTorchLightning/pytorch-lightning/pull/10028))
217216
* Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023))
218217

219218

pytorch_lightning/plugins/training_type/sharded.py

Lines changed: 19 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import contextmanager
15-
from typing import Dict, Generator, List, Optional, Tuple, Union
15+
from typing import Dict, Generator, Optional
1616

1717
import torch
18-
from torch.nn import Module
19-
from torch.optim import Optimizer
2018

2119
import pytorch_lightning as pl
2220
from pytorch_lightning.core.optimizer import LightningOptimizer
@@ -35,70 +33,47 @@
3533
class DDPShardedPlugin(DDPPlugin):
3634
"""Optimizer and gradient sharded training provided by FairScale."""
3735

38-
_REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M
39-
40-
def __init__(self, *args, **kwargs):
41-
super().__init__(*args, **kwargs)
42-
self._precision = None
36+
_REDUCE_BUFFER_SIZE_DEFAULT = 2 ** 23 # 8M
4337

4438
def configure_ddp(self) -> None:
45-
trainer = self.lightning_module.trainer
39+
self._wrap_optimizers()
40+
4641
if "reduce_buffer_size" not in self._ddp_kwargs:
4742
# For multi-node training, enabling bucketing will improve performance.
4843
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
4944

50-
[self._model], optimizers = self._setup_models_and_optimizers(
51-
models=[LightningShardedDataParallel(self.model)],
52-
optimizers=trainer.optimizers,
45+
self._model = ShardedDataParallel(
46+
LightningShardedDataParallel(self.model),
47+
sharded_optimizer=self.lightning_module.trainer.optimizers,
48+
**self._ddp_kwargs
5349
)
54-
trainer.optimizers = optimizers
55-
trainer.convert_to_lightning_optimizers()
56-
57-
def _setup_models_and_optimizers(
58-
self, models: List[Module], optimizers: List[Optimizer]
59-
) -> Tuple[List[Module], List[Optimizer]]:
60-
"""Wraps the model and optimizers with fairscale components.
50+
setattr(self._model, "require_backward_grad_sync", False)
6151

62-
Currently only one model can be setup at once.
63-
64-
Return:
65-
A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
66-
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
67-
"""
68-
if len(models) > 1:
69-
raise ValueError(
70-
"DDPSharded only supports setting up a single model with one or several optimizers."
71-
f" Got {len(models)} models."
72-
)
73-
74-
optimizers = self._wrap_optimizers(optimizers)
75-
model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs)
76-
setattr(model, "require_backward_grad_sync", False) # TODO: needed?
77-
return [model], optimizers
78-
79-
def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]:
52+
def _reinit_optimizers_with_oss(self):
53+
optimizers = self.lightning_module.trainer.optimizers
8054
for x, optimizer in enumerate(optimizers):
8155
if isinstance(optimizer, LightningOptimizer):
8256
optimizer = optimizer._optimizer
8357
if not isinstance(optimizer, OSS):
8458
optim_class = type(optimizer)
8559
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
8660
if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
87-
precision = self._precision or self.lightning_module.trainer.precision
61+
precision = self.lightning_module.trainer.precision
8862
is_fp16 = precision in ("mixed", 16)
8963
# For multi-node training, compressing the model shards in fp16 before broadcasting
9064
# improves performance. When using PyTorch AMP, it will not degrade
9165
# the model performance.
9266
zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1
9367
optimizers[x] = zero_optimizer
9468
del optimizer
95-
return optimizers
96-
97-
def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
98-
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
99-
return optimizers
69+
trainer = self.lightning_module.trainer
70+
trainer.optimizers = optimizers
71+
trainer.convert_to_lightning_optimizers()
10072

101-
return self._reinit_optimizers_with_oss(optimizers)
73+
def _wrap_optimizers(self):
74+
if self.model.trainer.state.fn != TrainerFn.FITTING:
75+
return
76+
self._reinit_optimizers_with_oss()
10277

10378
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
10479
if isinstance(optimizer, LightningOptimizer):

pytorch_lightning/plugins/training_type/sharded_spawn.py

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,9 @@
1313
# limitations under the License.
1414
from contextlib import contextmanager
1515
from multiprocessing.queues import SimpleQueue
16-
from typing import Dict, Generator, List, Optional, Tuple
16+
from typing import Dict, Generator, Optional
1717

1818
import torch
19-
from torch.nn import Module
20-
from torch.optim import Optimizer
2119

2220
import pytorch_lightning as pl
2321
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
@@ -38,49 +36,29 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
3836
"""Optimizer sharded training provided by FairScale."""
3937

4038
def configure_ddp(self) -> None:
41-
trainer = self.lightning_module.trainer
42-
[self._model], optimizers = self._setup_models_and_optimizers(
43-
models=[LightningShardedDataParallel(self.model)],
44-
optimizers=trainer.optimizers,
39+
self._wrap_optimizers()
40+
self._model = ShardedDataParallel(
41+
LightningShardedDataParallel(self.model),
42+
sharded_optimizer=self.lightning_module.trainer.optimizers,
43+
**self._ddp_kwargs
4544
)
46-
trainer.optimizers = optimizers
47-
48-
def _setup_models_and_optimizers(
49-
self, models: List[Module], optimizers: List[Optimizer]
50-
) -> Tuple[List[Module], List[Optimizer]]:
51-
"""Wraps the model and optimizers with fairscale components.
45+
setattr(self._model, "require_backward_grad_sync", False)
5246

53-
Currently only one model can be setup at once.
54-
55-
Return:
56-
A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
57-
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
58-
"""
59-
if len(models) > 1:
60-
raise ValueError(
61-
f"DDPShardedSpawn only supports setting up a single model with one or several optimizers."
62-
f" Got {len(models)} models."
63-
)
64-
65-
optimizers = self._wrap_optimizers(optimizers)
66-
model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs)
67-
setattr(model, "require_backward_grad_sync", False) # TODO: needed?
68-
return [model], optimizers
69-
70-
def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]:
47+
def _reinit_optimizers_with_oss(self):
48+
optimizers = self.lightning_module.trainer.optimizers
7149
for x, optimizer in enumerate(optimizers):
7250
if not isinstance(optimizer, OSS):
7351
optim_class = type(optimizer)
7452
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
7553
optimizers[x] = zero_optimizer
7654
del optimizer
77-
return optimizers
78-
79-
def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
80-
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
81-
return optimizers
55+
trainer = self.lightning_module.trainer
56+
trainer.optimizers = optimizers
8257

83-
return self._reinit_optimizers_with_oss(optimizers)
58+
def _wrap_optimizers(self):
59+
if self.model.trainer.state.fn != TrainerFn.FITTING:
60+
return
61+
self._reinit_optimizers_with_oss()
8462

8563
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
8664
if isinstance(optimizer, OSS):

0 commit comments

Comments
 (0)