Skip to content

Commit 31f39c9

Browse files
authored
Move CheckpointConnector.fault_tolerant_auto_save_path out of CheckpointConnector.hpc_resume_path (#11092)
1 parent 787f41e commit 31f39c9

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,25 +49,29 @@ def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH
4949
self._loaded_checkpoint: Dict[str, Any] = {}
5050

5151
@property
52-
def hpc_resume_path(self) -> Optional[str]:
52+
def _hpc_resume_path(self) -> Optional[str]:
5353
if not os.path.isdir(self.trainer.weights_save_path):
5454
return None
5555
dir_path_hpc = str(self.trainer.weights_save_path)
5656
max_version = self.max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_")
5757
if max_version is not None:
5858
return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt")
59-
auto_save_checkpoint = os.path.join(dir_path_hpc, ".pl_auto_save.ckpt")
60-
if os.path.exists(auto_save_checkpoint):
61-
return auto_save_checkpoint
59+
60+
@property
61+
def _fault_tolerant_auto_resume_path(self) -> Optional[str]:
62+
auto_saved_path = os.path.join(str(self.trainer.weights_save_path), ".pl_auto_save.ckpt")
63+
if os.path.exists(auto_saved_path):
64+
return auto_saved_path
6265

6366
def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
6467
"""Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:
6568
6669
1. from HPC weights if found
67-
2. from `checkpoint_path` file if provided
68-
3. don't restore
70+
2. from fault-tolerant auto-saved checkpoint if found
71+
3. from `checkpoint_path` file if provided
72+
4. don't restore
6973
"""
70-
self.resume_checkpoint_path = self.hpc_resume_path or checkpoint_path
74+
self.resume_checkpoint_path = self._hpc_resume_path or self._fault_tolerant_auto_resume_path or checkpoint_path
7175
checkpoint_path = self.resume_checkpoint_path
7276
if not checkpoint_path:
7377
return
@@ -162,7 +166,7 @@ def restore_model(self) -> None:
162166

163167
# TODO: remove this in v1.8.
164168
# call hpc specific hook
165-
if self.hpc_resume_path is not None:
169+
if self._hpc_resume_path is not None:
166170
model.on_hpc_load(self._loaded_checkpoint)
167171

168172
# restore model state_dict

tests/trainer/connectors/test_checkpoint_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def test_hpc_max_ckpt_version(tmpdir):
133133
trainer.save_checkpoint(tmpdir / "hpc_ckpt_3.ckpt")
134134
trainer.save_checkpoint(tmpdir / "hpc_ckpt_33.ckpt")
135135

136-
assert trainer.checkpoint_connector.hpc_resume_path == str(tmpdir / "hpc_ckpt_33.ckpt")
136+
assert trainer.checkpoint_connector._hpc_resume_path == str(tmpdir / "hpc_ckpt_33.ckpt")
137137
assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir) == 33
138138
assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir / "not" / "existing") is None
139139

0 commit comments

Comments
 (0)