Skip to content

Register Hooks for ShardedTensor Support #8633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
yifuwang opened this issue Jul 29, 2021 · 4 comments
Closed

Register Hooks for ShardedTensor Support #8633

yifuwang opened this issue Jul 29, 2021 · 4 comments
Assignees
Labels
checkpointing Related to checkpointing feature Is an improvement or enhancement help wanted Open to be worked on
Milestone

Comments

@yifuwang
Copy link
Contributor

🚀 Feature

Motivation

PyTorch is introducing ShardedTensor as the standard way for representing model state in sharded models. For checkpointing purposes, ShardedTensor is a special tensor that appears in model.state_dict(). The state dict can be used for restoring the original model via model.load_state_dict().

However, in order for ShardedTensor to work with .state_dict() and .load_state_dict(), two special hooks need to be registered via _register_state_dict_hook() and _register_load_state_dict_pre_hook(). These hooks are no-ops when these's no ShardedTensor in the model.

Pitch

Since in Lightning the trainer is responsible for obtaining state dict from a model, as well as restoring a model given a state dict, Lightning should probably be responsible for registering these hooks.

Note that the feature is still WIP in PyTorch. We can either support it now for early adopters who also uses Lightning or defer it until the feature is released.

Alternatives

Additional context

pytorch/pytorch#55207
pytorch/pytorch#62242


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning

  • Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

  • Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

@yifuwang yifuwang added feature Is an improvement or enhancement help wanted Open to be worked on labels Jul 29, 2021
@awaelchli
Copy link
Contributor

Considering our CI is already setup for nightly pytorch, I think it would be nice to explore introducing these hooks.

@ananthsub
Copy link
Contributor

@yifuwang some n00b questions:

  • where in the trainer do you recommend adding these hooks?
  • would the trainer need to check if the LightningModule already has the hook registered?
  • why register the hooks via the trainer vs the LightningModule's constructor?

@tchaton
Copy link
Contributor

tchaton commented Aug 3, 2021

Dear @yifuwang @ananthsub,

Should we expect the Trainer to auto-inspect the LightningModule and automatically register those hooks if sharded tensors are discovered ?

I was checking '1.10.0.dev20210802+cu111'. The nightly release doesn't contain ChunkShardingSpec yet and we can't write ShardedBoringModel yet.

Best,
T.C

@ananthsub ananthsub added the checkpointing Related to checkpointing label Aug 4, 2021
@ananthsub
Copy link
Contributor

ananthsub commented Aug 4, 2021

Should we expect the Trainer to auto-inspect the LightningModule and automatically register those hooks if sharded tensors are discovered ?

from @pritamdamania87, if there's no ShardedTensor in the module, the hooks for loading/saving to state dict are no-ops. so we don't need to inspect the lightning module for sharded tensors. it is safe to always register these.

Ideally this would be enabled by default for all nn.modules. however, this depends on pytorch/pytorch#62094. Until that is resolved, we need to explicitly register the hooks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
checkpointing Related to checkpointing feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

4 participants