Skip to content

Commit b4435bd

Browse files
bilelomrani1pre-commit-ci[bot]awaelchli
authored
Fix Google Cloud Storage checkpointing (#18088)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 3fd24f9 commit b4435bd

File tree

9 files changed

+128
-15
lines changed

9 files changed

+128
-15
lines changed

src/lightning/fabric/loggers/csv_logs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torch import Tensor
2222

2323
from lightning.fabric.loggers.logger import Logger, rank_zero_experiment
24-
from lightning.fabric.utilities.cloud_io import get_filesystem
24+
from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem
2525
from lightning.fabric.utilities.logger import _add_prefix
2626
from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn
2727
from lightning.fabric.utilities.types import _PATH
@@ -155,15 +155,15 @@ def finalize(self, status: str) -> None:
155155
def _get_next_version(self) -> int:
156156
versions_root = os.path.join(self._root_dir, self.name)
157157

158-
if not self._fs.isdir(versions_root):
158+
if not _is_dir(self._fs, versions_root, strict=True):
159159
log.warning("Missing logger folder: %s", versions_root)
160160
return 0
161161

162162
existing_versions = []
163163
for d in self._fs.listdir(versions_root):
164164
full_path = d["name"]
165165
name = os.path.basename(full_path)
166-
if self._fs.isdir(full_path) and name.startswith("version_"):
166+
if _is_dir(self._fs, full_path) and name.startswith("version_"):
167167
existing_versions.append(int(name.split("_")[1]))
168168

169169
if len(existing_versions) == 0:

src/lightning/fabric/loggers/tensorboard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torch.nn import Module
2323

2424
from lightning.fabric.loggers.logger import Logger, rank_zero_experiment
25-
from lightning.fabric.utilities.cloud_io import get_filesystem
25+
from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem
2626
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict
2727
from lightning.fabric.utilities.logger import _sanitize_params as _utils_sanitize_params
2828
from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn
@@ -294,7 +294,7 @@ def _get_next_version(self) -> int:
294294
for listing in listdir_info:
295295
d = listing["name"]
296296
bn = os.path.basename(d)
297-
if self._fs.isdir(d) and bn.startswith("version_"):
297+
if _is_dir(self._fs, d) and bn.startswith("version_"):
298298
dir_ver = bn.split("_")[1].replace("/", "")
299299
existing_versions.append(int(dir_ver))
300300
if len(existing_versions) == 0:

src/lightning/fabric/utilities/cloud_io.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Utilities related to data saving/loading."""
15-
1615
import io
1716
from pathlib import Path
1817
from typing import Any, Dict, IO, Union
@@ -21,6 +20,7 @@
2120
import torch
2221
from fsspec.core import url_to_fs
2322
from fsspec.implementations.local import AbstractFileSystem
23+
from lightning_utilities.core.imports import module_available
2424

2525
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
2626

@@ -70,3 +70,54 @@ def _atomic_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None
7070
torch.save(checkpoint, bytesbuffer)
7171
with fsspec.open(filepath, "wb") as f:
7272
f.write(bytesbuffer.getvalue())
73+
74+
75+
def _is_object_storage(fs: AbstractFileSystem) -> bool:
76+
if module_available("adlfs"):
77+
from adlfs import AzureBlobFileSystem
78+
79+
if isinstance(fs, AzureBlobFileSystem):
80+
return True
81+
82+
if module_available("gcsfs"):
83+
from adlfs import GCSFileSystem
84+
85+
if isinstance(fs, GCSFileSystem):
86+
return True
87+
88+
if module_available("s3fs"):
89+
from s3fs import S3FileSystem
90+
91+
if isinstance(fs, S3FileSystem):
92+
return True
93+
94+
return False
95+
96+
97+
def _is_dir(fs: AbstractFileSystem, path: Union[str, Path], strict: bool = False) -> bool:
98+
"""Check if a path is directory-like.
99+
100+
This function determines if a given path is considered directory-like, taking into account the behavior
101+
specific to object storage platforms. For other filesystems, it behaves similarly to the standard `fs.isdir`
102+
method.
103+
104+
Args:
105+
fs: The filesystem to check the path against.
106+
path: The path or URL to be checked.
107+
strict: A flag specific to Object Storage platforms. If set to ``False``, any non-existing path is considered
108+
as a valid directory-like path. In such cases, the directory (and any non-existing parent directories)
109+
will be created on the fly. Defaults to False.
110+
"""
111+
# Object storage fsspec's are inconsistent with other file systems because they do not have real directories,
112+
# see for instance https://gcsfs.readthedocs.io/en/latest/api.html?highlight=makedirs#gcsfs.core.GCSFileSystem.mkdir
113+
# In particular, `fs.makedirs` is a no-op so we use `strict=False` to consider any path as valid, except if the
114+
# path already exists but is a file
115+
if _is_object_storage(fs):
116+
if strict:
117+
return fs.isdir(path)
118+
119+
# Check if the path is not already taken by a file. If not, it is considered a valid directory-like path
120+
# because the directory (and all non-existing parent directories) will be created on the fly.
121+
return not fs.isfile(path)
122+
123+
return fs.isdir(path)

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
174174

175175
- `LightningCLI` not saving correctly `seed_everything` when `run=True` and `seed_everything=True` ([#18056](https://github.com/Lightning-AI/lightning/pull/18056))
176176

177+
- Fixed a `Missing folder` exception when using a Google Storage URL as a `default_root_dir` ([#18088](https://github.com/Lightning-AI/lightning/pull/18088))
177178

178179
- Fixed an issue that prevented the use of custom logger classes without an `experiment` property defined ([#18093](https://github.com/Lightning-AI/lightning/pull/18093))
179180

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from torch import Tensor
3333

3434
import lightning.pytorch as pl
35-
from lightning.fabric.utilities.cloud_io import get_filesystem
35+
from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem
3636
from lightning.fabric.utilities.types import _PATH
3737
from lightning.pytorch.callbacks import Checkpoint
3838
from lightning.pytorch.utilities.exceptions import MisconfigurationException
@@ -620,7 +620,7 @@ def _find_last_checkpoints(self, trainer: "pl.Trainer") -> Set[str]:
620620
return set()
621621

622622
def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
623-
if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
623+
if self.save_top_k != 0 and _is_dir(self._fs, dirpath, strict=True) and len(self._fs.ls(dirpath)) > 0:
624624
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
625625

626626
def _get_metric_interpolated_filepath_name(

src/lightning/pytorch/core/saving.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from lightning_utilities.core.apply_func import apply_to_collection
3131

3232
import lightning.pytorch as pl
33+
from lightning.fabric.utilities.cloud_io import _is_dir
3334
from lightning.fabric.utilities.cloud_io import _load as pl_load
3435
from lightning.fabric.utilities.cloud_io import get_filesystem
3536
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
@@ -252,7 +253,7 @@ def load_hparams_from_tags_csv(tags_csv: _PATH) -> Dict[str, Any]:
252253

253254
def save_hparams_to_tags_csv(tags_csv: _PATH, hparams: Union[dict, Namespace]) -> None:
254255
fs = get_filesystem(tags_csv)
255-
if not fs.isdir(os.path.dirname(tags_csv)):
256+
if not _is_dir(fs, os.path.dirname(tags_csv)):
256257
raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.")
257258

258259
if isinstance(hparams, Namespace):
@@ -306,7 +307,7 @@ def save_hparams_to_yaml(config_yaml: _PATH, hparams: Union[dict, Namespace], us
306307
307308
"""
308309
fs = get_filesystem(config_yaml)
309-
if not fs.isdir(os.path.dirname(config_yaml)):
310+
if not _is_dir(fs, os.path.dirname(config_yaml)):
310311
raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")
311312

312313
# convert Namespace or AD to dict

src/lightning/pytorch/loggers/tensorboard.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import lightning.pytorch as pl
2727
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
2828
from lightning.fabric.loggers.tensorboard import TensorBoardLogger as FabricTensorBoardLogger
29+
from lightning.fabric.utilities.cloud_io import _is_dir
2930
from lightning.fabric.utilities.logger import _convert_params
3031
from lightning.fabric.utilities.types import _PATH
3132
from lightning.pytorch.callbacks import ModelCheckpoint
@@ -217,7 +218,7 @@ def save(self) -> None:
217218
hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE)
218219

219220
# save the metatags file if it doesn't exist and the log directory exists
220-
if self._fs.isdir(dir_path) and not self._fs.isfile(hparams_file):
221+
if _is_dir(self._fs, dir_path) and not self._fs.isfile(hparams_file):
221222
save_hparams_to_yaml(hparams_file, self.hparams)
222223

223224
@rank_zero_only
@@ -248,7 +249,7 @@ def _get_next_version(self) -> int:
248249
for listing in listdir_info:
249250
d = listing["name"]
250251
bn = os.path.basename(d)
251-
if self._fs.isdir(d) and bn.startswith("version_"):
252+
if _is_dir(self._fs, d) and bn.startswith("version_"):
252253
dir_ver = bn.split("_")[1].replace("/", "")
253254
existing_versions.append(int(dir_ver))
254255
if len(existing_versions) == 0:

src/lightning/pytorch/trainer/connectors/checkpoint_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import lightning.pytorch as pl
2525
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
26-
from lightning.fabric.utilities.cloud_io import get_filesystem
26+
from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem
2727
from lightning.fabric.utilities.types import _PATH
2828
from lightning.pytorch.callbacks import ModelCheckpoint
2929
from lightning.pytorch.plugins.precision import MixedPrecisionPlugin
@@ -55,7 +55,7 @@ def _hpc_resume_path(self) -> Optional[str]:
5555
dir_path_hpc = self.trainer.default_root_dir
5656
dir_path_hpc = str(dir_path_hpc)
5757
fs, path = url_to_fs(dir_path_hpc)
58-
if not fs.isdir(path):
58+
if not _is_dir(fs, path):
5959
return None
6060
max_version = self.__max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_")
6161
if max_version is not None:

tests/tests_fabric/utilities/test_cloud_io.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515

1616
import fsspec
1717
from fsspec.implementations.local import LocalFileSystem
18+
from fsspec.spec import AbstractFileSystem
1819

19-
from lightning.fabric.utilities.cloud_io import get_filesystem
20+
from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem
2021

2122

2223
def test_get_filesystem_custom_filesystem():
@@ -32,3 +33,61 @@ class DummyFileSystem(LocalFileSystem):
3233

3334
def test_get_filesystem_local_filesystem():
3435
assert isinstance(get_filesystem("tmpdir/tmp_file"), LocalFileSystem)
36+
37+
38+
def test_is_dir_with_local_filesystem(tmp_path):
39+
fs = LocalFileSystem()
40+
tmp_existing_directory = tmp_path
41+
tmp_non_existing_directory = tmp_path / "non_existing"
42+
43+
assert _is_dir(fs, tmp_existing_directory)
44+
assert not _is_dir(fs, tmp_non_existing_directory)
45+
46+
47+
def test_is_dir_with_object_storage_filesystem():
48+
class MockAzureBlobFileSystem(AbstractFileSystem):
49+
def isdir(self, path):
50+
return path.startswith("azure://") and not path.endswith(".txt")
51+
52+
def isfile(self, path):
53+
return path.startswith("azure://") and path.endswith(".txt")
54+
55+
class MockGCSFileSystem(AbstractFileSystem):
56+
def isdir(self, path):
57+
return path.startswith("gcs://") and not path.endswith(".txt")
58+
59+
def isfile(self, path):
60+
return path.startswith("gcs://") and path.endswith(".txt")
61+
62+
class MockS3FileSystem(AbstractFileSystem):
63+
def isdir(self, path):
64+
return path.startswith("s3://") and not path.endswith(".txt")
65+
66+
def isfile(self, path):
67+
return path.startswith("s3://") and path.endswith(".txt")
68+
69+
fsspec.register_implementation("azure", MockAzureBlobFileSystem, clobber=True)
70+
fsspec.register_implementation("gcs", MockGCSFileSystem, clobber=True)
71+
fsspec.register_implementation("s3", MockS3FileSystem, clobber=True)
72+
73+
azure_directory = "azure://container/directory/"
74+
azure_file = "azure://container/file.txt"
75+
gcs_directory = "gcs://bucket/directory/"
76+
gcs_file = "gcs://bucket/file.txt"
77+
s3_directory = "s3://bucket/directory/"
78+
s3_file = "s3://bucket/file.txt"
79+
80+
assert _is_dir(get_filesystem(azure_directory), azure_directory)
81+
assert _is_dir(get_filesystem(azure_directory), azure_directory, strict=True)
82+
assert not _is_dir(get_filesystem(azure_directory), azure_file)
83+
assert not _is_dir(get_filesystem(azure_directory), azure_file, strict=True)
84+
85+
assert _is_dir(get_filesystem(gcs_directory), gcs_directory)
86+
assert _is_dir(get_filesystem(gcs_directory), gcs_directory, strict=True)
87+
assert not _is_dir(get_filesystem(gcs_directory), gcs_file)
88+
assert not _is_dir(get_filesystem(gcs_directory), gcs_file, strict=True)
89+
90+
assert _is_dir(get_filesystem(s3_directory), s3_directory)
91+
assert _is_dir(get_filesystem(s3_directory), s3_directory, strict=True)
92+
assert not _is_dir(get_filesystem(s3_directory), s3_file)
93+
assert not _is_dir(get_filesystem(s3_directory), s3_file, strict=True)

0 commit comments

Comments
 (0)