-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Advanced GPU Documentation #7259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
# Conflicts: # docs/source/advanced/multi_gpu.rst
Codecov Report
@@ Coverage Diff @@
## master #7259 +/- ##
=======================================
Coverage 91% 92%
=======================================
Files 199 200 +1
Lines 12779 12982 +203
=======================================
+ Hits 11679 11896 +217
+ Misses 1100 1086 -14 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fairscale part looks great to me. Thanks for adding this great doc!
.. code-block:: python | ||
|
||
# train using Sharded DDP | ||
trainer = Trainer(plugins='ddp_sharded') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest adding a plugin alias of "sdp", it is kind of easier and fits the group of names like "ddp", "sdp" and "fsdp".
|
||
When not using Fully Sharded these wrap functions are a no-op. This means once the changes have been made, there is no need to remove the changes for other plugins. | ||
|
||
This is a requirement for really large models and also saves on instantiation time as modules are sharded instantly, rather than after the entire model is created in memory. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if something needs to be said about model weight init here? Is that taken care of by lightning? If users get to control it, they need to make sure all workers init the same weights, or the shards will be from different weight init values at each worker.
Later we will try to add a way to sync params from rank 0, in that case, we can remove this restriction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @min-xu-ai, could you go into more details as to what is different here compared to DDP?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. With DDP, you can have this:
rank 0 rank 1
m = model() m = model() <------ two ranks may have different weights due to different random seeds
train(m) train(m) <---------------- weights are synced by ddp
With FSDP, since m
is sharded, parts of the weights will be from rank 0 and parts of the weights will be from rank 1 when sharding happens. That can break the weight init assumptions, like zero mean and unit stddev etc.
Therefore, until FSDP can sync weights between ranks, weight init needs to be very careful with FSDP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would be a potential solution here? use torch.distributed communications to sync global stats across all shards when initializing the model? It would be good to have a solution in place in the docs for users to have an example!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just following up here because it might be a solution @min-xu-ai, but using SummonFullParams
may be a way to init the model locally if the model would fit into memory, and then broadcast results. I'll add this into the FSDP docs in time when we merge the feature in!
Appreciate the comments @min-xu-ai will address the last comments ASAP! |
maybe an advanced tutorials section? |
Agreed I actually plan on doing something similar in terms of layout to this which would be closer to an actual tutorial: https://pytorch.org/tutorials/intermediate/ddp_tutorial.html Like if I wanted to train a transformer model on my data, what are the steps I should take? What should the process look like? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SeanNaren another setting from DDP to enable memory savings is to set gradient_as_bucket_view=True
: https://pytorch.org/docs/master/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
This should save an extra ~10-15% of peak memory usage and can be an intermediary option for users who don't need the sharded/fully sharded/deepspeed enginees
cc @zhaojuanmao
I have added both the DDP Comm hooks (need to test it myself) + the Once FSDP is merged, we can merge this PR |
@SeanNaren these are fantastic docs!!! these would be super useful to even merge now (minus FSDP) and then we can add back the FSDP section once #6152 is merged. What do you think? |
Thanks @ananthsub let me remove the FSDP stuff and merge :) Was hoping we'll get the FSDP stuff merged, but will separate out to get this in ASAP |
I have as followups to this PR:
EDIT: I also dropped Sequential RPC Plugin as this will be removed entirely once FSDP merged (which should be merged soonish) |
Great docs @SeanNaren very high quality like everything you do. Some sections have larger separation than others. Maybe it was intentional but I don't see the pattern. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great Work!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work !
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Justus Schock <[email protected]>
Thanks so much guys, should've addressed all points (thanks @carmocca for cleaning up!) |
What does this PR do?
Introduces a new advanced multi-gpu section, with more explanation and details. Cleanup of old APIs + addition of Fully Sharded and activation checkpointing.
A lot of the high level points may need actual data to back them up, but are collations from the DeepSpeed/FairScale team. It's more important right now to highlight the high level points, and then trickle down to data points via visualizations if possible.
If anyone has any suggestions on better naming than
Advanced GPU Optimized Training
Let me know!cc @ananthsub @shuyingsunshine21 @min-xu-ai
TODO:
Memory Optimized Multi-GPU Training
(how aboutAdvanced Multi-GPU Training
)?self.trainer.model
when doingconfigure_optimizers
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃