Support serialized checkpoint loading #9406
Labels
checkpointing
Related to checkpointing
feature
Is an improvement or enhancement
help wanted
Open to be worked on
let's do it!
approved to implement
Uh oh!
There was an error while loading. Please reload this page.
🚀 Feature
Motivation
Currently, all processes load the checkpoint at the same time. This can lead to CPU OOMs for large models when processes are concurrently loading the checkpoint. These use cases, especially with things like mixture of experts, might require serialized loading of checkpoint dicts across ranks (ie load the checkpoint one rank at a time per node). Could we enable this for DDP?
Prior work: #8515
Pitch
This would be controlled per training type plugin. Example pseudocode: https://gist.github.com/ananthsub/4ceedff56b2049a63bbb05ccd283b919
To work through:
LightningModule.on_load_checkpoint
instead of the Trainer/connector? This would make sense as the TTP "owns" the LightningModule inside of the trainer, and since it already offersload_model_state_dict
: https://github.com/PyTorchLightning/pytorch-lightning/blob/41ba639859cf6c6bf319eb33e5b3394504315962/pytorch_lightning/plugins/training_type/training_type_plugin.py#L159-L160DeepSpeed already eschews most of the checkpoint connector logic when it comes to loading the lightning module state. This could be a gap for metrics, and this means we could be calling
on_load_checkpoint
multiple times with certain plugins. In my opinion, this points to needing all LightningModule state load/save/alterations sit inside of the training type plugin.Alternatives
Additional context
If you enjoy Lightning, check out our other projects! ⚡
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning
Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
The text was updated successfully, but these errors were encountered: