Skip to content

Commit 2d9db21

Browse files
authored
Revert "Support serialized checkpoint loading (#9605)" (#10057)
This reverts commit f0e6f1b.
1 parent aa15404 commit 2d9db21

File tree

6 files changed

+10
-39
lines changed

6 files changed

+10
-39
lines changed

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import time
2222
from pathlib import Path
2323
from time import sleep
24-
from typing import Any, Dict, List, Mapping, Optional, Union
24+
from typing import Any, Dict, List, Optional, Union
2525

2626
import __main__
2727
import numpy as np
@@ -51,16 +51,10 @@
5151
)
5252
from pytorch_lightning.utilities.distributed import distributed_available
5353
from pytorch_lightning.utilities.distributed import group as _group
54-
from pytorch_lightning.utilities.distributed import (
55-
init_ddp_connection,
56-
rank_zero_info,
57-
rank_zero_only,
58-
ReduceOp,
59-
sync_ddp_if_available,
60-
)
54+
from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available
6155
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
6256
from pytorch_lightning.utilities.seed import reset_seed
63-
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
57+
from pytorch_lightning.utilities.types import STEP_OUTPUT
6458

6559
if _TORCH_GREATER_EQUAL_1_10:
6660
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
@@ -129,7 +123,6 @@ def __init__(
129123
self._pids: Optional[List[int]] = None
130124
self._sync_dir: Optional[str] = None
131125
self._rank_0_has_called_call_children_scripts: bool = False
132-
self._has_loaded_state_dict: bool = False
133126
self.set_world_ranks()
134127

135128
@property
@@ -540,26 +533,3 @@ def teardown(self) -> None:
540533
self.lightning_module.cpu()
541534
# clean up memory
542535
torch.cuda.empty_cache()
543-
544-
self._has_loaded_state_dict = False
545-
546-
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
547-
if "state_dict" not in checkpoint and self._has_loaded_state_dict:
548-
return
549-
self.lightning_module.load_state_dict(checkpoint["state_dict"])
550-
551-
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
552-
rank_zero_info(
553-
f"DistributedDataParallel has {self.num_processes} processes. "
554-
"Serializing checkpoint loading to avoid CPU OOMs."
555-
)
556-
for current_worker in range(self.num_processes):
557-
if self.local_rank == current_worker:
558-
checkpoint = super().load_checkpoint(checkpoint_path)
559-
self.lightning_module.on_load_checkpoint(checkpoint)
560-
self.load_model_state_dict(checkpoint)
561-
del checkpoint["state_dict"]
562-
self._has_loaded_state_dict = True
563-
log.info(f"Rank {self.global_rank}: done loading model states from {checkpoint_path}.")
564-
self.barrier()
565-
return checkpoint

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -781,8 +781,6 @@ def lightning_restore_optimizer_and_schedulers(self) -> bool:
781781
return False
782782

783783
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
784-
if "state_dict" not in checkpoint and self._has_loaded_state_dict:
785-
return
786784
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()`
787785
if self.load_full_weights and self.zero_stage_3:
788786
self.model_to_device()

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
195195
return self.checkpoint_io.load_checkpoint(checkpoint_path)
196196

197197
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
198-
self.lightning_module.on_load_checkpoint(checkpoint)
199198
self.lightning_module.load_state_dict(checkpoint["state_dict"])
200199

201200
def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ def restore_model(self) -> None:
141141

142142
model = self.trainer.lightning_module
143143

144+
# hook: give user access to checkpoint if needed.
145+
model.on_load_checkpoint(self._loaded_checkpoint)
146+
144147
# call hpc specific hook
145148
if self.hpc_resume_path is not None:
146149
model.on_hpc_load(self._loaded_checkpoint)
@@ -160,6 +163,7 @@ def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None:
160163
if checkpoint_path is not None:
161164
checkpoint = self._load_and_validate_checkpoint(checkpoint_path)
162165

166+
self.trainer.lightning_module.on_load_checkpoint(checkpoint)
163167
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)
164168

165169
def restore_training_state(self) -> None:

pytorch_lightning/trainer/trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,6 +1029,9 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
10291029
self.data_connector.prepare_data()
10301030
self.callback_connector._attach_model_callbacks()
10311031

1032+
if self._ckpt_path and not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
1033+
self._load_checkpoint_weights()
1034+
10321035
# ----------------------------
10331036
# SET UP TRAINING
10341037
# ----------------------------
@@ -1038,8 +1041,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
10381041

10391042
# check if we should delay restoring checkpoint till later
10401043
if not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
1041-
if self._ckpt_path:
1042-
self._load_checkpoint_weights()
10431044
self.checkpoint_connector.resume_start()
10441045
self._restore_modules_and_callbacks()
10451046

tests/trainer/connectors/test_checkpoint_connector.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def test_preloaded_checkpoint_lifecycle(tmpdir):
7878
ckpt_path = trainer.checkpoint_callback.best_model_path
7979
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, resume_from_checkpoint=ckpt_path)
8080
connector = trainer.checkpoint_connector
81-
trainer.accelerator.connect(model)
8281
connector.resume_start()
8382
assert connector.resume_checkpoint_path == ckpt_path
8483
assert connector._loaded_checkpoint

0 commit comments

Comments
 (0)