diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 2732762df5f75..81a5d385553e8 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -21,7 +21,7 @@ import time from pathlib import Path from time import sleep -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Union import __main__ import numpy as np @@ -50,10 +50,11 @@ rank_zero_deprecation, rank_zero_warn, ) -from pytorch_lightning.utilities.distributed import distributed_available +from pytorch_lightning.utilities.distributed import _info, distributed_available from pytorch_lightning.utilities.distributed import group as _group from pytorch_lightning.utilities.distributed import ( init_dist_connection, + rank_zero_info, rank_zero_only, ReduceOp, sync_ddp_if_available, @@ -61,7 +62,7 @@ from pytorch_lightning.utilities.enums import DistributedType from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT if _FAIRSCALE_AVAILABLE: from fairscale.optim import OSS @@ -113,6 +114,7 @@ def __init__( self._pids: Optional[List[int]] = None self._sync_dir: Optional[str] = None self._rank_0_has_called_call_children_scripts: bool = False + self._has_loaded_state_dict: bool = False self.set_world_ranks() @property @@ -529,3 +531,31 @@ def teardown(self) -> None: self.lightning_module.cpu() # clean up memory torch.cuda.empty_cache() + + self._has_loaded_state_dict = False + + def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + if self._has_loaded_state_dict: + return + self.lightning_module.load_state_dict(checkpoint["state_dict"]) + + def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: + rank_zero_info( + f"DistributedDataParallel has {self.num_processes} processes. " + "Serializing checkpoint loading to avoid CPU OOMs." + ) + for current_worker in range(self.num_processes): + if self.local_rank == current_worker: + checkpoint = super().load_checkpoint(checkpoint_path) + self.lightning_module.on_load_checkpoint(checkpoint) + # call hpc specific hook + if self.lightning_module.trainer.checkpoint_connector.hpc_resume_path is not None: + self.lightning_module.on_hpc_load(checkpoint) + self.load_model_state_dict(checkpoint) + # move the model to the correct device in order to release state_dict memory + self.model_to_device() + del checkpoint["state_dict"] + self._has_loaded_state_dict = True + _info(f"Rank {self.global_rank}: done loading model states from {checkpoint_path}.") + self.barrier() + return checkpoint diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index a4be698dc6602..6f03393a44849 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -791,6 +791,8 @@ def lightning_restore_optimizer_and_schedulers(self) -> bool: return False def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + if self._has_loaded_state_dict: + return # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()` if self.load_full_weights and self.zero_stage_3: self.model_to_device() diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index a1e55631c2a7d..d04af6c58b261 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -190,6 +190,10 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: return self.checkpoint_io.load_checkpoint(checkpoint_path) def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + self.lightning_module.on_load_checkpoint(checkpoint) + # call hpc specific hook + if self.lightning_module.trainer.checkpoint_connector.hpc_resume_path is not None: + self.lightning_module.on_hpc_load(checkpoint) self.lightning_module.load_state_dict(checkpoint["state_dict"]) def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 921c2e0a7e160..a8003549a8ff0 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -155,15 +155,6 @@ def restore_model(self) -> None: if not self._loaded_checkpoint: return - model = self.trainer.lightning_module - - # hook: give user access to checkpoint if needed. - model.on_load_checkpoint(self._loaded_checkpoint) - - # call hpc specific hook - if self.hpc_resume_path is not None: - model.on_hpc_load(self._loaded_checkpoint) - # restore model state_dict self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint) diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py index 6b408845ed879..6712814a94e86 100644 --- a/tests/trainer/connectors/test_checkpoint_connector.py +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -77,6 +77,7 @@ def test_preloaded_checkpoint_lifecycle(tmpdir): ckpt_path = trainer.checkpoint_callback.best_model_path trainer = Trainer(default_root_dir=tmpdir, max_steps=2) connector = trainer.checkpoint_connector + trainer.accelerator.connect(model) connector.resume_start(ckpt_path) assert connector.resume_checkpoint_path == ckpt_path assert connector._loaded_checkpoint