-
Notifications
You must be signed in to change notification settings - Fork 3.5k
FSDP integration #6152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
FSDP integration #6152
Changes from 2 commits
Commits
Show all changes
77 commits
Select commit
Hold shift + click to select a range
78f1eb4
Add initial FSDP integration
c36e00a
Fix error in refactor
59dbb83
update
tchaton 19a1440
Revert "update"
3b38615
Address reviews
5ff06ab
Fix doc string
36434f0
Even moar code review
c61a190
Add deprecation
1c4f011
Merge branch 'master' into feat/fsdp
02599e6
Fix name of test
e79977a
Integrate nesting, fix bugs across implementation
d15d4b5
Merge branch 'master' into feat/fsdp
ebf1818
Formatting types
290e8fd
Add additional tests for accelerator model
5c5f762
Fix import
d28438b
Few test fixes, expose params
ab591a8
Allow training_type_plugin to delay optimizer configure
23ccdb8
Merge branch 'feat/fsdp_2n' into feat/fsdp
a60f2c0
Add missing references to trainer, add a CPU accelerator based test
3d4e6df
Merge branch 'feat/fsdp_2n' into feat/fsdp
516bd04
Update for latest API changes to fairscale
9f8864f
Add base hook for model parallel
eac5344
fix callback signature
kaushikb11 32df0cb
Simplify hook
282a133
Add hook logic
7a94e72
add tests
kaushikb11 8091481
add property setter
kaushikb11 633fc77
add logic for being called once
kaushikb11 c99a36f
Update changelog
kaushikb11 a68c8d7
Merge branch 'master' into feat/model_parallel_hook
kaushikb11 9529a22
Fix
kaushikb11 3c1c782
fix return type
kaushikb11 7daba43
Merge branch 'master' into feat/fsdp
87ec222
Fix property name
966b2e5
Merge branch 'feat/model_parallel_hook' into feat/fsdp
5f6e039
Updaet wrapper, use latest fixes for hooks
b512e72
Swap hook order
8ba82df
Merge branch 'master' into feat/fsdp
1e5ca37
Small changes
936dc1a
Fixes
a6de18e
Remove activation checkpointing
8684f94
Turn off auto wrap by default
76091ae
Move to trainer.model
226d498
fix reference
cd63c10
Merge branch 'master' into feat/fsdp
b881e2f
Remove flag
e8959be
Fix imports
52478ac
Fix versions, update docs
b7f1896
Fix clip gradients
a62f8d8
Merge branch 'master' into feat/fsdp
69c33f1
Merge branch 'master' into feat/fsdp
9fa26c0
Fixes
56f23ce
pull
9ca3f0c
Few changes across the board
b53ba36
Fix imports
0da5249
Set none
90c6479
Swap to warnings
69d8178
Remove fairscale from container
a459d10
pull
a7842d9
Update dockers/base-cuda/Dockerfile
48ee83f
Add defaults, add test to ensure nested wrapper is set correctly
57a696c
Remove deprecation as this will be removed completely
36889b8
Check for nested FSDP wrappers, and omit wrapping algorithm
89b8cb5
Merge branch 'master' into feat/fsdp
0c1d2de
Update pytorch_lightning/trainer/connectors/accelerator_connector.py
592bb28
Address code review points
4e230c9
Merge branch 'master' into feat/fsdp
ca8e586
Add back missing model that was removed from clipping signature
54f501d
Do not pass model through, accelerator does it
02925cc
Merge branch 'master' into feat/fsdp
b67f1a9
Fix merge
132eb64
Fix imports
e6ce3cf
Changes to precision plugin
01153af
Require 2 GPU for multi gpu test
6cfe57d
Merge branch 'master' into feat/fsdp
efa81ab
Use callback in test, swap to DynamicLossScaler from fairscale to tes…
78d52b5
Disable loss scaler for now
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
29 changes: 29 additions & 0 deletions
29
pytorch_lightning/plugins/precision/full_sharded_native_amp.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Union | ||
|
||
from torch.optim import Optimizer | ||
|
||
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin | ||
|
||
|
||
class FullShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): | ||
"""Mixed Precision for Full Sharded Training | ||
""" | ||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def clip_gradients( | ||
self, model: Any, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0) | ||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
# Model manages clipping of gradients | ||
model.clip_grad_norm_(clip_val, norm_type) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
150 changes: 150 additions & 0 deletions
150
pytorch_lightning/plugins/training_type/full_sharded.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import List, Optional | ||
|
||
import torch | ||
|
||
from pytorch_lightning.core.lightning import LightningModule | ||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment | ||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin | ||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULL_SHARDED_AVAILABLE | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
||
if _FAIRSCALE_AVAILABLE: | ||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from fairscale.nn.data_parallel import FullyShardedDataParallel | ||
|
||
from pytorch_lightning.overrides.fairscale import ( | ||
LightningFullShardedDataParallel, | ||
unwrap_lightning_module_full_sharded, | ||
) | ||
|
||
|
||
class FullShardedPlugin(DDPPlugin): | ||
|
||
def __init__( | ||
self, | ||
cpu_offload: bool = True, | ||
flatten_parameters: bool = False, | ||
reshard_after_forward: bool = True, | ||
move_grads_to_cpu: Optional[bool] = None, | ||
fp32_reduce_scatter: Optional[bool] = None, | ||
compute_dtype: Optional[torch.dtype] = None, | ||
bucket_cap_mb: int = 25, | ||
parallel_devices: Optional[List[torch.device]] = None, | ||
num_nodes: int = 1, | ||
cluster_environment: ClusterEnvironment = None, | ||
sync_batchnorm: Optional[bool] = False | ||
): | ||
""" | ||
|
||
Provides capabilities to run training using the Full Sharded capabilities provided by FairScale. | ||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model | ||
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain | ||
at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar | ||
to ZeRO-Stage 3 but have been modified/adjusted for PyTorch. | ||
|
||
`For more information: https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`. | ||
|
||
.. warning:: ``FullShardedPlugin`` is in beta and subject to change. | ||
|
||
Defaults have been set to enable CPU Offload, but options have been exposed and may require configuration | ||
based on your level of memory/speed efficiency. | ||
We suggest having a look at this PR for more information. | ||
`https://github.com/facebookresearch/fairscale/pull/413` | ||
|
||
|
||
Many of the helpful doc strings below came from the original FairScale documentation: | ||
`https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html` | ||
|
||
Arguments: | ||
|
||
cpu_offload: Offload FP32 params to CPU. Only useable in precision=16 mode (default: False). | ||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
move_grads_to_cpu: Moves gradient shards to CPU after reducation. | ||
Only disable if using CPU based optimizers (defaults to ``cpu_offload``). | ||
|
||
flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency | ||
(default: False). | ||
|
||
reshard_after_forward: Reshard parameters after the forward pass, which saves memory but slows | ||
down training. Only revelant when nesting FullyShardedDataParallel wrappers inside the model. | ||
(default: False). | ||
|
||
fp32_reduce_scatter: Reduce-Scatter gradients in FP32. Only relevant in mixed precision | ||
(default: None) | ||
|
||
compute_dtype: dtype for full parameters for computation. Default to torch.float32, | ||
unless using mixed precision, in which case defaults to torch.float16. | ||
|
||
bucket_cap_mb: bucket parameters so that gradient reduction | ||
can potentially overlap with backward computation. | ||
bucket_cap_mb controls the bucket size in MegaBytes (MB). | ||
Buckets are sub-divided based on world_size, | ||
so the max shard size is roughly bucket_cap_mb / world_size. | ||
Values <= 0 disable bucketing. (Default: 25). | ||
SeanNaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
""" | ||
if not _FAIRSCALE_FULL_SHARDED_AVAILABLE: | ||
raise MisconfigurationException( | ||
"Full Sharded Training is not available. Install the latest FairScale via `pip install fairscale -U`" | ||
) | ||
|
||
if sync_batchnorm: | ||
raise MisconfigurationException("Currently sync batch norm is not supported by Full Sharded Training.") | ||
super().__init__(parallel_devices, num_nodes, cluster_environment, sync_batchnorm=sync_batchnorm) | ||
self.cpu_offload = cpu_offload | ||
self.move_grads_to_cpu = move_grads_to_cpu | ||
self.flatten_parameters = flatten_parameters | ||
self.reshard_after_forward = reshard_after_forward | ||
self.fp32_reduce_scatter = fp32_reduce_scatter | ||
self.compute_dtype = compute_dtype | ||
self.bucket_cap_mb = bucket_cap_mb | ||
|
||
def configure_ddp(self): | ||
precision = self.lightning_module.trainer.precision | ||
self.model = FullyShardedDataParallel( | ||
LightningFullShardedDataParallel(self.model), | ||
cpu_offload=self.cpu_offload, | ||
move_grads_to_cpu=self.move_grads_to_cpu, | ||
flatten_parameters=self.flatten_parameters, | ||
mixed_precision=precision == "mixed", | ||
reshard_after_forward=self.reshard_after_forward, | ||
fp32_reduce_scatter=self.fp32_reduce_scatter, | ||
compute_dtype=self.compute_dtype, | ||
bucket_cap_mb=self.bucket_cap_mb, | ||
) | ||
|
||
@property | ||
def lightning_module(self) -> LightningModule: | ||
return unwrap_lightning_module_full_sharded(self.model) | ||
|
||
def model_to_device(self): | ||
if not self.cpu_offload: | ||
super().model_to_device() | ||
|
||
def on_save(self, checkpoint: dict) -> dict: | ||
state_dict = self.collate_state_dict() | ||
checkpoint['state_dict'] = state_dict | ||
return checkpoint | ||
|
||
def collate_state_dict(self): | ||
""" | ||
Collects the models sharded state dict from all processes before returning. | ||
Returns: The unsharded model state dict. | ||
""" | ||
state_dict = self.model.state_dict() | ||
# Remove module prefix from state dict as this is the behaviour of state dict. | ||
state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()} | ||
return state_dict |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.