Skip to content

Commit d521f2b

Browse files
store: mock/fixture home (#16536)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 879701f commit d521f2b

File tree

7 files changed

+95
-84
lines changed

7 files changed

+95
-84
lines changed

src/lightning/store/save.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import os
1919
import shutil
2020
import tarfile
21-
from pathlib import PurePath
21+
from pathlib import Path, PurePath
2222

2323
import requests
2424
import torch
@@ -30,9 +30,11 @@
3030

3131
logging.basicConfig(level=logging.INFO)
3232

33-
_LIGHTNING_DIR = os.path.join(os.path.expanduser("~"), ".lightning")
34-
_LIGHTNING_STORAGE_FILE = os.path.join(_LIGHTNING_DIR, ".model_storage")
35-
_LIGHTNING_STORAGE_DIR = os.path.join(_LIGHTNING_DIR, "model_store")
33+
_LIGHTNING_DIR = os.path.join(Path.home(), ".lightning")
34+
__STORAGE_FILE_NAME = ".model_storage"
35+
_LIGHTNING_STORAGE_FILE = os.path.join(_LIGHTNING_DIR, __STORAGE_FILE_NAME)
36+
__STORAGE_DIR_NAME = "model_store"
37+
_LIGHTNING_STORAGE_DIR = os.path.join(_LIGHTNING_DIR, __STORAGE_DIR_NAME)
3638

3739

3840
def _check_id(id: str) -> str:

tests/tests_cloud/conftest.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import importlib
2+
import os
3+
import tempfile
4+
import types
5+
from pathlib import Path
6+
7+
import pytest
8+
9+
import lightning
10+
import lightning.store
11+
12+
13+
def reload_package(package):
14+
# credit: https://stackoverflow.com/a/28516918/4521646
15+
assert hasattr(package, "__package__")
16+
fn = package.__file__
17+
fn_dir = os.path.dirname(fn) + os.sep
18+
module_visit = {fn}
19+
del fn
20+
21+
def reload_recursive_ex(module):
22+
importlib.reload(module)
23+
24+
for module_child in vars(module).values():
25+
if not isinstance(module_child, types.ModuleType):
26+
continue
27+
fn_child = getattr(module_child, "__file__", None)
28+
if (fn_child is not None) and fn_child.startswith(fn_dir) and fn_child not in module_visit:
29+
# print("reloading:", fn_child, "from", module)
30+
module_visit.add(fn_child)
31+
reload_recursive_ex(module_child)
32+
33+
return reload_recursive_ex(package)
34+
35+
36+
@pytest.fixture(scope="function", autouse=True)
37+
def lit_home(monkeypatch):
38+
with tempfile.TemporaryDirectory() as tmp_dirname:
39+
monkeypatch.setattr(Path, "home", lambda: tmp_dirname)
40+
# we need to reload whole subpackage to apply the mock/fixture
41+
reload_package(lightning.store)
42+
yield os.path.join(tmp_dirname, ".lightning")

tests/tests_cloud/helpers.py

Lines changed: 0 additions & 12 deletions
This file was deleted.

tests/tests_cloud/test_model.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,48 @@
11
import os
22

3+
import pytest
34
from tests_cloud import _API_KEY, _PROJECT_ID, _USERNAME
4-
from tests_cloud.helpers import cleanup
55

66
import pytorch_lightning as pl
77
from lightning.store import download_from_cloud, load_model, upload_to_cloud
8-
from lightning.store.save import _LIGHTNING_STORAGE_DIR
8+
from lightning.store.save import __STORAGE_DIR_NAME
99
from pytorch_lightning.demos.boring_classes import BoringModel
1010

1111

12-
def test_model(model_name: str = "boring_model", version: str = "latest"):
13-
cleanup()
14-
12+
@pytest.mark.parametrize("pbar", [True, False])
13+
def test_model(lit_home, pbar, model_name: str = "boring_model", version: str = "latest"):
1514
upload_to_cloud(model_name, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID)
1615

17-
download_from_cloud(f"{_USERNAME}/{model_name}")
18-
assert os.path.isdir(os.path.join(_LIGHTNING_STORAGE_DIR, _USERNAME, model_name, version))
16+
download_from_cloud(f"{_USERNAME}/{model_name}", progress_bar=pbar)
17+
assert os.path.isdir(os.path.join(lit_home, __STORAGE_DIR_NAME, _USERNAME, model_name, version))
1918

2019
model = load_model(f"{_USERNAME}/{model_name}")
2120
assert model is not None
2221

2322

24-
def test_model_without_progress_bar(model_name: str = "boring_model", version: str = "latest"):
25-
cleanup()
26-
27-
upload_to_cloud(model_name, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID, progress_bar=False)
28-
29-
download_from_cloud(f"{_USERNAME}/{model_name}", progress_bar=False)
30-
assert os.path.isdir(os.path.join(_LIGHTNING_STORAGE_DIR, _USERNAME, model_name, version))
31-
32-
model = load_model(f"{_USERNAME}/{model_name}")
33-
assert model is not None
34-
35-
36-
def test_only_weights(model_name: str = "boring_model_only_weights", version: str = "latest"):
37-
cleanup()
38-
23+
def test_only_weights(lit_home, model_name: str = "boring_model_only_weights", version: str = "latest"):
3924
model = BoringModel()
4025
trainer = pl.Trainer(fast_dev_run=True)
4126
trainer.fit(model)
4227
upload_to_cloud(model_name, model=model, weights_only=True, api_key=_API_KEY, project_id=_PROJECT_ID)
4328

4429
download_from_cloud(f"{_USERNAME}/{model_name}")
45-
assert os.path.isdir(os.path.join(_LIGHTNING_STORAGE_DIR, _USERNAME, model_name, version))
30+
assert os.path.isdir(os.path.join(lit_home, __STORAGE_DIR_NAME, _USERNAME, model_name, version))
4631

4732
model_with_weights = load_model(f"{_USERNAME}/{model_name}", load_weights=True, model=model)
4833
assert model_with_weights is not None
4934
assert model_with_weights.state_dict() is not None
5035

5136

52-
def test_checkpoint_path(model_name: str = "boring_model_only_checkpoint_path", version: str = "latest"):
53-
cleanup()
54-
37+
def test_checkpoint_path(lit_home, model_name: str = "boring_model_only_checkpoint_path", version: str = "latest"):
5538
model = BoringModel()
5639
trainer = pl.Trainer(fast_dev_run=True)
5740
trainer.fit(model)
5841
trainer.save_checkpoint("tmp.ckpt")
5942
upload_to_cloud(model_name, checkpoint_path="tmp.ckpt", api_key=_API_KEY, project_id=_PROJECT_ID)
6043

6144
download_from_cloud(f"{_USERNAME}/{model_name}")
62-
assert os.path.isdir(os.path.join(_LIGHTNING_STORAGE_DIR, _USERNAME, model_name, version))
45+
assert os.path.isdir(os.path.join(lit_home, __STORAGE_DIR_NAME, _USERNAME, model_name, version))
6346

6447
ckpt = load_model(f"{_USERNAME}/{model_name}", load_checkpoint=True, model=model)
6548
assert ckpt is not None

tests/tests_cloud/test_requirements.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
import os
22

33
from tests_cloud import _API_KEY, _PROJECT_ID, _USERNAME
4-
from tests_cloud.helpers import cleanup
54

65
from lightning.store import download_from_cloud, upload_to_cloud
7-
from lightning.store.save import _LIGHTNING_STORAGE_DIR
6+
from lightning.store.save import __STORAGE_DIR_NAME
87
from pytorch_lightning.demos.boring_classes import BoringModel
98

109

11-
def test_requirements(version: str = "1.0.0", model_name: str = "boring_model"):
12-
cleanup()
13-
10+
def test_requirements(lit_home, version: str = "1.0.0", model_name: str = "boring_model"):
1411
requirements_list = ["pytorch_lightning==1.7.7", "lightning"]
1512

1613
upload_to_cloud(
@@ -24,7 +21,7 @@ def test_requirements(version: str = "1.0.0", model_name: str = "boring_model"):
2421

2522
download_from_cloud(f"{_USERNAME}/{model_name}", version=version)
2623

27-
req_folder_path = os.path.join(_LIGHTNING_STORAGE_DIR, _USERNAME, model_name, version)
24+
req_folder_path = os.path.join(lit_home, __STORAGE_DIR_NAME, _USERNAME, model_name, version)
2825
assert os.path.isdir(req_folder_path), "missing: %s" % req_folder_path
2926
assert "requirements.txt" in os.listdir(req_folder_path), "among files: %r" % os.listdir(req_folder_path)
3027

tests/tests_cloud/test_source_code.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,20 @@
33
import tempfile
44

55
from tests_cloud import _API_KEY, _PROJECT_ID, _PROJECT_ROOT, _TEST_ROOT, _USERNAME
6-
from tests_cloud.helpers import cleanup
76

87
from lightning.store import download_from_cloud, upload_to_cloud
9-
from lightning.store.save import _LIGHTNING_STORAGE_DIR
8+
from lightning.store.save import __STORAGE_DIR_NAME
109
from pytorch_lightning.demos.boring_classes import BoringModel
1110

1211

13-
def test_source_code_implicit(model_name: str = "model_test_source_code_implicit"):
14-
cleanup()
15-
12+
def test_source_code_implicit(lit_home, model_name: str = "model_test_source_code_implicit"):
1613
upload_to_cloud(model_name, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID)
1714

1815
download_from_cloud(f"{_USERNAME}/{model_name}")
1916
assert os.path.isfile(
2017
os.path.join(
21-
_LIGHTNING_STORAGE_DIR,
18+
lit_home,
19+
__STORAGE_DIR_NAME,
2220
_USERNAME,
2321
model_name,
2422
"latest",
@@ -27,15 +25,14 @@ def test_source_code_implicit(model_name: str = "model_test_source_code_implicit
2725
)
2826

2927

30-
def test_source_code_saving_disabled(model_name: str = "model_test_source_code_dont_save"):
31-
cleanup()
32-
28+
def test_source_code_saving_disabled(lit_home, model_name: str = "model_test_source_code_dont_save"):
3329
upload_to_cloud(model_name, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID, save_code=False)
3430

3531
download_from_cloud(f"{_USERNAME}/{model_name}")
3632
assert not os.path.isfile(
3733
os.path.join(
38-
_LIGHTNING_STORAGE_DIR,
34+
lit_home,
35+
__STORAGE_DIR_NAME,
3936
_USERNAME,
4037
model_name,
4138
"latest",
@@ -44,26 +41,29 @@ def test_source_code_saving_disabled(model_name: str = "model_test_source_code_d
4441
)
4542

4643

47-
def test_source_code_explicit_relative_folder(model_name: str = "model_test_source_code_explicit_relative"):
48-
cleanup()
49-
50-
dir_upload_path = _TEST_ROOT
44+
def test_source_code_explicit_relative_folder(lit_home, model_name: str = "model_test_source_code_explicit_relative"):
5145
upload_to_cloud(
52-
model_name, model=BoringModel(), source_code_path=dir_upload_path, api_key=_API_KEY, project_id=_PROJECT_ID
46+
model_name, model=BoringModel(), source_code_path=_TEST_ROOT, api_key=_API_KEY, project_id=_PROJECT_ID
5347
)
5448

5549
download_from_cloud(f"{_USERNAME}/{model_name}")
5650

5751
assert os.path.isdir(
5852
os.path.join(
59-
_LIGHTNING_STORAGE_DIR, _USERNAME, model_name, "latest", os.path.basename(os.path.abspath(dir_upload_path))
53+
lit_home,
54+
__STORAGE_DIR_NAME,
55+
_USERNAME,
56+
model_name,
57+
"latest",
58+
os.path.basename(os.path.abspath(_TEST_ROOT)),
6059
)
6160
)
6261

6362

64-
def test_source_code_explicit_absolute_folder(model_name: str = "model_test_source_code_explicit_absolute_path"):
65-
cleanup()
66-
63+
def test_source_code_explicit_absolute_folder(
64+
lit_home, model_name: str = "model_test_source_code_explicit_absolute_path"
65+
):
66+
# TODO: unify with above `test_source_code_explicit_relative_folder`
6767
with tempfile.TemporaryDirectory() as tmpdir:
6868
dir_upload_path = os.path.abspath(tmpdir)
6969
upload_to_cloud(
@@ -74,14 +74,17 @@ def test_source_code_explicit_absolute_folder(model_name: str = "model_test_sour
7474

7575
assert os.path.isdir(
7676
os.path.join(
77-
_LIGHTNING_STORAGE_DIR, _USERNAME, model_name, "latest", os.path.basename(os.path.abspath(dir_upload_path))
77+
lit_home,
78+
__STORAGE_DIR_NAME,
79+
_USERNAME,
80+
model_name,
81+
"latest",
82+
os.path.basename(os.path.abspath(dir_upload_path)),
7883
)
7984
)
8085

8186

82-
def test_source_code_explicit_file(model_name: str = "model_test_source_code_explicit_file"):
83-
cleanup()
84-
87+
def test_source_code_explicit_file(lit_home, model_name: str = "model_test_source_code_explicit_file"):
8588
file_name = os.path.join(_PROJECT_ROOT, "setup.py")
8689
upload_to_cloud(
8790
model_name, model=BoringModel(), source_code_path=file_name, api_key=_API_KEY, project_id=_PROJECT_ID
@@ -91,7 +94,8 @@ def test_source_code_explicit_file(model_name: str = "model_test_source_code_exp
9194

9295
assert os.path.isfile(
9396
os.path.join(
94-
_LIGHTNING_STORAGE_DIR,
97+
lit_home,
98+
__STORAGE_DIR_NAME,
9599
_USERNAME,
96100
model_name,
97101
"latest",

tests/tests_cloud/test_versioning.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33

44
import pytest
55
from tests_cloud import _API_KEY, _PROJECT_ID, _USERNAME
6-
from tests_cloud.helpers import cleanup
76

87
from lightning.store.cloud_api import download_from_cloud, upload_to_cloud
9-
from lightning.store.save import _LIGHTNING_STORAGE_DIR
8+
from lightning.store.save import __STORAGE_DIR_NAME
109
from pytorch_lightning.demos.boring_classes import BoringModel
1110

1211

13-
def assert_download_successful(username, model_name, version):
14-
folder_name = os.path.join(_LIGHTNING_STORAGE_DIR, username, model_name, version)
12+
def assert_download_successful(lit_home, username, model_name, version):
13+
folder_name = os.path.join(lit_home, __STORAGE_DIR_NAME, username, model_name, version)
1514
assert os.path.isdir(folder_name), f"Folder name: {folder_name} doesn't exist."
1615
assert len(os.listdir(folder_name)) != 0
1716

@@ -30,12 +29,10 @@ def assert_download_successful(username, model_name, version):
3029
]
3130
),
3231
)
33-
def test_versioning_valid_case(case, expected_case, model_name: str = "boring_model_versioning"):
34-
cleanup()
35-
32+
def test_versioning_valid_case(lit_home, case, expected_case, model_name: str = "boring_model_versioning"):
3633
upload_to_cloud(model_name, version=case, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID)
3734
download_from_cloud(f"{_USERNAME}/{model_name}", version=case)
38-
assert_download_successful(_USERNAME, model_name, expected_case)
35+
assert_download_successful(lit_home, _USERNAME, model_name, expected_case)
3936

4037

4138
@pytest.mark.parametrize(
@@ -50,9 +47,7 @@ def test_versioning_valid_case(case, expected_case, model_name: str = "boring_mo
5047
]
5148
),
5249
)
53-
def test_versioning_invalid_case(case, model_name: str = "boring_model_versioning"):
54-
cleanup()
55-
50+
def test_versioning_invalid_case(lit_home, case, model_name: str = "boring_model_versioning"):
5651
with pytest.raises(ConnectionRefusedError):
5752
upload_to_cloud(model_name, version=case, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID)
5853

0 commit comments

Comments
 (0)