diff --git a/src/lightning/store/save.py b/src/lightning/store/save.py index e179cfc05ca04..9029a52aed8f7 100644 --- a/src/lightning/store/save.py +++ b/src/lightning/store/save.py @@ -18,7 +18,7 @@ import os import shutil import tarfile -from pathlib import PurePath +from pathlib import Path, PurePath import requests import torch @@ -30,9 +30,11 @@ logging.basicConfig(level=logging.INFO) -_LIGHTNING_DIR = os.path.join(os.path.expanduser("~"), ".lightning") -_LIGHTNING_STORAGE_FILE = os.path.join(_LIGHTNING_DIR, ".model_storage") -_LIGHTNING_STORAGE_DIR = os.path.join(_LIGHTNING_DIR, "model_store") +_LIGHTNING_DIR = os.path.join(Path.home(), ".lightning") +__STORAGE_FILE_NAME = ".model_storage" +_LIGHTNING_STORAGE_FILE = os.path.join(_LIGHTNING_DIR, __STORAGE_FILE_NAME) +__STORAGE_DIR_NAME = "model_store" +_LIGHTNING_STORAGE_DIR = os.path.join(_LIGHTNING_DIR, __STORAGE_DIR_NAME) def _check_id(id: str) -> str: diff --git a/tests/tests_cloud/conftest.py b/tests/tests_cloud/conftest.py new file mode 100644 index 0000000000000..673d2bccd7848 --- /dev/null +++ b/tests/tests_cloud/conftest.py @@ -0,0 +1,42 @@ +import importlib +import os +import tempfile +import types +from pathlib import Path + +import pytest + +import lightning +import lightning.store + + +def reload_package(package): + # credit: https://stackoverflow.com/a/28516918/4521646 + assert hasattr(package, "__package__") + fn = package.__file__ + fn_dir = os.path.dirname(fn) + os.sep + module_visit = {fn} + del fn + + def reload_recursive_ex(module): + importlib.reload(module) + + for module_child in vars(module).values(): + if not isinstance(module_child, types.ModuleType): + continue + fn_child = getattr(module_child, "__file__", None) + if (fn_child is not None) and fn_child.startswith(fn_dir) and fn_child not in module_visit: + # print("reloading:", fn_child, "from", module) + module_visit.add(fn_child) + reload_recursive_ex(module_child) + + return reload_recursive_ex(package) + + +@pytest.fixture(scope="function", autouse=True) +def lit_home(monkeypatch): + with tempfile.TemporaryDirectory() as tmp_dirname: + monkeypatch.setattr(Path, "home", lambda: tmp_dirname) + # we need to reload whole subpackage to apply the mock/fixture + reload_package(lightning.store) + yield os.path.join(tmp_dirname, ".lightning") diff --git a/tests/tests_cloud/helpers.py b/tests/tests_cloud/helpers.py deleted file mode 100644 index 282cb095dfb77..0000000000000 --- a/tests/tests_cloud/helpers.py +++ /dev/null @@ -1,12 +0,0 @@ -import os -import shutil - -from lightning.store.save import _LIGHTNING_STORAGE_DIR - - -# TODO: make this as a fixture -def cleanup(): - # todo: `LIGHTNING_MODEL_STORE_TESTING` is nor working as intended, - # so the fixture shall create temp folder and map it home... - if os.getenv("LIGHTNING_MODEL_STORE_TESTING") and os.path.isdir(_LIGHTNING_STORAGE_DIR): - shutil.rmtree(_LIGHTNING_STORAGE_DIR) diff --git a/tests/tests_cloud/test_model.py b/tests/tests_cloud/test_model.py index 6eae880b3758e..4f1a0bbd31011 100644 --- a/tests/tests_cloud/test_model.py +++ b/tests/tests_cloud/test_model.py @@ -1,57 +1,40 @@ import os +import pytest from tests_cloud import _API_KEY, _PROJECT_ID, _USERNAME -from tests_cloud.helpers import cleanup import pytorch_lightning as pl from lightning.store import download_from_cloud, load_model, upload_to_cloud -from lightning.store.save import _LIGHTNING_STORAGE_DIR +from lightning.store.save import __STORAGE_DIR_NAME from pytorch_lightning.demos.boring_classes import BoringModel -def test_model(model_name: str = "boring_model", version: str = "latest"): - cleanup() - +@pytest.mark.parametrize("pbar", [True, False]) +def test_model(lit_home, pbar, model_name: str = "boring_model", version: str = "latest"): upload_to_cloud(model_name, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID) - download_from_cloud(f"{_USERNAME}/{model_name}") - assert os.path.isdir(os.path.join(_LIGHTNING_STORAGE_DIR, _USERNAME, model_name, version)) + download_from_cloud(f"{_USERNAME}/{model_name}", progress_bar=pbar) + assert os.path.isdir(os.path.join(lit_home, __STORAGE_DIR_NAME, _USERNAME, model_name, version)) model = load_model(f"{_USERNAME}/{model_name}") assert model is not None -def test_model_without_progress_bar(model_name: str = "boring_model", version: str = "latest"): - cleanup() - - upload_to_cloud(model_name, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID, progress_bar=False) - - download_from_cloud(f"{_USERNAME}/{model_name}", progress_bar=False) - assert os.path.isdir(os.path.join(_LIGHTNING_STORAGE_DIR, _USERNAME, model_name, version)) - - model = load_model(f"{_USERNAME}/{model_name}") - assert model is not None - - -def test_only_weights(model_name: str = "boring_model_only_weights", version: str = "latest"): - cleanup() - +def test_only_weights(lit_home, model_name: str = "boring_model_only_weights", version: str = "latest"): model = BoringModel() trainer = pl.Trainer(fast_dev_run=True) trainer.fit(model) upload_to_cloud(model_name, model=model, weights_only=True, api_key=_API_KEY, project_id=_PROJECT_ID) download_from_cloud(f"{_USERNAME}/{model_name}") - assert os.path.isdir(os.path.join(_LIGHTNING_STORAGE_DIR, _USERNAME, model_name, version)) + assert os.path.isdir(os.path.join(lit_home, __STORAGE_DIR_NAME, _USERNAME, model_name, version)) model_with_weights = load_model(f"{_USERNAME}/{model_name}", load_weights=True, model=model) assert model_with_weights is not None assert model_with_weights.state_dict() is not None -def test_checkpoint_path(model_name: str = "boring_model_only_checkpoint_path", version: str = "latest"): - cleanup() - +def test_checkpoint_path(lit_home, model_name: str = "boring_model_only_checkpoint_path", version: str = "latest"): model = BoringModel() trainer = pl.Trainer(fast_dev_run=True) trainer.fit(model) @@ -59,7 +42,7 @@ def test_checkpoint_path(model_name: str = "boring_model_only_checkpoint_path", upload_to_cloud(model_name, checkpoint_path="tmp.ckpt", api_key=_API_KEY, project_id=_PROJECT_ID) download_from_cloud(f"{_USERNAME}/{model_name}") - assert os.path.isdir(os.path.join(_LIGHTNING_STORAGE_DIR, _USERNAME, model_name, version)) + assert os.path.isdir(os.path.join(lit_home, __STORAGE_DIR_NAME, _USERNAME, model_name, version)) ckpt = load_model(f"{_USERNAME}/{model_name}", load_checkpoint=True, model=model) assert ckpt is not None diff --git a/tests/tests_cloud/test_requirements.py b/tests/tests_cloud/test_requirements.py index d1ae6e28ab801..dbffd7d33f133 100644 --- a/tests/tests_cloud/test_requirements.py +++ b/tests/tests_cloud/test_requirements.py @@ -1,16 +1,13 @@ import os from tests_cloud import _API_KEY, _PROJECT_ID, _USERNAME -from tests_cloud.helpers import cleanup from lightning.store import download_from_cloud, upload_to_cloud -from lightning.store.save import _LIGHTNING_STORAGE_DIR +from lightning.store.save import __STORAGE_DIR_NAME from pytorch_lightning.demos.boring_classes import BoringModel -def test_requirements(version: str = "1.0.0", model_name: str = "boring_model"): - cleanup() - +def test_requirements(lit_home, version: str = "1.0.0", model_name: str = "boring_model"): requirements_list = ["pytorch_lightning==1.7.7", "lightning"] upload_to_cloud( @@ -24,7 +21,7 @@ def test_requirements(version: str = "1.0.0", model_name: str = "boring_model"): download_from_cloud(f"{_USERNAME}/{model_name}", version=version) - req_folder_path = os.path.join(_LIGHTNING_STORAGE_DIR, _USERNAME, model_name, version) + req_folder_path = os.path.join(lit_home, __STORAGE_DIR_NAME, _USERNAME, model_name, version) assert os.path.isdir(req_folder_path), "missing: %s" % req_folder_path assert "requirements.txt" in os.listdir(req_folder_path), "among files: %r" % os.listdir(req_folder_path) diff --git a/tests/tests_cloud/test_source_code.py b/tests/tests_cloud/test_source_code.py index 11ea89775e484..e4111c7aad58b 100644 --- a/tests/tests_cloud/test_source_code.py +++ b/tests/tests_cloud/test_source_code.py @@ -3,22 +3,20 @@ import tempfile from tests_cloud import _API_KEY, _PROJECT_ID, _PROJECT_ROOT, _TEST_ROOT, _USERNAME -from tests_cloud.helpers import cleanup from lightning.store import download_from_cloud, upload_to_cloud -from lightning.store.save import _LIGHTNING_STORAGE_DIR +from lightning.store.save import __STORAGE_DIR_NAME from pytorch_lightning.demos.boring_classes import BoringModel -def test_source_code_implicit(model_name: str = "model_test_source_code_implicit"): - cleanup() - +def test_source_code_implicit(lit_home, model_name: str = "model_test_source_code_implicit"): upload_to_cloud(model_name, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID) download_from_cloud(f"{_USERNAME}/{model_name}") assert os.path.isfile( os.path.join( - _LIGHTNING_STORAGE_DIR, + lit_home, + __STORAGE_DIR_NAME, _USERNAME, model_name, "latest", @@ -27,15 +25,14 @@ def test_source_code_implicit(model_name: str = "model_test_source_code_implicit ) -def test_source_code_saving_disabled(model_name: str = "model_test_source_code_dont_save"): - cleanup() - +def test_source_code_saving_disabled(lit_home, model_name: str = "model_test_source_code_dont_save"): upload_to_cloud(model_name, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID, save_code=False) download_from_cloud(f"{_USERNAME}/{model_name}") assert not os.path.isfile( os.path.join( - _LIGHTNING_STORAGE_DIR, + lit_home, + __STORAGE_DIR_NAME, _USERNAME, model_name, "latest", @@ -44,26 +41,29 @@ def test_source_code_saving_disabled(model_name: str = "model_test_source_code_d ) -def test_source_code_explicit_relative_folder(model_name: str = "model_test_source_code_explicit_relative"): - cleanup() - - dir_upload_path = _TEST_ROOT +def test_source_code_explicit_relative_folder(lit_home, model_name: str = "model_test_source_code_explicit_relative"): upload_to_cloud( - model_name, model=BoringModel(), source_code_path=dir_upload_path, api_key=_API_KEY, project_id=_PROJECT_ID + model_name, model=BoringModel(), source_code_path=_TEST_ROOT, api_key=_API_KEY, project_id=_PROJECT_ID ) download_from_cloud(f"{_USERNAME}/{model_name}") assert os.path.isdir( os.path.join( - _LIGHTNING_STORAGE_DIR, _USERNAME, model_name, "latest", os.path.basename(os.path.abspath(dir_upload_path)) + lit_home, + __STORAGE_DIR_NAME, + _USERNAME, + model_name, + "latest", + os.path.basename(os.path.abspath(_TEST_ROOT)), ) ) -def test_source_code_explicit_absolute_folder(model_name: str = "model_test_source_code_explicit_absolute_path"): - cleanup() - +def test_source_code_explicit_absolute_folder( + lit_home, model_name: str = "model_test_source_code_explicit_absolute_path" +): + # TODO: unify with above `test_source_code_explicit_relative_folder` with tempfile.TemporaryDirectory() as tmpdir: dir_upload_path = os.path.abspath(tmpdir) upload_to_cloud( @@ -74,14 +74,17 @@ def test_source_code_explicit_absolute_folder(model_name: str = "model_test_sour assert os.path.isdir( os.path.join( - _LIGHTNING_STORAGE_DIR, _USERNAME, model_name, "latest", os.path.basename(os.path.abspath(dir_upload_path)) + lit_home, + __STORAGE_DIR_NAME, + _USERNAME, + model_name, + "latest", + os.path.basename(os.path.abspath(dir_upload_path)), ) ) -def test_source_code_explicit_file(model_name: str = "model_test_source_code_explicit_file"): - cleanup() - +def test_source_code_explicit_file(lit_home, model_name: str = "model_test_source_code_explicit_file"): file_name = os.path.join(_PROJECT_ROOT, "setup.py") upload_to_cloud( model_name, model=BoringModel(), source_code_path=file_name, api_key=_API_KEY, project_id=_PROJECT_ID @@ -91,7 +94,8 @@ def test_source_code_explicit_file(model_name: str = "model_test_source_code_exp assert os.path.isfile( os.path.join( - _LIGHTNING_STORAGE_DIR, + lit_home, + __STORAGE_DIR_NAME, _USERNAME, model_name, "latest", diff --git a/tests/tests_cloud/test_versioning.py b/tests/tests_cloud/test_versioning.py index 373d0a7500504..f198a6c17125f 100644 --- a/tests/tests_cloud/test_versioning.py +++ b/tests/tests_cloud/test_versioning.py @@ -3,15 +3,14 @@ import pytest from tests_cloud import _API_KEY, _PROJECT_ID, _USERNAME -from tests_cloud.helpers import cleanup from lightning.store.cloud_api import download_from_cloud, upload_to_cloud -from lightning.store.save import _LIGHTNING_STORAGE_DIR +from lightning.store.save import __STORAGE_DIR_NAME from pytorch_lightning.demos.boring_classes import BoringModel -def assert_download_successful(username, model_name, version): - folder_name = os.path.join(_LIGHTNING_STORAGE_DIR, username, model_name, version) +def assert_download_successful(lit_home, username, model_name, version): + folder_name = os.path.join(lit_home, __STORAGE_DIR_NAME, username, model_name, version) assert os.path.isdir(folder_name), f"Folder name: {folder_name} doesn't exist." assert len(os.listdir(folder_name)) != 0 @@ -30,12 +29,10 @@ def assert_download_successful(username, model_name, version): ] ), ) -def test_versioning_valid_case(case, expected_case, model_name: str = "boring_model_versioning"): - cleanup() - +def test_versioning_valid_case(lit_home, case, expected_case, model_name: str = "boring_model_versioning"): upload_to_cloud(model_name, version=case, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID) download_from_cloud(f"{_USERNAME}/{model_name}", version=case) - assert_download_successful(_USERNAME, model_name, expected_case) + assert_download_successful(lit_home, _USERNAME, model_name, expected_case) @pytest.mark.parametrize( @@ -50,9 +47,7 @@ def test_versioning_valid_case(case, expected_case, model_name: str = "boring_mo ] ), ) -def test_versioning_invalid_case(case, model_name: str = "boring_model_versioning"): - cleanup() - +def test_versioning_invalid_case(lit_home, case, model_name: str = "boring_model_versioning"): with pytest.raises(ConnectionRefusedError): upload_to_cloud(model_name, version=case, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID)