@@ -49,25 +49,29 @@ def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH
49
49
self ._loaded_checkpoint : Dict [str , Any ] = {}
50
50
51
51
@property
52
- def hpc_resume_path (self ) -> Optional [str ]:
52
+ def _hpc_resume_path (self ) -> Optional [str ]:
53
53
if not os .path .isdir (self .trainer .weights_save_path ):
54
54
return None
55
55
dir_path_hpc = str (self .trainer .weights_save_path )
56
56
max_version = self .max_ckpt_version_in_folder (dir_path_hpc , "hpc_ckpt_" )
57
57
if max_version is not None :
58
58
return os .path .join (dir_path_hpc , f"hpc_ckpt_{ max_version } .ckpt" )
59
- auto_save_checkpoint = os .path .join (dir_path_hpc , ".pl_auto_save.ckpt" )
60
- if os .path .exists (auto_save_checkpoint ):
61
- return auto_save_checkpoint
59
+
60
+ @property
61
+ def _fault_tolerant_auto_resume_path (self ) -> Optional [str ]:
62
+ auto_saved_path = os .path .join (str (self .trainer .weights_save_path ), ".pl_auto_save.ckpt" )
63
+ if os .path .exists (auto_saved_path ):
64
+ return auto_saved_path
62
65
63
66
def resume_start (self , checkpoint_path : Optional [_PATH ] = None ) -> None :
64
67
"""Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:
65
68
66
69
1. from HPC weights if found
67
- 2. from `checkpoint_path` file if provided
68
- 3. don't restore
70
+ 2. from fault-tolerant auto-saved checkpoint if found
71
+ 3. from `checkpoint_path` file if provided
72
+ 4. don't restore
69
73
"""
70
- self .resume_checkpoint_path = self .hpc_resume_path or checkpoint_path
74
+ self .resume_checkpoint_path = self ._hpc_resume_path or self . _fault_tolerant_auto_resume_path or checkpoint_path
71
75
checkpoint_path = self .resume_checkpoint_path
72
76
if not checkpoint_path :
73
77
return
@@ -162,7 +166,7 @@ def restore_model(self) -> None:
162
166
163
167
# TODO: remove this in v1.8.
164
168
# call hpc specific hook
165
- if self .hpc_resume_path is not None :
169
+ if self ._hpc_resume_path is not None :
166
170
model .on_hpc_load (self ._loaded_checkpoint )
167
171
168
172
# restore model state_dict
0 commit comments