-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Init: Models store API #15811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Init: Models store API #15811
Changes from 41 commits
Commits
Show all changes
51 commits
Select commit
Hold shift + click to select a range
17d2ac5
placeholder
Borda 505f6b1
porting code
Borda 3a820d7
move > lai
Borda 9847557
cleaning
Borda 701e240
readme
Borda a259966
precommit
Borda 074c337
store
Borda a086d85
docs
Borda a411488
ci
Borda daad294
azure
Borda 8dec395
latest
Borda 0e22d6c
py3.9
Borda 112c6b9
dir
Borda e1268f3
_
Borda b910df4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 03781df
tests
Borda fcefc30
Merge branch 'feat/model2cloud' of https://github.com/PyTorchLightnin…
Borda 86d5820
imports
Borda 4938a8d
imports
Borda 29e41d1
cleaning
Borda e157483
env
Borda 010edd0
exceptions
Borda 3283ef9
another pass
Borda 8d873d1
error
Borda a8308df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 741b33b
assert
Borda 84c1591
paths
Borda 3d8e35f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1f003b8
debug
Borda 6782769
if
Borda a1f2567
move CONSTANTs
Borda 94c90ad
imports
Borda acc2b6e
drop test dev
Borda 274abdc
ci
Borda 8af21a2
Merge branch 'master' into feat/model2cloud
Borda eff305d
reuse constants
Borda dd4a6ea
Merge branch 'feat/model2cloud' of https://github.com/PyTorchLightnin…
Borda 3fa82b0
cleaning
Borda 193a7da
Merge branch 'master' into feat/model2cloud
Borda 955b6b8
Merge branch 'master' into feat/model2cloud
Borda ef366ad
explicit
Borda 003b25c
todo
Borda 4cef07c
Merge branch 'master' into feat/model2cloud
Borda 318036f
Apply suggestions from code review
Borda 01e3e30
protected
Borda 8f5aa05
Merge branch 'master' into feat/model2cloud
Borda c5dce15
fix
Borda 017f271
renaming
Borda 5e1f7c6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a4ab167
Merge branch 'master' into feat/model2cloud
Borda 33986c0
Merge branch 'master' into feat/model2cloud
Borda File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Python package | ||
# Create and test a Python package on multiple Python versions. | ||
# Add steps that analyze code, save the dist with the build record, publish to a PyPI-compatible index, and more: | ||
# https://docs.microsoft.com/azure/devops/pipelines/languages/python | ||
|
||
trigger: | ||
tags: | ||
include: | ||
- '*' | ||
branches: | ||
include: | ||
- "master" | ||
- "release/*" | ||
- "refs/tags/*" | ||
|
||
pr: | ||
branches: | ||
include: | ||
- "master" | ||
- "release/*" | ||
paths: | ||
include: | ||
- ".actions/**" | ||
- ".azure/app-cloud-store.yml" | ||
- "src/lightning/store/**" | ||
- "tests/tests_cloud/**" | ||
- "setup.py" | ||
exclude: | ||
- "*.md" | ||
- "**/*.md" | ||
|
||
jobs: | ||
- job: test_store | ||
pool: | ||
vmImage: $(imageName) | ||
strategy: | ||
matrix: | ||
Linux: | ||
imageName: 'ubuntu-latest' | ||
Mac: | ||
imageName: 'macOS-latest' | ||
Windows: | ||
imageName: 'windows-latest' | ||
timeoutInMinutes: "20" | ||
cancelTimeoutInMinutes: "1" | ||
workspace: | ||
clean: all | ||
variables: | ||
FREEZE_REQUIREMENTS: "1" | ||
TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html" | ||
steps: | ||
- task: UsePythonVersion@0 | ||
inputs: | ||
versionSpec: '3.9' | ||
|
||
- bash: pip install -e .[test] -f $(TORCH_URL) | ||
displayName: 'Install Lightning & dependencies' | ||
|
||
- bash: | | ||
python -m pytest -m "not cloud" tests_cloud --timeout=300 -v | ||
workingDirectory: tests/ | ||
env: | ||
API_KEY: $(LIGHTNING_API_KEY_PROD) | ||
API_USERNAME: $(LIGHTNING_USERNAME_PROD) | ||
PROJECT_ID: $(LIGHTNING_PROJECT_ID_PROD) | ||
LIGHTNING_CLOUD_URL: $(LIGHTNING_CLOUD_URL_PROD) | ||
displayName: 'Run the tests' |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
## Getting Started | ||
|
||
- Login to lightning.ai (_optional_) \<-- takes less than a minute. ⏩ | ||
- Store your models on the cloud \<-- simple call: `to_lightning_cloud(...)`. 🗳️ | ||
- Share it with your friends \<-- just share the "username/model_name" (and version if required) format. :handshake: | ||
- They download using a simple call: `download_from_lightning_cloud("username/model_name", version="your_version")`. :wink: | ||
- They load your cool model. `load_from_lightning_cloud("username/model_name", version="your_version")`. :tada: | ||
- Lightning :zap: fast, isn't it?. :heart: | ||
|
||
## Usage | ||
|
||
**Storing to the cloud** | ||
|
||
```python | ||
from lightning.store import to_lightning_cloud | ||
Borda marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from sample.model import LitAutoEncoder, Encoder, Decoder | ||
|
||
# Initialize your model here | ||
autoencoder = LitAutoEncoder(Encoder(), Decoder()) | ||
|
||
# Pass the model object: | ||
# No need to pass the username (we'll deduce ourselves), just pass the model name you want as the first argument (with an optional version): | ||
# format: `model_name:version` (version can either be latest or combination of digits and full-stops: 1.0.0 for example) | ||
to_lightning_cloud("unique_model_mnist", model=autoencoder, source_code_path="sample") | ||
|
||
# version: | ||
to_lightning_cloud( | ||
"unique_model_mnist", | ||
version="1.0.0", | ||
model=autoencoder, | ||
source_code_path="sample/model.py", | ||
) | ||
|
||
# OR: (this will save the file which has the model defined) | ||
to_lightning_cloud("krshrimali/unique_model_mnist", model=autoencoder) | ||
``` | ||
|
||
You can also pass the checkpoint path: `to_lightning_cloud("model_name", version="latest", checkpoint_path=...)`. | ||
|
||
**Downloading from the cloud** | ||
|
||
```python | ||
from lightning.store import download_from_lightning_cloud | ||
Borda marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
download_from_lightning_cloud("krshrimali/unique_model_mnist", output_dir="your_output_dir") | ||
# OR: (default to lightning_model_storage $HOME/.lightning/lightning_model_store/username/<model_name>/version_<version_with_dots_replaced_by_underscores>/ folder) | ||
download_from_lightning_cloud("krshrimali/unique_model_mnist") | ||
``` | ||
|
||
**Loading model** | ||
|
||
```python | ||
from lightning.store import load_from_lightning_cloud | ||
|
||
# from <username>.<model_name>.version_<version_with_dots_replaced_by_underscores>.<model_source_file> import LitAutoEncoder, Encoder, Decoder | ||
model = load_from_lightning_cloud( | ||
Borda marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"<username>/<model_name>>", version="version" | ||
) # version is optional (defaults to latest) | ||
|
||
# OR: load weights or checkpoint (if they were uploaded) | ||
load_from_lightning_cloud( | ||
"<username>/<model_name>", version="version", load_weights=True / False, load_checkpoint=True / False | ||
) | ||
print(model) | ||
``` | ||
|
||
**Loading model weights** | ||
Borda marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
```python | ||
from lightning.store import load_from_lightning_cloud | ||
|
||
# If you had passed an `output_dir=...` to download_from_lightning_cloud(...), then you can just do: | ||
# from output_dir.<model_source_file> import LitAutoEncoder, Encoder, Decoder | ||
|
||
model = LitAutoEncoder(Encoder(), Decoder()) | ||
|
||
model = load_from_lightning_cloud(load_weights=True, model=model) | ||
print("State dict: ", model.state_dict()) | ||
``` | ||
|
||
Loading checkpoint is similar, just do: `load_checkpoint=True`. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from lightning.store.cloud_api import download_from_lightning_cloud, load_from_lightning_cloud, to_lightning_cloud | ||
|
||
__all__ = ["download_from_lightning_cloud", "load_from_lightning_cloud", "to_lightning_cloud"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright The Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import json | ||
import webbrowser | ||
|
||
import requests | ||
from requests.models import HTTPBasicAuth | ||
|
||
from lightning.app.core.constants import get_lightning_cloud_url | ||
from lightning.app.utilities.network import LightningClient | ||
|
||
_LIGHTNING_CLOUD_URL = get_lightning_cloud_url() | ||
|
||
|
||
def get_user_details(): | ||
client = LightningClient() | ||
user_details = client.auth_service_get_user() | ||
return user_details.username, user_details.api_key | ||
|
||
|
||
def get_username_from_api_key(api_key: str): | ||
Borda marked this conversation as resolved.
Show resolved
Hide resolved
|
||
response = requests.get(url=f"{_LIGHTNING_CLOUD_URL}/v1/auth/user", auth=HTTPBasicAuth("lightning", api_key)) | ||
if response.status_code != 200: | ||
raise ConnectionRefusedError( | ||
"API_KEY provided is either invalid or wasn't found in the database." | ||
" Please ensure that you passed the correct API_KEY." | ||
) | ||
return json.loads(response.content)["username"] | ||
|
||
|
||
def _check_browser_runnable(): | ||
try: | ||
webbrowser.get() | ||
except webbrowser.Error: | ||
return False | ||
return True | ||
|
||
|
||
def authenticate(inp_api_key: str = ""): | ||
Borda marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if not inp_api_key: | ||
if not _check_browser_runnable(): | ||
raise ValueError( | ||
"Couldn't find a runnable browser in the current system/server." | ||
" In order to run the commands on this system, we suggest passing the `api_key`" | ||
" after logging into https://lightning.ai." | ||
) | ||
username, inp_api_key = get_user_details() | ||
else: | ||
username = get_username_from_api_key(inp_api_key) | ||
return username, inp_api_key |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.