-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Support serialized checkpoint loading [redo #9605] #10141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2f649d9
449a731
83bd8ac
ca02abc
18ac502
89d7fc9
09752b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,7 @@ | |
import time | ||
from pathlib import Path | ||
from time import sleep | ||
from typing import Any, Dict, List, Optional, Union | ||
from typing import Any, Dict, List, Mapping, Optional, Union | ||
|
||
import __main__ | ||
import numpy as np | ||
|
@@ -50,18 +50,19 @@ | |
rank_zero_deprecation, | ||
rank_zero_warn, | ||
) | ||
from pytorch_lightning.utilities.distributed import distributed_available | ||
from pytorch_lightning.utilities.distributed import _info, distributed_available | ||
from pytorch_lightning.utilities.distributed import group as _group | ||
from pytorch_lightning.utilities.distributed import ( | ||
init_dist_connection, | ||
rank_zero_info, | ||
rank_zero_only, | ||
ReduceOp, | ||
sync_ddp_if_available, | ||
) | ||
from pytorch_lightning.utilities.enums import DistributedType | ||
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException | ||
from pytorch_lightning.utilities.seed import reset_seed | ||
from pytorch_lightning.utilities.types import STEP_OUTPUT | ||
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT | ||
|
||
if _FAIRSCALE_AVAILABLE: | ||
from fairscale.optim import OSS | ||
|
@@ -113,6 +114,7 @@ def __init__( | |
self._pids: Optional[List[int]] = None | ||
self._sync_dir: Optional[str] = None | ||
self._rank_0_has_called_call_children_scripts: bool = False | ||
self._has_loaded_state_dict: bool = False | ||
self.set_world_ranks() | ||
|
||
@property | ||
|
@@ -529,3 +531,31 @@ def teardown(self) -> None: | |
self.lightning_module.cpu() | ||
# clean up memory | ||
torch.cuda.empty_cache() | ||
|
||
self._has_loaded_state_dict = False | ||
|
||
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: | ||
if self._has_loaded_state_dict: | ||
return | ||
self.lightning_module.load_state_dict(checkpoint["state_dict"]) | ||
|
||
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we test this by mocking the hooks called and asserting the order? also the warning/info message |
||
rank_zero_info( | ||
f"DistributedDataParallel has {self.num_processes} processes. " | ||
"Serializing checkpoint loading to avoid CPU OOMs." | ||
) | ||
for current_worker in range(self.num_processes): | ||
if self.local_rank == current_worker: | ||
checkpoint = super().load_checkpoint(checkpoint_path) | ||
self.lightning_module.on_load_checkpoint(checkpoint) | ||
# call hpc specific hook | ||
if self.lightning_module.trainer.checkpoint_connector.hpc_resume_path is not None: | ||
self.lightning_module.on_hpc_load(checkpoint) | ||
self.load_model_state_dict(checkpoint) | ||
# move the model to the correct device in order to release state_dict memory | ||
self.model_to_device() | ||
del checkpoint["state_dict"] | ||
self._has_loaded_state_dict = True | ||
_info(f"Rank {self.global_rank}: done loading model states from {checkpoint_path}.") | ||
self.barrier() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this hang if the user runs any collective op in their There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point I think yes ... is this resolvable though, how could we support collective op in If it's not possible (please correct me if i'm wrong) - should we control serialized loading option with a boolean, and note that if serialized loading is set, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @carmocca good catch. Yes this will hang if the user on_load_checkpoint runs any collective. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I guess it's more of a problem for models that don't need this serialized loading. Because this PR makes serialization the default mechanism with no opt-out.
This would probably be too ugly to implement. This weak point seems very niche, as you would generally not run a collective op inside these hooks. Some possibilities that come to mind:
For this PR, I'd suggest adding a comment here about this weak point. But it's something to remember if we see such hangs. Perhaps this boolean would be easier to implement and manage after the collective refactor. cc @awaelchli There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It's not allowed to call setup() in this hook. With the change we made to the loading logic in #????, we removed the need to call setup() in that hook. But generally I agree, it is going to be hard to debug hangs as this mechanism may not be obvious. Adding a comment would help minimally. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That sounds great @yifuwang !! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could also resolve the current bug where the checkpoint doesn't exist on some nodes when the filesystem isn't shared. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yifuwan, Would you mind opening an issue for this proposal? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we close this PR and use the proposed logic ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yifuwang @jjenniferdai Who would make a PR with the above proposition? |
||
return checkpoint |
Uh oh!
There was an error while loading. Please reload this page.