diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index a37cffb320227..ede6cb2a7e286 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -49,25 +49,29 @@ def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH self._loaded_checkpoint: Dict[str, Any] = {} @property - def hpc_resume_path(self) -> Optional[str]: + def _hpc_resume_path(self) -> Optional[str]: if not os.path.isdir(self.trainer.weights_save_path): return None dir_path_hpc = str(self.trainer.weights_save_path) max_version = self.max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_") if max_version is not None: return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt") - auto_save_checkpoint = os.path.join(dir_path_hpc, ".pl_auto_save.ckpt") - if os.path.exists(auto_save_checkpoint): - return auto_save_checkpoint + + @property + def _fault_tolerant_auto_resume_path(self) -> Optional[str]: + auto_saved_path = os.path.join(str(self.trainer.weights_save_path), ".pl_auto_save.ckpt") + if os.path.exists(auto_saved_path): + return auto_saved_path def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: """Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: 1. from HPC weights if found - 2. from `checkpoint_path` file if provided - 3. don't restore + 2. from fault-tolerant auto-saved checkpoint if found + 3. from `checkpoint_path` file if provided + 4. don't restore """ - self.resume_checkpoint_path = self.hpc_resume_path or checkpoint_path + self.resume_checkpoint_path = self._hpc_resume_path or self._fault_tolerant_auto_resume_path or checkpoint_path checkpoint_path = self.resume_checkpoint_path if not checkpoint_path: return @@ -162,7 +166,7 @@ def restore_model(self) -> None: # TODO: remove this in v1.8. # call hpc specific hook - if self.hpc_resume_path is not None: + if self._hpc_resume_path is not None: model.on_hpc_load(self._loaded_checkpoint) # restore model state_dict diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py index 1643d8f43cb6b..c0153638cc534 100644 --- a/tests/trainer/connectors/test_checkpoint_connector.py +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -133,7 +133,7 @@ def test_hpc_max_ckpt_version(tmpdir): trainer.save_checkpoint(tmpdir / "hpc_ckpt_3.ckpt") trainer.save_checkpoint(tmpdir / "hpc_ckpt_33.ckpt") - assert trainer.checkpoint_connector.hpc_resume_path == str(tmpdir / "hpc_ckpt_33.ckpt") + assert trainer.checkpoint_connector._hpc_resume_path == str(tmpdir / "hpc_ckpt_33.ckpt") assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir) == 33 assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir / "not" / "existing") is None