-
Notifications
You must be signed in to change notification settings - Fork 24.4k
Support saving and loading ShardedTensor. #62242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support saving and loading ShardedTensor. #62242
Conversation
1) Add a state_dict hook to ensure ShardedTensors are added to a state_dict. 2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a module at load time. 3) Add a `with_load_process_group` context manager for load time. 4) Added ser-de capability to ShardedTensor. Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/) [ghstack-poisoned]
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit c1af3a0 (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
@with_comms | ||
@skip_if_lt_x_gpu(4) | ||
@requires_nccl() | ||
def test_state_dict(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bookmark
1) Add a state_dict hook to ensure ShardedTensors are added to a state_dict. 2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a module at load time. 3) Add a `with_load_process_group` context manager for load time. 4) Added ser-de capability to ShardedTensor. Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/) [ghstack-poisoned]
1) Add a state_dict hook to ensure ShardedTensors are added to a state_dict. 2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a module at load time. 3) Add a `with_load_process_group` context manager for load time. 4) Added ser-de capability to ShardedTensor. Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/) [ghstack-poisoned]
Pull Request resolved: #62242 1) Add a state_dict hook to ensure ShardedTensors are added to a state_dict. 2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a module at load time. 3) Add a `with_load_process_group` context manager for load time. 4) Added ser-de capability to ShardedTensor. ghstack-source-id: 134381847 Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have some concerns on using the hook to update the state_dict, otherwise looks good to me!
1) Add a state_dict hook to ensure ShardedTensors are added to a state_dict. 2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a module at load time. 3) Add a `with_load_process_group` context manager for load time. 4) Added ser-de capability to ShardedTensor. Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/) [ghstack-poisoned]
1) Add a state_dict hook to ensure ShardedTensors are added to a state_dict. 2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a module at load time. 3) Add a `with_load_process_group` context manager for load time. 4) Added ser-de capability to ShardedTensor. Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/) [ghstack-poisoned]
Pull Request resolved: #62242 1) Add a state_dict hook to ensure ShardedTensors are added to a state_dict. 2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a module at load time. 3) Add a `with_load_process_group` context manager for load time. 4) Added ser-de capability to ShardedTensor. ghstack-source-id: 134574329 Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great! just one suggestion about adding the tests for exception handling.
1) Add a state_dict hook to ensure ShardedTensors are added to a state_dict. 2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a module at load time. 3) Add a `with_load_process_group` context manager for load time. 4) Added ser-de capability to ShardedTensor. Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/) [ghstack-poisoned]
Pull Request resolved: #62242 1) Add a state_dict hook to ensure ShardedTensors are added to a state_dict. 2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a module at load time. 3) Add a `with_load_process_group` context manager for load time. 4) Added ser-de capability to ShardedTensor. ghstack-source-id: 134741074 Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/)
1) Add a state_dict hook to ensure ShardedTensors are added to a state_dict. 2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a module at load time. 3) Add a `with_load_process_group` context manager for load time. 4) Added ser-de capability to ShardedTensor. Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/) [ghstack-poisoned]
Pull Request resolved: #62242 1) Add a state_dict hook to ensure ShardedTensors are added to a state_dict. 2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a module at load time. 3) Add a `with_load_process_group` context manager for load time. 4) Added ser-de capability to ShardedTensor. ghstack-source-id: 134775358 Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/)
1) Add a state_dict hook to ensure ShardedTensors are added to a state_dict. 2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a module at load time. 3) Add a `with_load_process_group` context manager for load time. 4) Added ser-de capability to ShardedTensor. Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/) [ghstack-poisoned]
Pull Request resolved: #62242 1) Add a state_dict hook to ensure ShardedTensors are added to a state_dict. 2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a module at load time. 3) Add a `with_load_process_group` context manager for load time. 4) Added ser-de capability to ShardedTensor. ghstack-source-id: 134860967 Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/)
This pull request has been merged in c07a123. |
elif self.memory_format == torch.channels_last: | ||
mem_format_encoding = 1 | ||
elif self.memory_format == torch.preserve_format: | ||
mem_format_encoding = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I've found a copy-paste error, it seems like from the setstate logic, mem_format_encoding should be = 2 in this case.
def _recurse_update_module(module, state_dict, prefix): | ||
for attr_name, attr in module.__dict__.items(): | ||
key = prefix + attr_name | ||
if key in state_dict: | ||
if isinstance(state_dict[key], ShardedTensor): | ||
setattr(module, attr_name, state_dict[key]) | ||
|
||
for submodule_name, submodule in module.named_modules(): | ||
key = prefix + submodule_name | ||
if submodule_name: | ||
_recurse_update_module(submodule, state_dict, key + '.') | ||
|
||
|
||
def _recurse_update_dict(module, destination, prefix): | ||
for attr_name, attr in module.__dict__.items(): | ||
if isinstance(attr, ShardedTensor): | ||
destination[prefix + attr_name] = attr | ||
|
||
for submodule_name, submodule in module.named_modules(): | ||
if submodule_name != '': | ||
_recurse_update_dict(submodule, destination, prefix + submodule_name + '.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I posted an issue here #68805 for a potential improvement. I believe the recursion here is not necessary and causes inefficiency when retrieving the state dict.
Stack from ghstack:
added to a state_dict.
module at load time.
with_load_process_group
context manager for load time.Differential Revision: D29927881