Skip to content

store: mock/fixture home #16536

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/lightning/store/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
import shutil
import tarfile
from pathlib import PurePath
from pathlib import Path, PurePath

import requests
import torch
Expand All @@ -30,9 +30,11 @@

logging.basicConfig(level=logging.INFO)

_LIGHTNING_DIR = os.path.join(os.path.expanduser("~"), ".lightning")
_LIGHTNING_STORAGE_FILE = os.path.join(_LIGHTNING_DIR, ".model_storage")
_LIGHTNING_STORAGE_DIR = os.path.join(_LIGHTNING_DIR, "model_store")
_LIGHTNING_DIR = os.path.join(Path.home(), ".lightning")
__STORAGE_FILE_NAME = ".model_storage"
_LIGHTNING_STORAGE_FILE = os.path.join(_LIGHTNING_DIR, __STORAGE_FILE_NAME)
__STORAGE_DIR_NAME = "model_store"
_LIGHTNING_STORAGE_DIR = os.path.join(_LIGHTNING_DIR, __STORAGE_DIR_NAME)


def _check_id(id: str) -> str:
Expand Down
42 changes: 42 additions & 0 deletions tests/tests_cloud/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import importlib
import os
import tempfile
import types
from pathlib import Path

import pytest

import lightning
import lightning.store


def reload_package(package):
# credit: https://stackoverflow.com/a/28516918/4521646
assert hasattr(package, "__package__")
fn = package.__file__
fn_dir = os.path.dirname(fn) + os.sep
module_visit = {fn}
del fn

def reload_recursive_ex(module):
importlib.reload(module)

for module_child in vars(module).values():
if not isinstance(module_child, types.ModuleType):
continue
fn_child = getattr(module_child, "__file__", None)
if (fn_child is not None) and fn_child.startswith(fn_dir) and fn_child not in module_visit:
# print("reloading:", fn_child, "from", module)
module_visit.add(fn_child)
reload_recursive_ex(module_child)

return reload_recursive_ex(package)


@pytest.fixture(scope="function", autouse=True)
def lit_home(monkeypatch):
with tempfile.TemporaryDirectory() as tmp_dirname:
monkeypatch.setattr(Path, "home", lambda: tmp_dirname)
# we need to reload whole subpackage to apply the mock/fixture
reload_package(lightning.store)
yield os.path.join(tmp_dirname, ".lightning")
12 changes: 0 additions & 12 deletions tests/tests_cloud/helpers.py

This file was deleted.

37 changes: 10 additions & 27 deletions tests/tests_cloud/test_model.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,48 @@
import os

import pytest
from tests_cloud import _API_KEY, _PROJECT_ID, _USERNAME
from tests_cloud.helpers import cleanup

import pytorch_lightning as pl
from lightning.store import download_from_cloud, load_model, upload_to_cloud
from lightning.store.save import _LIGHTNING_STORAGE_DIR
from lightning.store.save import __STORAGE_DIR_NAME
from pytorch_lightning.demos.boring_classes import BoringModel


def test_model(model_name: str = "boring_model", version: str = "latest"):
cleanup()

@pytest.mark.parametrize("pbar", [True, False])
def test_model(lit_home, pbar, model_name: str = "boring_model", version: str = "latest"):
upload_to_cloud(model_name, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID)

download_from_cloud(f"{_USERNAME}/{model_name}")
assert os.path.isdir(os.path.join(_LIGHTNING_STORAGE_DIR, _USERNAME, model_name, version))
download_from_cloud(f"{_USERNAME}/{model_name}", progress_bar=pbar)
assert os.path.isdir(os.path.join(lit_home, __STORAGE_DIR_NAME, _USERNAME, model_name, version))

model = load_model(f"{_USERNAME}/{model_name}")
assert model is not None


def test_model_without_progress_bar(model_name: str = "boring_model", version: str = "latest"):
cleanup()

upload_to_cloud(model_name, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID, progress_bar=False)

download_from_cloud(f"{_USERNAME}/{model_name}", progress_bar=False)
assert os.path.isdir(os.path.join(_LIGHTNING_STORAGE_DIR, _USERNAME, model_name, version))

model = load_model(f"{_USERNAME}/{model_name}")
assert model is not None


def test_only_weights(model_name: str = "boring_model_only_weights", version: str = "latest"):
cleanup()

def test_only_weights(lit_home, model_name: str = "boring_model_only_weights", version: str = "latest"):
model = BoringModel()
trainer = pl.Trainer(fast_dev_run=True)
trainer.fit(model)
upload_to_cloud(model_name, model=model, weights_only=True, api_key=_API_KEY, project_id=_PROJECT_ID)

download_from_cloud(f"{_USERNAME}/{model_name}")
assert os.path.isdir(os.path.join(_LIGHTNING_STORAGE_DIR, _USERNAME, model_name, version))
assert os.path.isdir(os.path.join(lit_home, __STORAGE_DIR_NAME, _USERNAME, model_name, version))

model_with_weights = load_model(f"{_USERNAME}/{model_name}", load_weights=True, model=model)
assert model_with_weights is not None
assert model_with_weights.state_dict() is not None


def test_checkpoint_path(model_name: str = "boring_model_only_checkpoint_path", version: str = "latest"):
cleanup()

def test_checkpoint_path(lit_home, model_name: str = "boring_model_only_checkpoint_path", version: str = "latest"):
model = BoringModel()
trainer = pl.Trainer(fast_dev_run=True)
trainer.fit(model)
trainer.save_checkpoint("tmp.ckpt")
upload_to_cloud(model_name, checkpoint_path="tmp.ckpt", api_key=_API_KEY, project_id=_PROJECT_ID)

download_from_cloud(f"{_USERNAME}/{model_name}")
assert os.path.isdir(os.path.join(_LIGHTNING_STORAGE_DIR, _USERNAME, model_name, version))
assert os.path.isdir(os.path.join(lit_home, __STORAGE_DIR_NAME, _USERNAME, model_name, version))

ckpt = load_model(f"{_USERNAME}/{model_name}", load_checkpoint=True, model=model)
assert ckpt is not None
9 changes: 3 additions & 6 deletions tests/tests_cloud/test_requirements.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import os

from tests_cloud import _API_KEY, _PROJECT_ID, _USERNAME
from tests_cloud.helpers import cleanup

from lightning.store import download_from_cloud, upload_to_cloud
from lightning.store.save import _LIGHTNING_STORAGE_DIR
from lightning.store.save import __STORAGE_DIR_NAME
from pytorch_lightning.demos.boring_classes import BoringModel


def test_requirements(version: str = "1.0.0", model_name: str = "boring_model"):
cleanup()

def test_requirements(lit_home, version: str = "1.0.0", model_name: str = "boring_model"):
requirements_list = ["pytorch_lightning==1.7.7", "lightning"]

upload_to_cloud(
Expand All @@ -24,7 +21,7 @@ def test_requirements(version: str = "1.0.0", model_name: str = "boring_model"):

download_from_cloud(f"{_USERNAME}/{model_name}", version=version)

req_folder_path = os.path.join(_LIGHTNING_STORAGE_DIR, _USERNAME, model_name, version)
req_folder_path = os.path.join(lit_home, __STORAGE_DIR_NAME, _USERNAME, model_name, version)
assert os.path.isdir(req_folder_path), "missing: %s" % req_folder_path
assert "requirements.txt" in os.listdir(req_folder_path), "among files: %r" % os.listdir(req_folder_path)

Expand Down
52 changes: 28 additions & 24 deletions tests/tests_cloud/test_source_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,20 @@
import tempfile

from tests_cloud import _API_KEY, _PROJECT_ID, _PROJECT_ROOT, _TEST_ROOT, _USERNAME
from tests_cloud.helpers import cleanup

from lightning.store import download_from_cloud, upload_to_cloud
from lightning.store.save import _LIGHTNING_STORAGE_DIR
from lightning.store.save import __STORAGE_DIR_NAME
from pytorch_lightning.demos.boring_classes import BoringModel


def test_source_code_implicit(model_name: str = "model_test_source_code_implicit"):
cleanup()

def test_source_code_implicit(lit_home, model_name: str = "model_test_source_code_implicit"):
upload_to_cloud(model_name, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID)

download_from_cloud(f"{_USERNAME}/{model_name}")
assert os.path.isfile(
os.path.join(
_LIGHTNING_STORAGE_DIR,
lit_home,
__STORAGE_DIR_NAME,
_USERNAME,
model_name,
"latest",
Expand All @@ -27,15 +25,14 @@ def test_source_code_implicit(model_name: str = "model_test_source_code_implicit
)


def test_source_code_saving_disabled(model_name: str = "model_test_source_code_dont_save"):
cleanup()

def test_source_code_saving_disabled(lit_home, model_name: str = "model_test_source_code_dont_save"):
upload_to_cloud(model_name, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID, save_code=False)

download_from_cloud(f"{_USERNAME}/{model_name}")
assert not os.path.isfile(
os.path.join(
_LIGHTNING_STORAGE_DIR,
lit_home,
__STORAGE_DIR_NAME,
_USERNAME,
model_name,
"latest",
Expand All @@ -44,26 +41,29 @@ def test_source_code_saving_disabled(model_name: str = "model_test_source_code_d
)


def test_source_code_explicit_relative_folder(model_name: str = "model_test_source_code_explicit_relative"):
cleanup()

dir_upload_path = _TEST_ROOT
def test_source_code_explicit_relative_folder(lit_home, model_name: str = "model_test_source_code_explicit_relative"):
upload_to_cloud(
model_name, model=BoringModel(), source_code_path=dir_upload_path, api_key=_API_KEY, project_id=_PROJECT_ID
model_name, model=BoringModel(), source_code_path=_TEST_ROOT, api_key=_API_KEY, project_id=_PROJECT_ID
)

download_from_cloud(f"{_USERNAME}/{model_name}")

assert os.path.isdir(
os.path.join(
_LIGHTNING_STORAGE_DIR, _USERNAME, model_name, "latest", os.path.basename(os.path.abspath(dir_upload_path))
lit_home,
__STORAGE_DIR_NAME,
_USERNAME,
model_name,
"latest",
os.path.basename(os.path.abspath(_TEST_ROOT)),
)
)


def test_source_code_explicit_absolute_folder(model_name: str = "model_test_source_code_explicit_absolute_path"):
cleanup()

def test_source_code_explicit_absolute_folder(
lit_home, model_name: str = "model_test_source_code_explicit_absolute_path"
):
# TODO: unify with above `test_source_code_explicit_relative_folder`
with tempfile.TemporaryDirectory() as tmpdir:
dir_upload_path = os.path.abspath(tmpdir)
upload_to_cloud(
Expand All @@ -74,14 +74,17 @@ def test_source_code_explicit_absolute_folder(model_name: str = "model_test_sour

assert os.path.isdir(
os.path.join(
_LIGHTNING_STORAGE_DIR, _USERNAME, model_name, "latest", os.path.basename(os.path.abspath(dir_upload_path))
lit_home,
__STORAGE_DIR_NAME,
_USERNAME,
model_name,
"latest",
os.path.basename(os.path.abspath(dir_upload_path)),
)
)


def test_source_code_explicit_file(model_name: str = "model_test_source_code_explicit_file"):
cleanup()

def test_source_code_explicit_file(lit_home, model_name: str = "model_test_source_code_explicit_file"):
file_name = os.path.join(_PROJECT_ROOT, "setup.py")
upload_to_cloud(
model_name, model=BoringModel(), source_code_path=file_name, api_key=_API_KEY, project_id=_PROJECT_ID
Expand All @@ -91,7 +94,8 @@ def test_source_code_explicit_file(model_name: str = "model_test_source_code_exp

assert os.path.isfile(
os.path.join(
_LIGHTNING_STORAGE_DIR,
lit_home,
__STORAGE_DIR_NAME,
_USERNAME,
model_name,
"latest",
Expand Down
17 changes: 6 additions & 11 deletions tests/tests_cloud/test_versioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@

import pytest
from tests_cloud import _API_KEY, _PROJECT_ID, _USERNAME
from tests_cloud.helpers import cleanup

from lightning.store.cloud_api import download_from_cloud, upload_to_cloud
from lightning.store.save import _LIGHTNING_STORAGE_DIR
from lightning.store.save import __STORAGE_DIR_NAME
from pytorch_lightning.demos.boring_classes import BoringModel


def assert_download_successful(username, model_name, version):
folder_name = os.path.join(_LIGHTNING_STORAGE_DIR, username, model_name, version)
def assert_download_successful(lit_home, username, model_name, version):
folder_name = os.path.join(lit_home, __STORAGE_DIR_NAME, username, model_name, version)
assert os.path.isdir(folder_name), f"Folder name: {folder_name} doesn't exist."
assert len(os.listdir(folder_name)) != 0

Expand All @@ -30,12 +29,10 @@ def assert_download_successful(username, model_name, version):
]
),
)
def test_versioning_valid_case(case, expected_case, model_name: str = "boring_model_versioning"):
cleanup()

def test_versioning_valid_case(lit_home, case, expected_case, model_name: str = "boring_model_versioning"):
upload_to_cloud(model_name, version=case, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID)
download_from_cloud(f"{_USERNAME}/{model_name}", version=case)
assert_download_successful(_USERNAME, model_name, expected_case)
assert_download_successful(lit_home, _USERNAME, model_name, expected_case)


@pytest.mark.parametrize(
Expand All @@ -50,9 +47,7 @@ def test_versioning_valid_case(case, expected_case, model_name: str = "boring_mo
]
),
)
def test_versioning_invalid_case(case, model_name: str = "boring_model_versioning"):
cleanup()

def test_versioning_invalid_case(lit_home, case, model_name: str = "boring_model_versioning"):
with pytest.raises(ConnectionRefusedError):
upload_to_cloud(model_name, version=case, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID)

Expand Down