|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | from contextlib import contextmanager
|
15 |
| -from typing import Dict, Generator, List, Optional, Tuple, Union |
| 15 | +from typing import Dict, Generator, Optional |
16 | 16 |
|
17 | 17 | import torch
|
18 |
| -from torch.nn import Module |
19 |
| -from torch.optim import Optimizer |
20 | 18 |
|
21 | 19 | import pytorch_lightning as pl
|
22 | 20 | from pytorch_lightning.core.optimizer import LightningOptimizer
|
|
35 | 33 | class DDPShardedPlugin(DDPPlugin):
|
36 | 34 | """Optimizer and gradient sharded training provided by FairScale."""
|
37 | 35 |
|
38 |
| - _REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M |
39 |
| - |
40 |
| - def __init__(self, *args, **kwargs): |
41 |
| - super().__init__(*args, **kwargs) |
42 |
| - self._precision = None |
| 36 | + _REDUCE_BUFFER_SIZE_DEFAULT = 2 ** 23 # 8M |
43 | 37 |
|
44 | 38 | def configure_ddp(self) -> None:
|
45 |
| - trainer = self.lightning_module.trainer |
| 39 | + self._wrap_optimizers() |
| 40 | + |
46 | 41 | if "reduce_buffer_size" not in self._ddp_kwargs:
|
47 | 42 | # For multi-node training, enabling bucketing will improve performance.
|
48 | 43 | self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
|
49 | 44 |
|
50 |
| - [self._model], optimizers = self._setup_models_and_optimizers( |
51 |
| - models=[LightningShardedDataParallel(self.model)], |
52 |
| - optimizers=trainer.optimizers, |
| 45 | + self._model = ShardedDataParallel( |
| 46 | + LightningShardedDataParallel(self.model), |
| 47 | + sharded_optimizer=self.lightning_module.trainer.optimizers, |
| 48 | + **self._ddp_kwargs |
53 | 49 | )
|
54 |
| - trainer.optimizers = optimizers |
55 |
| - trainer.convert_to_lightning_optimizers() |
56 |
| - |
57 |
| - def _setup_models_and_optimizers( |
58 |
| - self, models: List[Module], optimizers: List[Optimizer] |
59 |
| - ) -> Tuple[List[Module], List[Optimizer]]: |
60 |
| - """Wraps the model and optimizers with fairscale components. |
| 50 | + setattr(self._model, "require_backward_grad_sync", False) |
61 | 51 |
|
62 |
| - Currently only one model can be setup at once. |
63 |
| -
|
64 |
| - Return: |
65 |
| - A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module |
66 |
| - and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`. |
67 |
| - """ |
68 |
| - if len(models) > 1: |
69 |
| - raise ValueError( |
70 |
| - "DDPSharded only supports setting up a single model with one or several optimizers." |
71 |
| - f" Got {len(models)} models." |
72 |
| - ) |
73 |
| - |
74 |
| - optimizers = self._wrap_optimizers(optimizers) |
75 |
| - model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs) |
76 |
| - setattr(model, "require_backward_grad_sync", False) # TODO: needed? |
77 |
| - return [model], optimizers |
78 |
| - |
79 |
| - def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]: |
| 52 | + def _reinit_optimizers_with_oss(self): |
| 53 | + optimizers = self.lightning_module.trainer.optimizers |
80 | 54 | for x, optimizer in enumerate(optimizers):
|
81 | 55 | if isinstance(optimizer, LightningOptimizer):
|
82 | 56 | optimizer = optimizer._optimizer
|
83 | 57 | if not isinstance(optimizer, OSS):
|
84 | 58 | optim_class = type(optimizer)
|
85 | 59 | zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
|
86 | 60 | if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
|
87 |
| - precision = self._precision or self.lightning_module.trainer.precision |
| 61 | + precision = self.lightning_module.trainer.precision |
88 | 62 | is_fp16 = precision in ("mixed", 16)
|
89 | 63 | # For multi-node training, compressing the model shards in fp16 before broadcasting
|
90 | 64 | # improves performance. When using PyTorch AMP, it will not degrade
|
91 | 65 | # the model performance.
|
92 | 66 | zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1
|
93 | 67 | optimizers[x] = zero_optimizer
|
94 | 68 | del optimizer
|
95 |
| - return optimizers |
96 |
| - |
97 |
| - def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: |
98 |
| - if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING: |
99 |
| - return optimizers |
| 69 | + trainer = self.lightning_module.trainer |
| 70 | + trainer.optimizers = optimizers |
| 71 | + trainer.convert_to_lightning_optimizers() |
100 | 72 |
|
101 |
| - return self._reinit_optimizers_with_oss(optimizers) |
| 73 | + def _wrap_optimizers(self): |
| 74 | + if self.model.trainer.state.fn != TrainerFn.FITTING: |
| 75 | + return |
| 76 | + self._reinit_optimizers_with_oss() |
102 | 77 |
|
103 | 78 | def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
|
104 | 79 | if isinstance(optimizer, LightningOptimizer):
|
|
0 commit comments