Skip to content

Support serialized checkpoint loading #9406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ananthsub opened this issue Sep 9, 2021 · 2 comments · Fixed by #9605
Closed

Support serialized checkpoint loading #9406

ananthsub opened this issue Sep 9, 2021 · 2 comments · Fixed by #9605
Assignees
Labels
checkpointing Related to checkpointing feature Is an improvement or enhancement help wanted Open to be worked on let's do it! approved to implement

Comments

@ananthsub
Copy link
Contributor

ananthsub commented Sep 9, 2021

🚀 Feature

Motivation

Currently, all processes load the checkpoint at the same time. This can lead to CPU OOMs for large models when processes are concurrently loading the checkpoint. These use cases, especially with things like mixture of experts, might require serialized loading of checkpoint dicts across ranks (ie load the checkpoint one rank at a time per node). Could we enable this for DDP?

Prior work: #8515

Pitch

This would be controlled per training type plugin. Example pseudocode: https://gist.github.com/ananthsub/4ceedff56b2049a63bbb05ccd283b919

To work through:

  1. Should the TrainingTypePlugin have responsibility of calling LightningModule.on_load_checkpoint instead of the Trainer/connector? This would make sense as the TTP "owns" the LightningModule inside of the trainer, and since it already offers load_model_state_dict: https://github.com/PyTorchLightning/pytorch-lightning/blob/41ba639859cf6c6bf319eb33e5b3394504315962/pytorch_lightning/plugins/training_type/training_type_plugin.py#L159-L160

DeepSpeed already eschews most of the checkpoint connector logic when it comes to loading the lightning module state. This could be a gap for metrics, and this means we could be calling on_load_checkpoint multiple times with certain plugins. In my opinion, this points to needing all LightningModule state load/save/alterations sit inside of the training type plugin.

Alternatives

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning

  • Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

  • Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

@ananthsub ananthsub added feature Is an improvement or enhancement help wanted Open to be worked on checkpointing Related to checkpointing labels Sep 9, 2021
@ananthsub ananthsub self-assigned this Sep 9, 2021
@carmocca
Copy link
Contributor

carmocca commented Sep 9, 2021

More previous work: #7509

@tchaton tchaton added the let's do it! approved to implement label Sep 10, 2021
@jjenniferdai
Copy link
Contributor

I'm planning to work on this this week if thats ok!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
checkpointing Related to checkpointing feature Is an improvement or enhancement help wanted Open to be worked on let's do it! approved to implement
Projects
None yet
4 participants