Register Hooks for ShardedTensor Support #8633
Labels
checkpointing
Related to checkpointing
feature
Is an improvement or enhancement
help wanted
Open to be worked on
Milestone
🚀 Feature
Motivation
PyTorch is introducing ShardedTensor as the standard way for representing model state in sharded models. For checkpointing purposes,
ShardedTensor
is a special tensor that appears inmodel.state_dict()
. The state dict can be used for restoring the original model viamodel.load_state_dict()
.However, in order for
ShardedTensor
to work with.state_dict()
and.load_state_dict()
, two special hooks need to be registered via_register_state_dict_hook()
and_register_load_state_dict_pre_hook()
. These hooks are no-ops when these's noShardedTensor
in the model.Pitch
Since in Lightning the trainer is responsible for obtaining state dict from a model, as well as restoring a model given a state dict, Lightning should probably be responsible for registering these hooks.
Note that the feature is still WIP in PyTorch. We can either support it now for early adopters who also uses Lightning or defer it until the feature is released.
Alternatives
Additional context
pytorch/pytorch#55207
pytorch/pytorch#62242
If you enjoy Lightning, check out our other projects! ⚡
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning
Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
The text was updated successfully, but these errors were encountered: