Skip to content

Commit d738ab1

Browse files
krshrimaliawaelchliBordapre-commit-ci[bot]
authored
Init: Models store API (#15811)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Jirka <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 25e1aff commit d738ab1

16 files changed

+1329
-3
lines changed

.azure/app-cloud-e2e.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ jobs:
117117
workspace:
118118
clean: all
119119
variables:
120+
FREEZE_REQUIREMENTS: "1"
120121
HEADLESS: '1'
121122
PACKAGE_LIGHTNING: '1'
122123
CLOUD: '1'
@@ -146,8 +147,6 @@ jobs:
146147
- bash: |
147148
pip install -e .[test] \
148149
-f https://download.pytorch.org/whl/cpu/torch_stable.html
149-
env:
150-
FREEZE_REQUIREMENTS: "1"
151150
displayName: 'Install Lightning & dependencies'
152151
153152
- bash: |

.azure/app-cloud-store.yml

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Python package
2+
# Create and test a Python package on multiple Python versions.
3+
# Add steps that analyze code, save the dist with the build record, publish to a PyPI-compatible index, and more:
4+
# https://docs.microsoft.com/azure/devops/pipelines/languages/python
5+
6+
trigger:
7+
tags:
8+
include:
9+
- '*'
10+
branches:
11+
include:
12+
- "master"
13+
- "release/*"
14+
- "refs/tags/*"
15+
16+
pr:
17+
branches:
18+
include:
19+
- "master"
20+
- "release/*"
21+
paths:
22+
include:
23+
- ".actions/**"
24+
- ".azure/app-cloud-store.yml"
25+
- "src/lightning/store/**"
26+
- "tests/tests_cloud/**"
27+
- "setup.py"
28+
exclude:
29+
- "*.md"
30+
- "**/*.md"
31+
32+
jobs:
33+
- job: test_store
34+
pool:
35+
vmImage: $(imageName)
36+
strategy:
37+
matrix:
38+
Linux:
39+
imageName: 'ubuntu-latest'
40+
Mac:
41+
imageName: 'macOS-latest'
42+
Windows:
43+
imageName: 'windows-latest'
44+
timeoutInMinutes: "20"
45+
cancelTimeoutInMinutes: "1"
46+
workspace:
47+
clean: all
48+
variables:
49+
FREEZE_REQUIREMENTS: "1"
50+
TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html"
51+
steps:
52+
- task: UsePythonVersion@0
53+
inputs:
54+
versionSpec: '3.9'
55+
56+
- bash: pip install -e .[test] -f $(TORCH_URL)
57+
displayName: 'Install Lightning & dependencies'
58+
59+
- bash: |
60+
python -m pytest -m "not cloud" tests_cloud --timeout=300 -v
61+
workingDirectory: tests/
62+
env:
63+
API_KEY: $(LIGHTNING_API_KEY_PROD)
64+
API_USERNAME: $(LIGHTNING_USERNAME_PROD)
65+
PROJECT_ID: $(LIGHTNING_PROJECT_ID_PROD)
66+
LIGHTNING_CLOUD_URL: $(LIGHTNING_CLOUD_URL_PROD)
67+
displayName: 'Run the tests'

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@ wheels/
4848
*.egg-info/
4949
.installed.cfg
5050
*.egg
51-
src/lightning/*/
5251
src/*/version.info
52+
src/lightning/app/
53+
src/lightning/fabric/
54+
src/lightning/pytorch/
5355

5456
# PyInstaller
5557
# Usually these files are written by a python script from a template

src/lightning/store/README.md

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
## Getting Started
2+
3+
- Login to lightning.ai (_optional_) \<-- takes less than a minute. ⏩
4+
- Store your models on the cloud \<-- simple call: `upload_to_cloud(...)`. 🗳️
5+
- Share it with your friends \<-- just share the "username/model_name" (and version if required) format. :handshake:
6+
- They download using a simple call: `download_from_cloud("username/model_name", version="your_version")`. :wink:
7+
- They load your cool model. `load_model("username/model_name", version="your_version")`. :tada:
8+
- Lightning :zap: fast, isn't it?. :heart:
9+
10+
## Usage
11+
12+
**Storing to the cloud**
13+
14+
```python
15+
import lightning as L
16+
from sample.model import LitAutoEncoder, Encoder, Decoder
17+
18+
# Initialize your model here
19+
autoencoder = LitAutoEncoder(Encoder(), Decoder())
20+
21+
# Pass the model object:
22+
# No need to pass the username (we'll deduce ourselves), just pass the model name you want as the first argument (with an optional version):
23+
# format: `model_name:version` (version can either be latest or combination of digits and full-stops: 1.0.0 for example)
24+
L.store.upload_to_cloud("unique_model_mnist", model=autoencoder, source_code_path="sample")
25+
26+
# version:
27+
L.store.upload_to_cloud(
28+
"unique_model_mnist",
29+
version="1.0.0",
30+
model=autoencoder,
31+
source_code_path="sample/model.py",
32+
)
33+
34+
# OR: (this will save the file which has the model defined)
35+
L.store.upload_to_cloud("krshrimali/unique_model_mnist", model=autoencoder)
36+
```
37+
38+
You can also pass the checkpoint path: `to_lightning_cloud("model_name", version="latest", checkpoint_path=...)`.
39+
40+
**Downloading from the cloud**
41+
42+
At first, you need to download the model to your local machine.
43+
44+
```python
45+
import lightning as L
46+
47+
L.store.download_from_cloud(
48+
"krshrimali/unique_model_mnist",
49+
output_dir="your_output_dir",
50+
)
51+
# OR: (default to model_storage
52+
# $HOME
53+
# |- .lightning
54+
# | |- model_store
55+
# | | |- username
56+
# | | | |- <model_name>
57+
# | | | | |- version_<version_with_dots_replaced_by_underscores>
58+
# folder)
59+
L.store.download_from_cloud("krshrimali/unique_model_mnist")
60+
```
61+
62+
**Loading model**
63+
64+
Then you can load the model to your program.
65+
66+
```python
67+
import lightning as L
68+
69+
# from <username>.<model_name>.version_<version_with_dots_replaced_by_underscores>.<model_source_file> import LitAutoEncoder, Encoder, Decoder
70+
model = L.store.load_model("<username>/<model_name>>", version="version") # version is optional (defaults to latest)
71+
72+
# OR: load weights or checkpoint (if they were uploaded)
73+
L.store.load_model(
74+
"<username>/<model_name>", version="version", load_weights=True | False, load_checkpoint=True | False
75+
)
76+
print(model)
77+
```
78+
79+
**Loading model weights**
80+
81+
```python
82+
import lightning as L
83+
from sample.model import LitAutoEncoder, Encoder, Decoder
84+
85+
# If you had passed an `output_dir=...` to download_from_lightning_cloud(...), then you can just do:
86+
# from output_dir.<model_source_file> import LitAutoEncoder, Encoder, Decoder
87+
88+
model = LitAutoEncoder(Encoder(), Decoder())
89+
90+
model = L.store.load_model(load_weights=True, model=model)
91+
print("State dict: ", model.state_dict())
92+
```
93+
94+
Loading checkpoint is similar, just do: `load_checkpoint=True`.
95+
96+
## Known limitations
97+
98+
- missing web UI for user to brows his uploads
99+
- missing CLI/API to list and delete uploaded models

src/lightning/store/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from lightning.store.cloud_api import download_from_cloud, load_model, upload_to_cloud
2+
3+
__all__ = ["download_from_cloud", "load_model", "upload_to_cloud"]

src/lightning/store/authentication.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright The Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import webbrowser
17+
18+
import requests
19+
from requests.models import HTTPBasicAuth
20+
21+
from lightning.app.core.constants import get_lightning_cloud_url
22+
from lightning.app.utilities.network import LightningClient
23+
24+
_LIGHTNING_CLOUD_URL = get_lightning_cloud_url()
25+
26+
27+
def _get_user_details():
28+
client = LightningClient()
29+
user_details = client.auth_service_get_user()
30+
return user_details.username, user_details.api_key
31+
32+
33+
def _get_username_from_api_key(api_key: str):
34+
response = requests.get(url=f"{_LIGHTNING_CLOUD_URL}/v1/auth/user", auth=HTTPBasicAuth("lightning", api_key))
35+
if response.status_code != 200:
36+
raise ConnectionRefusedError(
37+
"API_KEY provided is either invalid or wasn't found in the database."
38+
" Please ensure that you passed the correct API_KEY."
39+
)
40+
return json.loads(response.content)["username"]
41+
42+
43+
def _check_browser_runnable():
44+
try:
45+
webbrowser.get()
46+
except webbrowser.Error:
47+
return False
48+
return True
49+
50+
51+
def _authenticate(inp_api_key: str = ""):
52+
# TODO: we have headless login now,
53+
# so it could be reasonable to just point to that if browser can't be opened / user can't be authed
54+
if not inp_api_key:
55+
if not _check_browser_runnable():
56+
raise ValueError(
57+
"Couldn't find a runnable browser in the current system/server."
58+
" In order to run the commands on this system, we suggest passing the `api_key`"
59+
" after logging into https://lightning.ai."
60+
)
61+
username, inp_api_key = _get_user_details()
62+
else:
63+
username = _get_username_from_api_key(inp_api_key)
64+
return username, inp_api_key

0 commit comments

Comments
 (0)