API change, expose model's state_dict to accelerator.training_type_plugin
#7470
Labels
checkpointing
Related to checkpointing
feature
Is an improvement or enhancement
help wanted
Open to be worked on
Milestone
🚀 Feature
Currently, in
CheckpointConnector.dump_checkpoint
, we haveso model's state dict is extracted here. However, let
accelerator.training_type_plugin
control the logic might make more sense especially for sharded plugin, we might need to access the local (i.e. sharded) state instead of the whole states.Motivation
#6152 (comment)
we would like to make customized model state dict for specific training type plugin, we could override the
training_type_plugin.on_save
method to modify the state dict, but this would cause duplicate call for extracting model state dict.Pitch
define a new method for
TrainingTypePlugin
and in
CheckpointConnector.dump_checkpoint
,Alternatives
Additional context
The text was updated successfully, but these errors were encountered: