Skip to content

store: download model before load #16552

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 52 additions & 58 deletions src/lightning/store/cloud_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from lightning.app.core.constants import LIGHTNING_MODELS_PUBLIC_REGISTRY
from lightning.store.authentication import _authenticate
from lightning.store.save import (
_download_and_extract_data_to,
_download_and_extract_data,
_get_linked_output_dir,
_LIGHTNING_STORAGE_DIR,
_LIGHTNING_STORAGE_FILE,
Expand Down Expand Up @@ -103,7 +103,7 @@ def upload_model(
or
`upload_model("model_name", checkpoint_path="your_checkpoint_path.ckpt", ...)`
is required.
"""
"""
)

if weights_only and not model:
Expand Down Expand Up @@ -141,12 +141,7 @@ def upload_model(
)

if requirements:
stored = _write_and_save_requirements(
model_name,
requirements=requirements,
stored=stored,
tmpdir=tmpdir,
)
stored = _write_and_save_requirements(model_name, requirements=requirements, stored=stored, tmpdir=tmpdir)

url = _save_meta_data(
model_name,
Expand Down Expand Up @@ -260,7 +255,7 @@ def download_model(
meta_data = download_url_response["metadata"]

logging.info(f"Downloading the model data for {name} to {output_dir} folder.")
_download_and_extract_data_to(output_dir, download_url, progress_bar)
_download_and_extract_data(output_dir, download_url, progress_bar)

if linked_output_dir:
logging.info(f"Linking the downloaded folder from {output_dir} to {linked_output_dir} folder.")
Expand Down Expand Up @@ -307,6 +302,7 @@ def load_model(
load_weights: bool = False,
load_checkpoint: bool = False,
model: Union[PL.LightningModule, L.LightningModule, None] = None,
progress_bar: bool = True,
*args,
**kwargs,
):
Expand All @@ -323,59 +319,57 @@ def load_model(
Loads checkpoint if this is set to `True`. Only a `LightningModule` model is supported for this feature.
model:
Model class to be used.
progress_bar:
Show progress on download.
"""
if load_weights and load_checkpoint:
raise ValueError(
f"You passed load_weights={load_weights} and load_checkpoint={load_checkpoint},"
" it's expected that only one of them are requested in a single call."
)

if os.path.exists(_LIGHTNING_STORAGE_FILE):
version = version or "latest"
model_data = _get_model_data(name, version)
output_dir = model_data["output_dir"]
linked_output_dir = model_data["linked_output_dir"]
meta_data = model_data["metadata"]
stored = {"code": {}}

for key, val in meta_data.items():
if key.startswith("stored_"):
if key.startswith("stored_code_"):
stored["code"][key.split("_code_")[1]] = val
else:
stored[key.split("_")[1]] = val

_validate_output_dir(output_dir)
if linked_output_dir:
_validate_output_dir(linked_output_dir)

if load_weights:
# This first loads the model - and then the weights
if not model:
raise ValueError(
"Expected model=... to be passed for loading weights, please pass"
f" your model object to load_model({name}, {version}, model=ModelObj)"
)
return _load_weights(model, stored, linked_output_dir or output_dir, *args, **kwargs)
elif load_checkpoint:
if not model:
raise ValueError(
"You need to pass the LightningModule object (model) to be able to"
f" load the checkpoint. `load_model({name}, {version},"
" load_checkpoint=True, model=...)`"
)
if not isinstance(model, (PL.LightningModule, L.LightningModule)):
raise TypeError(
"For loading checkpoints, the model is required to be a LightningModule"
f" or a subclass of LightningModule, got type {type(model)}."
)

return _load_checkpoint(model, stored, linked_output_dir or output_dir, *args, **kwargs)
else:
return _load_model(stored, linked_output_dir or output_dir, *args, **kwargs)
else:
raise ValueError(
f"Could not find the model (for {name}:{version}) in the local system."
" Did you make sure to download the model using: `download_model(...)`"
" before calling `load_model(...)`?"
)
if not os.path.exists(_LIGHTNING_STORAGE_FILE):
download_model(name=name, version=version, output_dir=_LIGHTNING_STORAGE_DIR, progress_bar=progress_bar)
Comment on lines +339 to +340
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks weird, we only want to download it if we haven't before right? Does download_model handle that? Or can you only download one at a time?

Also, after you do this one that file will exist, but then if you change the model the file still exists no? So probably this only works the first time you do it...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks weird, we only want to download it if we haven't before right? Does download_model handle that? Or can you only download one at a time?

good point, but I think this shall be handled on the download side...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the condition just before solving the issue?

if not os.path.exists(_LIGHTNING_STORAGE_FILE):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ethanwharris 's point is that this has only a chance to run once, after which _LIGHTNING_STORAGE_FILE will exist.

Note that _LIGHTNING_STORAGE_FILE is defined as

_LIGHTNING_DIR = os.path.join(Path.home(), ".lightning")
__STORAGE_FILE_NAME = ".model_storage"
_LIGHTNING_STORAGE_FILE = os.path.join(_LIGHTNING_DIR, __STORAGE_FILE_NAME)

so it has nothing to do with the actual location where the specific model file has been saved. It only contains metadata, see:

https://github.com/Lightning-AI/lightning/blob/21a7aa2735d1617fc0ed581b318c6a8b5ca09d89/src/lightning/store/cloud_api.py#L284

The actual solution here is to peek into this file, extract the actual location and see if the specific model file exists. Then, if version is latest, we need to resolve it and see if it increased with respect to what it's contained in the storage file.


version = version or "latest"
model_data = _get_model_data(name, version)
output_dir = model_data["output_dir"]
linked_output_dir = model_data["linked_output_dir"]
meta_data = model_data["metadata"]
stored = {"code": {}}

for key, val in meta_data.items():
if key.startswith("stored_"):
if key.startswith("stored_code_"):
stored["code"][key.split("_code_")[1]] = val
else:
stored[key.split("_")[1]] = val

_validate_output_dir(output_dir)
if linked_output_dir:
_validate_output_dir(linked_output_dir)

if load_weights:
# This first loads the model - and then the weights
if not model:
raise ValueError(
"Expected model=... to be passed for loading weights, please pass"
f" your model object to load_model({name}, {version}, model=ModelObj)"
)
return _load_weights(model, stored, linked_output_dir or output_dir, *args, **kwargs)
elif load_checkpoint:
if not model:
raise ValueError(
"You need to pass the LightningModule object (model) to be able to"
f" load the checkpoint. `load_model({name}, {version},"
" load_checkpoint=True, model=...)`"
)
if not isinstance(model, (PL.LightningModule, L.LightningModule)):
raise TypeError(
"For loading checkpoints, the model is required to be a LightningModule"
f" or a subclass of LightningModule, got type {type(model)}."
)

return _load_checkpoint(model, stored, linked_output_dir or output_dir, *args, **kwargs)

return _load_model(stored, linked_output_dir or output_dir, *args, **kwargs)
29 changes: 5 additions & 24 deletions src/lightning/store/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,18 +164,10 @@ def _get_url(response_content):
content = json.loads(response_content)
return content["uploadUrl"]

json_field = {
"name": f"{username}/{name}",
"version": version,
"metadata": meta_data,
}
json_field = {"name": f"{username}/{name}", "version": version, "metadata": meta_data}
if project_id:
json_field["project_id"] = project_id
response = requests.post(
LIGHTNING_MODELS_PUBLIC_REGISTRY,
auth=HTTPBasicAuth(username, api_key),
json=json_field,
)
response = requests.post(LIGHTNING_MODELS_PUBLIC_REGISTRY, auth=HTTPBasicAuth(username, api_key), json=json_field)
if response.status_code != 200:
raise ConnectionRefusedError(f"Unable to upload content.\n Error: {response.content}\n for load: {json_field}")
return _get_url(response.content)
Expand All @@ -196,12 +188,7 @@ def _process_stored(stored: dict):
meta_data.update(_process_stored(stored))

return _upload_metadata(
meta_data,
name=name,
version=version,
username=username,
api_key=api_key,
project_id=project_id,
meta_data, name=name, version=version, username=username, api_key=api_key, project_id=project_id
)


Expand All @@ -213,13 +200,7 @@ def _make_tar(tmpdir, archive_output_path):
def upload_from_file(src, dst):
file_size = os.path.getsize(src)
with open(src, "rb") as fd:
with tqdm(
desc="Uploading",
total=file_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as t:
with tqdm(desc="Uploading", total=file_size, unit="B", unit_scale=True, unit_divisor=1024) as t:
reader_wrapper = CallbackIOWrapper(t.update, fd, "read")
response = requests.put(dst, data=reader_wrapper)
response.raise_for_status()
Expand Down Expand Up @@ -257,7 +238,7 @@ def _common_clean_up(output_dir: str) -> None:
shutil.rmtree(dir_file_path)


def _download_and_extract_data_to(output_dir: str, download_url: str, progress_bar: bool) -> None:
def _download_and_extract_data(output_dir: str, download_url: str, progress_bar: bool) -> None:
try:
_download_tarfile(download_url, output_dir, progress_bar)

Expand Down
26 changes: 17 additions & 9 deletions tests/tests_cloud/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,40 +9,48 @@
from pytorch_lightning.demos.boring_classes import BoringModel


@pytest.mark.parametrize("pre_download", [True, False])
@pytest.mark.parametrize("pbar", [True, False])
def test_model(lit_home, pbar, model_name: str = "boring_model", version: str = "latest"):
def test_model(lit_home, pbar, pre_download, model_name: str = "boring_model", version: str = "latest"):
upload_model(model_name, model=BoringModel(), api_key=_API_KEY, project_id=_PROJECT_ID)

download_model(f"{_USERNAME}/{model_name}", progress_bar=pbar)
assert os.path.isdir(os.path.join(lit_home, __STORAGE_DIR_NAME, _USERNAME, model_name, version))
if pre_download:
download_model(f"{_USERNAME}/{model_name}", progress_bar=pbar)
assert os.path.isdir(os.path.join(lit_home, __STORAGE_DIR_NAME, _USERNAME, model_name, version))

model = load_model(f"{_USERNAME}/{model_name}")
assert model is not None


def test_only_weights(lit_home, model_name: str = "boring_model_only_weights", version: str = "latest"):
@pytest.mark.parametrize("pre_download", [True, False])
def test_only_weights(lit_home, pre_download, model_name: str = "boring_model_only_weights", version: str = "latest"):
model = BoringModel()
trainer = pl.Trainer(fast_dev_run=True)
trainer.fit(model)
upload_model(model_name, model=model, weights_only=True, api_key=_API_KEY, project_id=_PROJECT_ID)

download_model(f"{_USERNAME}/{model_name}")
assert os.path.isdir(os.path.join(lit_home, __STORAGE_DIR_NAME, _USERNAME, model_name, version))
if pre_download:
download_model(f"{_USERNAME}/{model_name}")
assert os.path.isdir(os.path.join(lit_home, __STORAGE_DIR_NAME, _USERNAME, model_name, version))

model_with_weights = load_model(f"{_USERNAME}/{model_name}", load_weights=True, model=model)
assert model_with_weights is not None
assert model_with_weights.state_dict() is not None


def test_checkpoint_path(lit_home, model_name: str = "boring_model_only_checkpoint_path", version: str = "latest"):
@pytest.mark.parametrize("pre_download", [True, False])
def test_checkpoint_path(
lit_home, pre_download, model_name: str = "boring_model_only_checkpoint_path", version: str = "latest"
):
model = BoringModel()
trainer = pl.Trainer(fast_dev_run=True)
trainer.fit(model)
trainer.save_checkpoint("tmp.ckpt")
upload_model(model_name, checkpoint_path="tmp.ckpt", api_key=_API_KEY, project_id=_PROJECT_ID)

download_model(f"{_USERNAME}/{model_name}")
assert os.path.isdir(os.path.join(lit_home, __STORAGE_DIR_NAME, _USERNAME, model_name, version))
if pre_download:
download_model(f"{_USERNAME}/{model_name}")
assert os.path.isdir(os.path.join(lit_home, __STORAGE_DIR_NAME, _USERNAME, model_name, version))

ckpt = load_model(f"{_USERNAME}/{model_name}", load_checkpoint=True, model=model)
assert ckpt is not None