Skip to content

API change, expose model's state_dict to accelerator.training_type_plugin #7470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
shuyingsunshine21 opened this issue May 10, 2021 · 0 comments · Fixed by #7474
Closed

API change, expose model's state_dict to accelerator.training_type_plugin #7470

shuyingsunshine21 opened this issue May 10, 2021 · 0 comments · Fixed by #7474
Labels
checkpointing Related to checkpointing feature Is an improvement or enhancement help wanted Open to be worked on
Milestone

Comments

@shuyingsunshine21
Copy link
Contributor

🚀 Feature

Currently, in CheckpointConnector.dump_checkpoint, we have

model = self.trainer.lightning_module

checkpoint = {
    'epoch': current_epoch,
    'global_step': global_step,
    'pytorch-lightning_version': pytorch_lightning.__version__,
    'state_dict': model.state_dict(),
}

so model's state dict is extracted here. However, let accelerator.training_type_plugin control the logic might make more sense especially for sharded plugin, we might need to access the local (i.e. sharded) state instead of the whole states.

Motivation

#6152 (comment)

we would like to make customized model state dict for specific training type plugin, we could override the training_type_plugin.on_save method to modify the state dict, but this would cause duplicate call for extracting model state dict.

Pitch

define a new method for TrainingTypePlugin

def state_dict(self) -> dict:
     model = self.lightning_module
     return model.state_dict()

and in CheckpointConnector.dump_checkpoint,

checkpoint = {
    'epoch': current_epoch,
    'global_step': global_step,
    'pytorch-lightning_version': pytorch_lightning.__version__,
    'state_dict': self.trainer.accelerator.training_type_plugin.state_dict(),
}

Alternatives

Additional context

@shuyingsunshine21 shuyingsunshine21 added feature Is an improvement or enhancement help wanted Open to be worked on labels May 10, 2021
@SeanNaren SeanNaren added the checkpointing Related to checkpointing label May 10, 2021
@SeanNaren SeanNaren modified the milestones: v1.3.x, v1.4 May 10, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
checkpointing Related to checkpointing feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants