Skip to content

Follow up — changes to load model state in checkpoint connector in case of multiple workers #8044 #8515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from

Conversation

mleshen
Copy link

@mleshen mleshen commented Jul 21, 2021

What does this PR do?

Fixes #<issue_number>

Does your PR introduce any breaking changes ? If yes, please list them.

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)
  • Did you list all the breaking changes introduced by this pull request?

PR review

Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

@codecov
Copy link

codecov bot commented Jul 21, 2021

Codecov Report

Merging #8515 (c7f8c8c) into master (366fb39) will increase coverage by 0%.
The diff coverage is n/a.

❗ Current head c7f8c8c differs from pull request most recent head 6388a33. Consider uploading reports for the commit 6388a33 to get more accurate results

@@           Coverage Diff           @@
##           master   #8515    +/-   ##
=======================================
  Coverage      92%     92%            
=======================================
  Files         175     218    +43     
  Lines       14696   14412   -284     
=======================================
- Hits        13508   13308   -200     
+ Misses       1188    1104    -84     

@awaelchli awaelchli added this to the v1.5 milestone Jul 22, 2021
@awaelchli awaelchli added feature Is an improvement or enhancement design Includes a design discussion labels Jul 22, 2021
Comment on lines 188 to 195
for current_worker in range(self.num_processes):
if self.local_rank == current_worker:
checkpoint = super().load_checkpoint_file(checkpoint_path)
self.lightning_module.on_load_checkpoint(checkpoint)
self.load_model_state_dict(checkpoint)
log.info(f"Rank {self.global_rank}: done loading model states from {checkpoint_path}.")
del checkpoint["state_dict"]
self.barrier()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is pretty smart :)

one concern, now the responsibility of calling the hooks is shifted to the plugin. do we want to allow that?
@PyTorchLightning/core-contributors

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, I don't see a problem in it. The only thing I am afraid of, is that we may have a precedence case here and that might lead to this pattern where we don't want it.

What do you think @SeanNaren @awaelchli @ananthsub @carmocca ?

@mleshen mleshen requested a review from edenlightning as a code owner July 23, 2021 18:35
f"FullyShardedDataParallel has {self.num_processes} processes. Serializing model
state restore to avoid CPU OOMs."
)
for current_worker in range(self.num_processes):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this break if the number of gpus on checkpointing / saving isn't the same. Should we save the world_size + rank in each checkpoint and re-use that information on reload ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for raising this! thinking aloud — in the same training session we wouldn't run into this problem, but even in the case where a model checkpoint from a different training session is being finetuned in a new training session with a different machine, do we envision that world size would be different (for our use cases at least we have config files that keep constant variables like Trainer(gpus) for the same model training)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that is definitely something we need to support. For example I usually use the output from torch.cuda.device_count() as number of gpus.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton this kind of training "metadata" should get saved with the checkpoint. For example, we will also want to know this for fault-tolerance to fail if the trainer configuration has changed between runs and the user is trying to restore mid-batch.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW i ran a test taking a checkpoint from a previous training session and starting a fresh one: setting trainer.resume_from_checkpoint to the old checkpoint and trainer.gpus=4 instead of 8 (which was num gpus on the original training session) and loading the checkpoint on 4 GPUs didn't break.

i can add this feature in another pull request! would a good place to add the metadata be on_save_checkpoint in model_checkpoint.py? eg here: https://github.com/PyTorchLightning/pytorch-lightning/blob/c7f8c8c3c82b4f249125885490b2392bf9d3d08b/pytorch_lightning/callbacks/model_checkpoint.py#L341

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply! I've filed #9123 to track this


self.trainer.lightning_module.on_load_checkpoint(checkpoint)
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)
if hasattr(self.trainer.training_type_plugin, "serialized_restore_model_state"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just personal preference, but I would prefer to refactor the code and pass the checkpoint_path to the training_type plugin and it handles the loading logic.

checkpoint = self.trainer.training_type_plugin.serialized_restore_model_state(checkpoint_path)
else:
checkpoint = self.trainer.training_type_plugin.load_checkpoint_file(checkpoint_path)
self.trainer.lightning_module.on_load_checkpoint(checkpoint)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, should we move self.trainer.lightning_module.on_load_checkpoint(checkpoint) to training type plugin ?

)
for current_worker in range(self.num_processes):
if self.local_rank == current_worker:
checkpoint = super().load_checkpoint_file(checkpoint_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this assume the same checkpoint path for all workers ?

Comment on lines 195 to 196
self.load_model_state_dict(checkpoint)
del checkpoint["state_dict"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.load_model_state_dict(checkpoint)
del checkpoint["state_dict"]
self.load_model_state_dict(checkpoint.pop("state_dict"))

checkpoint = {}
rank_zero_info(
f"FullyShardedDataParallel has {self.num_processes} processes. Serializing model
state restore to avoid CPU OOMs."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
state restore to avoid CPU OOMs."
state restoration to avoid CPU OOMs."

f"FullyShardedDataParallel has {self.num_processes} processes. Serializing model
state restore to avoid CPU OOMs."
)
for current_worker in range(self.num_processes):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton this kind of training "metadata" should get saved with the checkpoint. For example, we will also want to know this for fault-tolerance to fail if the trainer configuration has changed between runs and the user is trying to restore mid-batch.

@tchaton
Copy link
Contributor

tchaton commented Jul 29, 2021

Hey @mleshen, any updates on this PR ? Do you need some assistance ?

@maximsch2
Copy link
Contributor

@mleshen , and others - this has been an issue for us a few times in the past and seems to be a particularly tricky things to test for. Do we have tests for checkpoint loading that we can amend to include memory tracking to make sure memory used doesn't scale with number of workers?

"dirpath": self.dirpath
"dirpath": self.dirpath,
"world_size": trainer.world_size,
"node_rank": trainer.node_rank,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add num_nodes too.

f"FullyShardedDataParallel has {self.num_processes} processes. Serializing model
state restoration to avoid CPU OOMs."
)
for current_worker in range(self.num_processes):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind adding a comment to explain what is happening in case new reader reached this code :)


self.trainer.lightning_module.on_load_checkpoint(checkpoint)
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)
checkpoint = self.trainer.training_type_plugin.load_model_state(checkpoint_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much cleaner !

@mergify mergify bot removed the has conflicts label Aug 26, 2021
@@ -178,6 +182,24 @@ def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
# state dict.
return super().lightning_module_state_dict()

def load_model_state(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
checkpoint = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring ?

@tchaton
Copy link
Contributor

tchaton commented Aug 26, 2021

Dear @mleshen,

Any updates on this PR ?
noob question. Does FSDP save only 1 checkpoint or a directory with multiple checkpoints as DeeSpeed does. ?

Best,
T.C

"deleted state_dict from checkpoint."
)
self.barrier()
return checkpoint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess, returning the checkpoint is not required. + docstrings

checkpoint = self.load_checkpoint_file(checkpoint_path)
self.on_load_checkpoint(checkpoint)
self.load_model_state_dict(checkpoint)
return checkpoint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here.

@awaelchli awaelchli modified the milestones: v1.5, v1.6 Nov 1, 2021
@carmocca carmocca removed this from the 1.6 milestone Mar 28, 2022
@stale
Copy link

stale bot commented Apr 16, 2022

This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 7 days if no further activity occurs. If you need further help see our docs: https://pytorch-lightning.readthedocs.io/en/latest/generated/CONTRIBUTING.html#pull-request or ask the assistance of a core contributor here or on Slack. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Apr 16, 2022
@stale
Copy link

stale bot commented Apr 21, 2022

This pull request is going to be closed. Please feel free to reopen it create a new from the actual master.

@stale stale bot closed this Apr 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion feature Is an improvement or enhancement has conflicts won't fix This will not be worked on
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants