Skip to content

Support serialized checkpoint loading [redo #9605] #10141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import time
from pathlib import Path
from time import sleep
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Union

import __main__
import numpy as np
Expand Down Expand Up @@ -50,18 +50,19 @@
rank_zero_deprecation,
rank_zero_warn,
)
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import _info, distributed_available
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import (
init_dist_connection,
rank_zero_info,
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.enums import DistributedType
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT

if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
Expand Down Expand Up @@ -113,6 +114,7 @@ def __init__(
self._pids: Optional[List[int]] = None
self._sync_dir: Optional[str] = None
self._rank_0_has_called_call_children_scripts: bool = False
self._has_loaded_state_dict: bool = False
self.set_world_ranks()

@property
Expand Down Expand Up @@ -529,3 +531,31 @@ def teardown(self) -> None:
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()

self._has_loaded_state_dict = False

def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
if self._has_loaded_state_dict:
return
self.lightning_module.load_state_dict(checkpoint["state_dict"])

def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we test this by mocking the hooks called and asserting the order? also the warning/info message

rank_zero_info(
f"DistributedDataParallel has {self.num_processes} processes. "
"Serializing checkpoint loading to avoid CPU OOMs."
)
for current_worker in range(self.num_processes):
if self.local_rank == current_worker:
checkpoint = super().load_checkpoint(checkpoint_path)
self.lightning_module.on_load_checkpoint(checkpoint)
# call hpc specific hook
if self.lightning_module.trainer.checkpoint_connector.hpc_resume_path is not None:
self.lightning_module.on_hpc_load(checkpoint)
self.load_model_state_dict(checkpoint)
# move the model to the correct device in order to release state_dict memory
self.model_to_device()
del checkpoint["state_dict"]
self._has_loaded_state_dict = True
_info(f"Rank {self.global_rank}: done loading model states from {checkpoint_path}.")
self.barrier()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this hang if the user runs any collective op in their on_load_checkpoint hook?

Copy link
Contributor Author

@jjenniferdai jjenniferdai Oct 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point I think yes ... is this resolvable though, how could we support collective op in model.on_load_checkpoint(checkpoint) if it's impossible to concurrently load then pass checkpoint for large models that OOM?

If it's not possible (please correct me if i'm wrong) - should we control serialized loading option with a boolean, and note that if serialized loading is set, model.on_load_checkpoint would not support collective ops?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@carmocca good catch. Yes this will hang if the user on_load_checkpoint runs any collective.

Copy link
Contributor

@carmocca carmocca Oct 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for large models that OOM?

I guess it's more of a problem for models that don't need this serialized loading. Because this PR makes serialization the default mechanism with no opt-out.

should we control serialized loading option with a boolean

This would probably be too ugly to implement.

This weak point seems very niche, as you would generally not run a collective op inside these hooks. Some possibilities that come to mind:

  • IIRC, FB has mentioned in the past that they call model.setup from model.on_load_checkpoint in some cases (sharded only?). I'd say setup is more likely to run such ops.
  • Naive property accesses (like self.trainer.log_dir) can run broadcast.

For this PR, I'd suggest adding a comment here about this weak point. But it's something to remember if we see such hangs.

Perhaps this boolean would be easier to implement and manage after the collective refactor.

cc @awaelchli

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, FB has mentioned in the past that they call model.setup from model.on_load_checkpoint in some cases (sharded only?). I'd say setup is more likely to run such ops.

It's not allowed to call setup() in this hook. With the change we made to the loading logic in #????, we removed the need to call setup() in that hook.

But generally I agree, it is going to be hard to debug hangs as this mechanism may not be obvious. Adding a comment would help minimally.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds great @yifuwang !!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could also resolve the current bug where the checkpoint doesn't exist on some nodes when the filesystem isn't shared.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yifuwan, Would you mind opening an issue for this proposal?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we close this PR and use the proposed logic ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yifuwang @jjenniferdai Who would make a PR with the above proposition?

return checkpoint
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,8 @@ def lightning_restore_optimizer_and_schedulers(self) -> bool:
return False

def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
if self._has_loaded_state_dict:
return
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()`
if self.load_full_weights and self.zero_stage_3:
self.model_to_device()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
return self.checkpoint_io.load_checkpoint(checkpoint_path)

def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
self.lightning_module.on_load_checkpoint(checkpoint)
# call hpc specific hook
if self.lightning_module.trainer.checkpoint_connector.hpc_resume_path is not None:
self.lightning_module.on_hpc_load(checkpoint)
self.lightning_module.load_state_dict(checkpoint["state_dict"])

def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,6 @@ def restore_model(self) -> None:
if not self._loaded_checkpoint:
return

model = self.trainer.lightning_module

# hook: give user access to checkpoint if needed.
model.on_load_checkpoint(self._loaded_checkpoint)

# call hpc specific hook
if self.hpc_resume_path is not None:
model.on_hpc_load(self._loaded_checkpoint)

# restore model state_dict
self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint)

Expand Down
1 change: 1 addition & 0 deletions tests/trainer/connectors/test_checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def test_preloaded_checkpoint_lifecycle(tmpdir):
ckpt_path = trainer.checkpoint_callback.best_model_path
trainer = Trainer(default_root_dir=tmpdir, max_steps=2)
connector = trainer.checkpoint_connector
trainer.accelerator.connect(model)
connector.resume_start(ckpt_path)
assert connector.resume_checkpoint_path == ckpt_path
assert connector._loaded_checkpoint
Expand Down