Skip to content

Follow up — changes to load model state in checkpoint connector in case of multiple workers #8044 #8515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9cb0246
Update checkpoint_connector.py
mleshen Jun 20, 2021
ee7f979
lint
mleshen Jun 20, 2021
e86670e
cleanup
mleshen Jun 20, 2021
20e429a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 20, 2021
b352d24
serialize load_checkpoint_file in fully_sharded
mleshen Jul 21, 2021
815c474
Merge branch 'master' of https://github.com/mleshen/pytorch-lightning
mleshen Jul 21, 2021
6791642
revert changes to checkpoint_connector
mleshen Jul 21, 2021
ec70b80
Merge branch 'PyTorchLightning:master' into master
mleshen Jul 21, 2021
f01e705
remove changes
mleshen Jul 21, 2021
7f59061
Merge branch 'PyTorchLightning:master' into master
mleshen Jul 21, 2021
741bd63
add serialized restore model state
mleshen Jul 21, 2021
8138319
Merge branch 'master' of https://github.com/mleshen/pytorch-lightning
mleshen Jul 21, 2021
7482b8c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 21, 2021
3f47522
address review comments
mleshen Jul 23, 2021
9e4d746
Merge branch 'PyTorchLightning:master' into master
mleshen Jul 23, 2021
c2e1afb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 23, 2021
ef2be98
fix fully_sharded
mleshen Jul 23, 2021
d4f2e15
Merge branch 'master' of https://github.com/mleshen/pytorch-lightning
mleshen Jul 23, 2021
a9dcd13
add metadata to model checkpoint
mleshen Aug 2, 2021
15ed02f
refactor, move loading logic to training type plugin
mleshen Aug 2, 2021
6388a33
resolve conflitcs
tchaton Aug 26, 2021
7265fda
add docstring
tchaton Aug 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions pytorch_lightning/plugins/training_type/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Union

import torch
Expand All @@ -27,6 +29,8 @@
from fairscale.nn import default_auto_wrap_policy, enable_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel

log: logging.Logger = logging.getLogger(__name__)


class DDPFullyShardedPlugin(DDPPlugin):

Expand Down Expand Up @@ -178,6 +182,19 @@ def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
# state dict.
return super().lightning_module_state_dict()

def serialized_restore_model_state(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
checkpoint = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring ?

log.info(f"FullyShardedDataParallel has {self.num_processes} processes. Serializing to avoid CPU OOMs.")
for current_worker in range(self.num_processes):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this break if the number of gpus on checkpointing / saving isn't the same. Should we save the world_size + rank in each checkpoint and re-use that information on reload ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for raising this! thinking aloud — in the same training session we wouldn't run into this problem, but even in the case where a model checkpoint from a different training session is being finetuned in a new training session with a different machine, do we envision that world size would be different (for our use cases at least we have config files that keep constant variables like Trainer(gpus) for the same model training)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that is definitely something we need to support. For example I usually use the output from torch.cuda.device_count() as number of gpus.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton this kind of training "metadata" should get saved with the checkpoint. For example, we will also want to know this for fault-tolerance to fail if the trainer configuration has changed between runs and the user is trying to restore mid-batch.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW i ran a test taking a checkpoint from a previous training session and starting a fresh one: setting trainer.resume_from_checkpoint to the old checkpoint and trainer.gpus=4 instead of 8 (which was num gpus on the original training session) and loading the checkpoint on 4 GPUs didn't break.

i can add this feature in another pull request! would a good place to add the metadata be on_save_checkpoint in model_checkpoint.py? eg here: https://github.com/PyTorchLightning/pytorch-lightning/blob/c7f8c8c3c82b4f249125885490b2392bf9d3d08b/pytorch_lightning/callbacks/model_checkpoint.py#L341

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply! I've filed #9123 to track this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind adding a comment to explain what is happening in case new reader reached this code :)

if self.local_rank == current_worker:
checkpoint = super().load_checkpoint_file(checkpoint_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this assume the same checkpoint path for all workers ?

self.lightning_module.on_load_checkpoint(checkpoint)
self.load_model_state_dict(checkpoint)
log.info(f"Rank {self.global_rank}: done loading model states from {checkpoint_path}.")
del checkpoint["state_dict"]
self.barrier()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is pretty smart :)

one concern, now the responsibility of calling the hooks is shifted to the plugin. do we want to allow that?
@PyTorchLightning/core-contributors

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, I don't see a problem in it. The only thing I am afraid of, is that we may have a precedence case here and that might lead to this pattern where we don't want it.

What do you think @SeanNaren @awaelchli @ananthsub @carmocca ?

return checkpoint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess, returning the checkpoint is not required. + docstrings


@property
def setup_optimizers_in_pre_dispatch(self) -> bool:
# Setup optimizers after the Fully Sharded Model has been made
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,20 @@ def restore_model(self) -> None:
def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None:
""" Restore only the model weights. """
checkpoint = self._loaded_checkpoint
if hasattr(self.trainer.training_type_plugin, "serialized_restore_model_state"):
checkpoint = self.trainer.training_type_plugin.serialized_restore_model_state(checkpoint_path)
else:
checkpoint = self.restore_model_state(checkpoint_path)

def restore_model_state(self, checkpoint_path: Optional[Union[str, Path]]) -> dict:
if checkpoint_path is not None:
checkpoint = self.trainer.training_type_plugin.load_checkpoint_file(checkpoint_path)

self.trainer.lightning_module.on_load_checkpoint(checkpoint)
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)

return checkpoint

def restore_training_state(self) -> None:
"""
Restore the trainer state from the pre-loaded checkpoint. This includes the precision settings, loop progress,
Expand Down