Skip to content

Support saving and loading ShardedTensor. #62242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

pritamdamania87
Copy link
Contributor

@pritamdamania87 pritamdamania87 commented Jul 27, 2021

Stack from ghstack:

  1. Add a state_dict hook to ensure ShardedTensors are
    added to a state_dict.
  2. Add a pre load state_dict hook to ensure ShardedTensor are added back to a
    module at load time.
  3. Add a with_load_process_group context manager for load time.
  4. Added ser-de capability to ShardedTensor.

Differential Revision: D29927881

1) Add a state_dict hook to ensure ShardedTensors are
added to a state_dict.
2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a
module at load time.
3) Add a `with_load_process_group` context manager for load time.
4) Added ser-de capability to ShardedTensor.

Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jul 27, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit c1af3a0 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-scanned failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@with_comms
@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_state_dict(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bookmark

1) Add a state_dict hook to ensure ShardedTensors are
added to a state_dict.
2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a
module at load time.
3) Add a `with_load_process_group` context manager for load time.
4) Added ser-de capability to ShardedTensor.

Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/)

[ghstack-poisoned]
1) Add a state_dict hook to ensure ShardedTensors are
added to a state_dict.
2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a
module at load time.
3) Add a `with_load_process_group` context manager for load time.
4) Added ser-de capability to ShardedTensor.

Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this pull request Jul 27, 2021
Pull Request resolved: #62242

1) Add a state_dict hook to ensure ShardedTensors are
added to a state_dict.
2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a
module at load time.
3) Add a `with_load_process_group` context manager for load time.
4) Added ser-de capability to ShardedTensor.
ghstack-source-id: 134381847

Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/)
@pritamdamania87 pritamdamania87 requested a review from wanchaol July 27, 2021 20:32
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have some concerns on using the hook to update the state_dict, otherwise looks good to me!

1) Add a state_dict hook to ensure ShardedTensors are
added to a state_dict.
2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a
module at load time.
3) Add a `with_load_process_group` context manager for load time.
4) Added ser-de capability to ShardedTensor.

Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/)

[ghstack-poisoned]
1) Add a state_dict hook to ensure ShardedTensors are
added to a state_dict.
2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a
module at load time.
3) Add a `with_load_process_group` context manager for load time.
4) Added ser-de capability to ShardedTensor.

Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this pull request Jul 28, 2021
Pull Request resolved: #62242

1) Add a state_dict hook to ensure ShardedTensors are
added to a state_dict.
2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a
module at load time.
3) Add a `with_load_process_group` context manager for load time.
4) Added ser-de capability to ShardedTensor.
ghstack-source-id: 134574329

Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/)
@pritamdamania87 pritamdamania87 requested a review from wanchaol July 28, 2021 23:55
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great! just one suggestion about adding the tests for exception handling.

1) Add a state_dict hook to ensure ShardedTensors are
added to a state_dict.
2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a
module at load time.
3) Add a `with_load_process_group` context manager for load time.
4) Added ser-de capability to ShardedTensor.

Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this pull request Jul 30, 2021
Pull Request resolved: #62242

1) Add a state_dict hook to ensure ShardedTensors are
added to a state_dict.
2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a
module at load time.
3) Add a `with_load_process_group` context manager for load time.
4) Added ser-de capability to ShardedTensor.
ghstack-source-id: 134741074

Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/)
1) Add a state_dict hook to ensure ShardedTensors are
added to a state_dict.
2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a
module at load time.
3) Add a `with_load_process_group` context manager for load time.
4) Added ser-de capability to ShardedTensor.

Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this pull request Jul 31, 2021
Pull Request resolved: #62242

1) Add a state_dict hook to ensure ShardedTensors are
added to a state_dict.
2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a
module at load time.
3) Add a `with_load_process_group` context manager for load time.
4) Added ser-de capability to ShardedTensor.
ghstack-source-id: 134775358

Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/)
1) Add a state_dict hook to ensure ShardedTensors are
added to a state_dict.
2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a
module at load time.
3) Add a `with_load_process_group` context manager for load time.
4) Added ser-de capability to ShardedTensor.

Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this pull request Aug 2, 2021
Pull Request resolved: #62242

1) Add a state_dict hook to ensure ShardedTensors are
added to a state_dict.
2) Add a pre load state_dict hook to ensure ShardedTensor are added back to a
module at load time.
3) Add a `with_load_process_group` context manager for load time.
4) Added ser-de capability to ShardedTensor.
ghstack-source-id: 134860967

Differential Revision: [D29927881](https://our.internmc.facebook.com/intern/diff/D29927881/)
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in c07a123.

elif self.memory_format == torch.channels_last:
mem_format_encoding = 1
elif self.memory_format == torch.preserve_format:
mem_format_encoding = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I've found a copy-paste error, it seems like from the setstate logic, mem_format_encoding should be = 2 in this case.

Comment on lines +100 to +120
def _recurse_update_module(module, state_dict, prefix):
for attr_name, attr in module.__dict__.items():
key = prefix + attr_name
if key in state_dict:
if isinstance(state_dict[key], ShardedTensor):
setattr(module, attr_name, state_dict[key])

for submodule_name, submodule in module.named_modules():
key = prefix + submodule_name
if submodule_name:
_recurse_update_module(submodule, state_dict, key + '.')


def _recurse_update_dict(module, destination, prefix):
for attr_name, attr in module.__dict__.items():
if isinstance(attr, ShardedTensor):
destination[prefix + attr_name] = attr

for submodule_name, submodule in module.named_modules():
if submodule_name != '':
_recurse_update_dict(submodule, destination, prefix + submodule_name + '.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I posted an issue here #68805 for a potential improvement. I believe the recursion here is not necessary and causes inefficiency when retrieving the state dict.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants