Skip to content

Commit c5c5afc

Browse files
committed
Revert "Revert "Support serialized checkpoint loading (#9605)" (#10057)"
This reverts commit 2d9db21.
1 parent 5a7fbb6 commit c5c5afc

File tree

6 files changed

+39
-10
lines changed

6 files changed

+39
-10
lines changed

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import time
2222
from pathlib import Path
2323
from time import sleep
24-
from typing import Any, Dict, List, Optional, Union
24+
from typing import Any, Dict, List, Mapping, Optional, Union
2525

2626
import __main__
2727
import numpy as np
@@ -51,10 +51,16 @@
5151
)
5252
from pytorch_lightning.utilities.distributed import distributed_available
5353
from pytorch_lightning.utilities.distributed import group as _group
54-
from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available
54+
from pytorch_lightning.utilities.distributed import (
55+
init_ddp_connection,
56+
rank_zero_info,
57+
rank_zero_only,
58+
ReduceOp,
59+
sync_ddp_if_available,
60+
)
5561
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
5662
from pytorch_lightning.utilities.seed import reset_seed
57-
from pytorch_lightning.utilities.types import STEP_OUTPUT
63+
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
5864

5965
if _TORCH_GREATER_EQUAL_1_10:
6066
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
@@ -123,6 +129,7 @@ def __init__(
123129
self._pids: Optional[List[int]] = None
124130
self._sync_dir: Optional[str] = None
125131
self._rank_0_has_called_call_children_scripts: bool = False
132+
self._has_loaded_state_dict: bool = False
126133
self.set_world_ranks()
127134

128135
@property
@@ -533,3 +540,26 @@ def teardown(self) -> None:
533540
self.lightning_module.cpu()
534541
# clean up memory
535542
torch.cuda.empty_cache()
543+
544+
self._has_loaded_state_dict = False
545+
546+
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
547+
if "state_dict" not in checkpoint and self._has_loaded_state_dict:
548+
return
549+
self.lightning_module.load_state_dict(checkpoint["state_dict"])
550+
551+
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
552+
rank_zero_info(
553+
f"DistributedDataParallel has {self.num_processes} processes. "
554+
"Serializing checkpoint loading to avoid CPU OOMs."
555+
)
556+
for current_worker in range(self.num_processes):
557+
if self.local_rank == current_worker:
558+
checkpoint = super().load_checkpoint(checkpoint_path)
559+
self.lightning_module.on_load_checkpoint(checkpoint)
560+
self.load_model_state_dict(checkpoint)
561+
del checkpoint["state_dict"]
562+
self._has_loaded_state_dict = True
563+
log.info(f"Rank {self.global_rank}: done loading model states from {checkpoint_path}.")
564+
self.barrier()
565+
return checkpoint

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,8 @@ def lightning_restore_optimizer_and_schedulers(self) -> bool:
781781
return False
782782

783783
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
784+
if "state_dict" not in checkpoint and self._has_loaded_state_dict:
785+
return
784786
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()`
785787
if self.load_full_weights and self.zero_stage_3:
786788
self.model_to_device()

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
195195
return self.checkpoint_io.load_checkpoint(checkpoint_path)
196196

197197
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
198+
self.lightning_module.on_load_checkpoint(checkpoint)
198199
self.lightning_module.load_state_dict(checkpoint["state_dict"])
199200

200201
def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,6 @@ def restore_model(self) -> None:
141141

142142
model = self.trainer.lightning_module
143143

144-
# hook: give user access to checkpoint if needed.
145-
model.on_load_checkpoint(self._loaded_checkpoint)
146-
147144
# call hpc specific hook
148145
if self.hpc_resume_path is not None:
149146
model.on_hpc_load(self._loaded_checkpoint)
@@ -163,7 +160,6 @@ def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None:
163160
if checkpoint_path is not None:
164161
checkpoint = self._load_and_validate_checkpoint(checkpoint_path)
165162

166-
self.trainer.lightning_module.on_load_checkpoint(checkpoint)
167163
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)
168164

169165
def restore_training_state(self) -> None:

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,9 +1029,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
10291029
self.data_connector.prepare_data()
10301030
self.callback_connector._attach_model_callbacks()
10311031

1032-
if self._ckpt_path and not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
1033-
self._load_checkpoint_weights()
1034-
10351032
# ----------------------------
10361033
# SET UP TRAINING
10371034
# ----------------------------
@@ -1041,6 +1038,8 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
10411038

10421039
# check if we should delay restoring checkpoint till later
10431040
if not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
1041+
if self._ckpt_path:
1042+
self._load_checkpoint_weights()
10441043
self.checkpoint_connector.resume_start()
10451044
self._restore_modules_and_callbacks()
10461045

tests/trainer/connectors/test_checkpoint_connector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def test_preloaded_checkpoint_lifecycle(tmpdir):
7878
ckpt_path = trainer.checkpoint_callback.best_model_path
7979
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, resume_from_checkpoint=ckpt_path)
8080
connector = trainer.checkpoint_connector
81+
trainer.accelerator.connect(model)
8182
connector.resume_start()
8283
assert connector.resume_checkpoint_path == ckpt_path
8384
assert connector._loaded_checkpoint

0 commit comments

Comments
 (0)