-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Follow up — changes to load model state in checkpoint connector in case of multiple workers #8044 #8515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Codecov Report
@@ Coverage Diff @@
## master #8515 +/- ##
=======================================
Coverage 92% 92%
=======================================
Files 175 218 +43
Lines 14696 14412 -284
=======================================
- Hits 13508 13308 -200
+ Misses 1188 1104 -84 |
for current_worker in range(self.num_processes): | ||
if self.local_rank == current_worker: | ||
checkpoint = super().load_checkpoint_file(checkpoint_path) | ||
self.lightning_module.on_load_checkpoint(checkpoint) | ||
self.load_model_state_dict(checkpoint) | ||
log.info(f"Rank {self.global_rank}: done loading model states from {checkpoint_path}.") | ||
del checkpoint["state_dict"] | ||
self.barrier() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is pretty smart :)
one concern, now the responsibility of calling the hooks is shifted to the plugin. do we want to allow that?
@PyTorchLightning/core-contributors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, I don't see a problem in it. The only thing I am afraid of, is that we may have a precedence case here and that might lead to this pattern where we don't want it.
What do you think @SeanNaren @awaelchli @ananthsub @carmocca ?
f"FullyShardedDataParallel has {self.num_processes} processes. Serializing model | ||
state restore to avoid CPU OOMs." | ||
) | ||
for current_worker in range(self.num_processes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this break if the number of gpus on checkpointing / saving isn't the same. Should we save the world_size + rank in each checkpoint and re-use that information on reload ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for raising this! thinking aloud — in the same training session we wouldn't run into this problem, but even in the case where a model checkpoint from a different training session is being finetuned in a new training session with a different machine, do we envision that world size would be different (for our use cases at least we have config files that keep constant variables like Trainer(gpus) for the same model training)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that is definitely something we need to support. For example I usually use the output from torch.cuda.device_count()
as number of gpus.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tchaton this kind of training "metadata" should get saved with the checkpoint. For example, we will also want to know this for fault-tolerance to fail if the trainer configuration has changed between runs and the user is trying to restore mid-batch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW i ran a test taking a checkpoint from a previous training session and starting a fresh one: setting trainer.resume_from_checkpoint to the old checkpoint and trainer.gpus=4 instead of 8 (which was num gpus on the original training session) and loading the checkpoint on 4 GPUs didn't break.
i can add this feature in another pull request! would a good place to add the metadata be on_save_checkpoint
in model_checkpoint.py
? eg here: https://github.com/PyTorchLightning/pytorch-lightning/blob/c7f8c8c3c82b4f249125885490b2392bf9d3d08b/pytorch_lightning/callbacks/model_checkpoint.py#L341
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late reply! I've filed #9123 to track this
|
||
self.trainer.lightning_module.on_load_checkpoint(checkpoint) | ||
self.trainer.training_type_plugin.load_model_state_dict(checkpoint) | ||
if hasattr(self.trainer.training_type_plugin, "serialized_restore_model_state"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just personal preference, but I would prefer to refactor the code and pass the checkpoint_path
to the training_type plugin and it handles the loading logic.
checkpoint = self.trainer.training_type_plugin.serialized_restore_model_state(checkpoint_path) | ||
else: | ||
checkpoint = self.trainer.training_type_plugin.load_checkpoint_file(checkpoint_path) | ||
self.trainer.lightning_module.on_load_checkpoint(checkpoint) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency, should we move self.trainer.lightning_module.on_load_checkpoint(checkpoint)
to training type plugin ?
) | ||
for current_worker in range(self.num_processes): | ||
if self.local_rank == current_worker: | ||
checkpoint = super().load_checkpoint_file(checkpoint_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this assume the same checkpoint path for all workers ?
self.load_model_state_dict(checkpoint) | ||
del checkpoint["state_dict"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.load_model_state_dict(checkpoint) | |
del checkpoint["state_dict"] | |
self.load_model_state_dict(checkpoint.pop("state_dict")) |
checkpoint = {} | ||
rank_zero_info( | ||
f"FullyShardedDataParallel has {self.num_processes} processes. Serializing model | ||
state restore to avoid CPU OOMs." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
state restore to avoid CPU OOMs." | |
state restoration to avoid CPU OOMs." |
f"FullyShardedDataParallel has {self.num_processes} processes. Serializing model | ||
state restore to avoid CPU OOMs." | ||
) | ||
for current_worker in range(self.num_processes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tchaton this kind of training "metadata" should get saved with the checkpoint. For example, we will also want to know this for fault-tolerance to fail if the trainer configuration has changed between runs and the user is trying to restore mid-batch.
Hey @mleshen, any updates on this PR ? Do you need some assistance ? |
@mleshen , and others - this has been an issue for us a few times in the past and seems to be a particularly tricky things to test for. Do we have tests for checkpoint loading that we can amend to include memory tracking to make sure memory used doesn't scale with number of workers? |
"dirpath": self.dirpath | ||
"dirpath": self.dirpath, | ||
"world_size": trainer.world_size, | ||
"node_rank": trainer.node_rank, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add num_nodes
too.
f"FullyShardedDataParallel has {self.num_processes} processes. Serializing model | ||
state restoration to avoid CPU OOMs." | ||
) | ||
for current_worker in range(self.num_processes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mind adding a comment to explain what is happening in case new reader reached this code :)
|
||
self.trainer.lightning_module.on_load_checkpoint(checkpoint) | ||
self.trainer.training_type_plugin.load_model_state_dict(checkpoint) | ||
checkpoint = self.trainer.training_type_plugin.load_model_state(checkpoint_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much cleaner !
@@ -178,6 +182,24 @@ def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: | |||
# state dict. | |||
return super().lightning_module_state_dict() | |||
|
|||
def load_model_state(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: | |||
checkpoint = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring ?
Dear @mleshen, Any updates on this PR ? Best, |
"deleted state_dict from checkpoint." | ||
) | ||
self.barrier() | ||
return checkpoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess, returning the checkpoint is not required. + docstrings
checkpoint = self.load_checkpoint_file(checkpoint_path) | ||
self.on_load_checkpoint(checkpoint) | ||
self.load_model_state_dict(checkpoint) | ||
return checkpoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here.
This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 7 days if no further activity occurs. If you need further help see our docs: https://pytorch-lightning.readthedocs.io/en/latest/generated/CONTRIBUTING.html#pull-request or ask the assistance of a core contributor here or on Slack. Thank you for your contributions. |
This pull request is going to be closed. Please feel free to reopen it create a new from the actual master. |
What does this PR do?
Fixes #<issue_number>
Does your PR introduce any breaking changes ? If yes, please list them.
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃