-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Serialize checkpoint loading on each node #7509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Serialize checkpoint loading on each node #7509
Conversation
Codecov Report
@@ Coverage Diff @@
## master #7509 +/- ##
=======================================
- Coverage 92% 88% -5%
=======================================
Files 197 197
Lines 12878 12884 +6
=======================================
- Hits 11899 11314 -585
- Misses 979 1570 +591 |
Hello @maximsch2! Thanks for updating this PR. There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2021-05-17 17:58:59 UTC |
@@ -143,7 +143,8 @@ def __init__( | |||
distributed_backend: Optional[str] = None, | |||
move_metrics_to_cpu: bool = False, | |||
multiple_trainloader_mode: str = 'max_size_cycle', | |||
stochastic_weight_avg: bool = False | |||
stochastic_weight_avg: bool = False, | |||
serialize_checkpoint_loading: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think a trainer flag is necessary here? Is it too slow to always serialize?
What about us making the choice by comparing the ram available and model size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can potentially be unsafe to always serialize, specifically if you are doing any NCCL communications in on_load_checkpoint
and assume they are happening at the same time on all hosts - with serialization those will deadlock, hence having this off by default.
…lightning into serialize_checkpoint_loading
@@ -143,7 +143,8 @@ def __init__( | |||
distributed_backend: Optional[str] = None, | |||
move_metrics_to_cpu: bool = False, | |||
multiple_trainloader_mode: str = 'max_size_cycle', | |||
stochastic_weight_avg: bool = False | |||
stochastic_weight_avg: bool = False, | |||
serialize_checkpoint_loading: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, sequential_checkpoint_loading
would be easier to understand.
ckpt_path, map_location=lambda storage, loc: storage | ||
) | ||
# Serialize checkpoint loading to avoid OOMs | ||
if self.serialize_checkpoint_loading and self.num_gpus > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not let this responsibility to the training_type_plugin, I guess this is useful only for DDP and derivate right now.
This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 7 days if no further activity occurs. If you need further help see our docs: https://pytorch-lightning.readthedocs.io/en/latest/generated/CONTRIBUTING.html#pull-request or ask the assistance of a core contributor here or on Slack. Thank you for your contributions. |
This pull request is going to be closed. Please feel free to reopen it create a new from the actual master. |
What does this PR do?
Loading large checkpoints across multiple workers on the same host can lead to OOMs (easy to imagine case:
model_size*num_gpus < total ram < 2*model_size*num_gpus
- we get 2x penalty for loading checkpoint on each worker before setting it into state_dict of the model), serializing the process would help as now we'll only do things one local worker at a time.Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃