Skip to content

Commit d0b67f7

Browse files
authored
Clean up last ModelCheckpoint makedirs call to IOPlugin (#11035)
1 parent 7aee00c commit d0b67f7

File tree

4 files changed

+5
-11
lines changed

4 files changed

+5
-11
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def state_key(self) -> str:
249249
)
250250

251251
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
252-
"""When pretrain routine starts we build the ckpt dir on the fly."""
252+
"""When pretrain routine starts we resolve the ckpt dir on the fly."""
253253
if self._save_on_train_epoch_end is None:
254254
# if the user runs validation multiple times per training epoch or multiple training epochs without
255255
# validation, then we run after validation instead of on train epoch end
@@ -600,9 +600,6 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
600600

601601
self.dirpath = ckpt_path
602602

603-
if not trainer.fast_dev_run and trainer.training_type_plugin.should_rank_save_checkpoint:
604-
self._fs.makedirs(self.dirpath, exist_ok=True)
605-
606603
def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
607604
if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
608605
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

pytorch_lightning/plugins/io/xla_plugin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import os
1415
from typing import Any, Dict, Optional
1516

1617
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
1718
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE
1819
from pytorch_lightning.utilities.apply_func import apply_to_collection
20+
from pytorch_lightning.utilities.cloud_io import get_filesystem
1921
from pytorch_lightning.utilities.types import _PATH
2022

2123
if _TPU_AVAILABLE:
@@ -36,6 +38,8 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
3638
path: write-target path
3739
storage_options: Optional parameters when saving the model/training states.
3840
"""
41+
fs = get_filesystem(path)
42+
fs.makedirs(os.path.dirname(path), exist_ok=True)
3943
# Todo: TypeError: 'mappingproxy' object does not support item assignment
4044
# Ref: https://github.com/pytorch/xla/issues/2773
4145
if _OMEGACONF_AVAILABLE:

pytorch_lightning/plugins/training_type/single_tpu.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,6 @@ def pre_dispatch(self, trainer: "pl.Trainer") -> None:
7575
self.tpu_local_core_rank = xm.get_local_ordinal()
7676
self.tpu_global_core_rank = xm.get_ordinal()
7777

78-
def save(self, state_dict: Dict, path: _PATH) -> None:
79-
xm.save(state_dict, path)
80-
8178
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: _PATH) -> None:
8279
"""Save model/training states as a checkpoint file through state-dump and file-write.
8380

pytorch_lightning/trainer/trainer.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,10 +1700,6 @@ def world_size(self) -> int:
17001700
# some training types define a world size
17011701
return getattr(self.training_type_plugin, "world_size", 1)
17021702

1703-
@property
1704-
def should_rank_save_checkpoint(self) -> bool:
1705-
return self.training_type_plugin.should_rank_save_checkpoint
1706-
17071703
@property
17081704
def _distrib_type(self) -> _StrategyType:
17091705
return self._accelerator_connector._distrib_type

0 commit comments

Comments
 (0)