Skip to content

Commit 78f1eb4

Browse files
author
SeanNaren
committed
Add initial FSDP integration
1 parent 863a70c commit 78f1eb4

21 files changed

+410
-21
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Opt
290290
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
291291
"""clips all the optimizer parameters to the given value"""
292292

293-
self.precision_plugin.clip_gradients(optimizer, clip_val)
293+
self.precision_plugin.clip_gradients(self.model, optimizer, clip_val)
294294

295295
def on_train_epoch_end(self, outputs) -> None:
296296
"""Hook to do something on the end of an training epoch
@@ -371,7 +371,7 @@ def optimizer_state(self, optimizer: Optimizer) -> dict:
371371
return optimizer.state_dict()
372372

373373
def on_save(self, checkpoint):
374-
return checkpoint
374+
return self.training_type_plugin.on_save(checkpoint)
375375

376376
def barrier(self, name: Optional[str] = None) -> None:
377377
self.training_type_plugin.barrier(name=name)

pytorch_lightning/overrides/fairscale.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
from pytorch_lightning.core.lightning import LightningModule
1515
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module
16-
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
16+
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULL_SHARDED_AVAILABLE
1717

1818
LightningShardedDataParallel = None
1919
if _FAIRSCALE_AVAILABLE:
@@ -29,3 +29,22 @@ def unwrap_lightning_module_sharded(wrapped_model) -> LightningModule:
2929
model = model.module
3030

3131
return unwrap_lightning_module(model)
32+
33+
34+
LightningFullShardedDataParallel = None
35+
if _FAIRSCALE_FULL_SHARDED_AVAILABLE:
36+
from fairscale.nn import FlattenParamsWrapper
37+
from fairscale.nn.data_parallel import FullyShardedDataParallel
38+
39+
class LightningFullShardedDataParallel(_LightningModuleWrapperBase):
40+
# Just do this for later docstrings
41+
pass
42+
43+
def unwrap_lightning_module_full_sharded(wrapped_model) -> LightningModule:
44+
model = wrapped_model
45+
if isinstance(model, FullyShardedDataParallel):
46+
model = model.module
47+
# Additional check if we're using a flattened parameters buffer
48+
if isinstance(model, FlattenParamsWrapper):
49+
model = model.module
50+
return unwrap_lightning_module(model)

pytorch_lightning/plugins/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401
22
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
33
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
4+
from pytorch_lightning.plugins.precision.full_sharded_native_amp import ( # noqa: F401
5+
FullShardedNativeMixedPrecisionPlugin,
6+
)
47
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
58
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
69
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
@@ -10,6 +13,7 @@
1013
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
1114
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
1215
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
16+
from pytorch_lightning.plugins.training_type.full_sharded import FullShardedPlugin # noqa: F401
1317
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
1418
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
1519
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
@@ -29,6 +33,8 @@
2933
"DDPSpawnPlugin",
3034
"DeepSpeedPlugin",
3135
"DeepSpeedPrecisionPlugin",
36+
"FullShardedPlugin",
37+
"FullShardedNativeMixedPrecisionPlugin",
3238
"HorovodPlugin",
3339
"NativeMixedPrecisionPlugin",
3440
"PrecisionPlugin",

pytorch_lightning/plugins/precision/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
22
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
3+
from pytorch_lightning.plugins.precision.full_sharded_native_amp import ( # noqa: F401
4+
FullShardedNativeMixedPrecisionPlugin,
5+
)
36
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401
47
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
58
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401

pytorch_lightning/plugins/precision/deepspeed_precision.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Union
1+
from typing import Any, Callable, Union
22

33
import torch
44
from torch.optim import Optimizer
@@ -54,7 +54,9 @@ def backward(
5454

5555
return closure_loss
5656

57-
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)):
57+
def clip_gradients(
58+
self, model: Any, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)
59+
):
5860
"""
5961
DeepSpeed handles clipping gradients via the training type plugin.
6062
"""
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, Union
15+
16+
from torch.optim import Optimizer
17+
18+
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
19+
20+
21+
class FullShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
22+
"""Mixed Precision for Full Sharded Training
23+
"""
24+
25+
def clip_gradients(
26+
self, model: Any, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)
27+
):
28+
# Model manages clipping of gradients
29+
model.clip_grad_norm_(clip_val, norm_type)

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ def pre_optimizer_step(
8686
def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
8787
"""Hook to do something after each optimizer step."""
8888

89-
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)) -> None:
89+
def clip_gradients(
90+
self, model: Any, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)
91+
):
9092
"""Clips the gradients to a specific value"""
9193
# TODO: separate TPU case from here
9294
if clip_val is None:

pytorch_lightning/plugins/precision/sharded_native_amp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import cast, Union
14+
from typing import Any, cast, Union
1515

1616
from torch.optim import Optimizer
1717

@@ -31,6 +31,8 @@ def __init__(self):
3131
super().__init__()
3232
self.scaler = ShardedGradScaler()
3333

34-
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)):
34+
def clip_gradients(
35+
self, model: Any, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)
36+
):
3537
optimizer = cast(OSS, optimizer)
3638
optimizer.clip_grad_norm(clip_val, norm_type=norm_type)

pytorch_lightning/plugins/training_type/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
44
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
55
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
6+
from pytorch_lightning.plugins.training_type.full_sharded import FullShardedPlugin # noqa: F401
67
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
78
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
89
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import List, Optional
15+
16+
import torch
17+
18+
from pytorch_lightning.core.lightning import LightningModule
19+
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
20+
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
21+
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULL_SHARDED_AVAILABLE
22+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
23+
24+
if _FAIRSCALE_AVAILABLE:
25+
from fairscale.nn.data_parallel import FullyShardedDataParallel
26+
27+
from pytorch_lightning.overrides.fairscale import (
28+
LightningFullShardedDataParallel,
29+
unwrap_lightning_module_full_sharded,
30+
)
31+
32+
33+
class FullShardedPlugin(DDPPlugin):
34+
35+
def __init__(
36+
self,
37+
cpu_offload: bool = True,
38+
flatten_parameters: bool = False,
39+
reshard_after_forward: bool = True,
40+
move_grads_to_cpu: Optional[bool] = None,
41+
fp32_reduce_scatter: Optional[bool] = None,
42+
compute_dtype: Optional[torch.dtype] = None,
43+
bucket_cap_mb: int = 25,
44+
parallel_devices: Optional[List[torch.device]] = None,
45+
num_nodes: int = 1,
46+
cluster_environment: ClusterEnvironment = None,
47+
sync_batchnorm: Optional[bool] = False
48+
):
49+
"""
50+
51+
Provides capabilities to run training using the Full Sharded capabilities provided by FairScale.
52+
53+
Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model
54+
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
55+
at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar
56+
to ZeRO-Stage 3 but have been modified/adjusted for PyTorch.
57+
58+
`For more information: https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`.
59+
60+
.. warning:: ``FullShardedPlugin`` is in beta and subject to change.
61+
62+
Defaults have been set to enable CPU Offload, but options have been exposed and may require configuration
63+
based on your level of memory/speed efficiency.
64+
We suggest having a look at this PR for more information.
65+
`https://github.com/facebookresearch/fairscale/pull/413`
66+
67+
68+
Many of the helpful doc strings below came from the original FairScale documentation:
69+
`https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`
70+
71+
Arguments:
72+
73+
cpu_offload: Offload FP32 params to CPU. Only useable in precision=16 mode (default: False).
74+
75+
move_grads_to_cpu: Moves gradient shards to CPU after reducation.
76+
Only disable if using CPU based optimizers (defaults to ``cpu_offload``).
77+
78+
flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency
79+
(default: False).
80+
81+
reshard_after_forward: Reshard parameters after the forward pass, which saves memory but slows
82+
down training. Only revelant when nesting FullyShardedDataParallel wrappers inside the model.
83+
(default: False).
84+
85+
fp32_reduce_scatter: Reduce-Scatter gradients in FP32. Only relevant in mixed precision
86+
(default: None)
87+
88+
compute_dtype: dtype for full parameters for computation. Default to torch.float32,
89+
unless using mixed precision, in which case defaults to torch.float16.
90+
91+
bucket_cap_mb: bucket parameters so that gradient reduction
92+
can potentially overlap with backward computation.
93+
bucket_cap_mb controls the bucket size in MegaBytes (MB).
94+
Buckets are sub-divided based on world_size,
95+
so the max shard size is roughly bucket_cap_mb / world_size.
96+
Values <= 0 disable bucketing. (Default: 25).
97+
98+
"""
99+
if not _FAIRSCALE_FULL_SHARDED_AVAILABLE:
100+
raise MisconfigurationException(
101+
"Full Sharded Training is not available. Install the latest FairScale via `pip install fairscale -U`"
102+
)
103+
104+
if sync_batchnorm:
105+
raise MisconfigurationException("Currently sync batch norm is not supported by Full Sharded Training.")
106+
super().__init__(parallel_devices, num_nodes, cluster_environment, sync_batchnorm=sync_batchnorm)
107+
self.cpu_offload = cpu_offload
108+
self.move_grads_to_cpu = move_grads_to_cpu
109+
self.flatten_parameters = flatten_parameters
110+
self.reshard_after_forward = reshard_after_forward
111+
self.fp32_reduce_scatter = fp32_reduce_scatter
112+
self.compute_dtype = compute_dtype
113+
self.bucket_cap_mb = bucket_cap_mb
114+
115+
def configure_ddp(self):
116+
precision = self.lightning_module.trainer.precision
117+
self.model = FullyShardedDataParallel(
118+
LightningFullShardedDataParallel(self.model),
119+
cpu_offload=self.cpu_offload,
120+
move_grads_to_cpu=self.move_grads_to_cpu,
121+
flatten_parameters=self.flatten_parameters,
122+
mixed_precision=precision == "mixed",
123+
reshard_after_forward=self.reshard_after_forward,
124+
fp32_reduce_scatter=self.fp32_reduce_scatter,
125+
compute_dtype=self.compute_dtype,
126+
bucket_cap_mb=self.bucket_cap_mb,
127+
)
128+
129+
@property
130+
def lightning_module(self) -> LightningModule:
131+
return unwrap_lightning_module_full_sharded(self.model)
132+
133+
def model_to_device(self):
134+
if not self.cpu_offload:
135+
super().model_to_device()
136+
137+
def on_save(self, checkpoint: dict) -> dict:
138+
state_dict = self.collate_state_dict()
139+
checkpoint['state_dict'] = state_dict
140+
return checkpoint
141+
142+
def collate_state_dict(self):
143+
"""
144+
Collects the models sharded state dict from all processes before returning.
145+
Returns: The unsharded model state dict.
146+
"""
147+
state_dict = self.model.state_dict()
148+
# Remove module prefix from state dict as this is the behaviour of state dict.
149+
state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()}
150+
return state_dict

pytorch_lightning/plugins/training_type/rpc_sequential.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pytorch_lightning.overrides.distributed import LightningDistributedModule
2626
from pytorch_lightning.plugins.training_type.rpc import DEFAULT_RPC_TIMEOUT_SEC, RPCPlugin
2727
from pytorch_lightning.trainer.states import RunningStage
28-
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
28+
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only, rank_zero_warn
2929
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3030

3131
if _FAIRSCALE_PIPE_AVAILABLE:
@@ -56,6 +56,10 @@ def __init__(
5656
5757
.. _RPCSequentialPlugin: https://arxiv.org/abs/1811.06965
5858
59+
.. warning::
60+
This plugin has been deprecated. Please use the ``FullShardedPlugin`` which provides better performance
61+
and scaling without pipelining the model.
62+
5963
Pipeline parallelism comes with with checkpointing to reduce peak
6064
memory required to train while minimizing device under-utilization.
6165
This is turned on by default and can be turned off via the checkpoint argument.
@@ -87,6 +91,10 @@ def __init__(
8791
at the same time. Defaults to `True` if
8892
`get_model_parallel_world_size() > 1`
8993
"""
94+
rank_zero_warn(
95+
"RPC Sequential Plugin has been deprecated. Please use the `FullShardedPlugin` "
96+
"which provides better performance and scaling without pipelining the model."
97+
)
9098
self._check_pipe_available()
9199
super().__init__(rpc_timeout_sec=rpc_timeout_sec, **kwargs)
92100

0 commit comments

Comments
 (0)