Skip to content

Add support for models.predictions.create endpoint #207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ jobs:

name: "Test Python ${{ matrix.python-version }}"

timeout-minutes: 10

strategy:
fail-fast: false
matrix:
Expand Down
23 changes: 23 additions & 0 deletions replicate/identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,29 @@
from typing import NamedTuple


class ModelIdentifier(NamedTuple):
"""
A reference to a model in the format owner/name:version.
"""

owner: str
name: str

@classmethod
def parse(cls, ref: str) -> "ModelIdentifier":
"""
Split a reference in the format owner/name:version into its components.
"""

match = re.match(r"^(?P<owner>[^/]+)/(?P<name>[^:]+)$", ref)
if not match:
raise ValueError(
f"Invalid reference to model version: {ref}. Expected format: owner/name"
)

return cls(match.group("owner"), match.group("name"))


class ModelVersionIdentifier(NamedTuple):
"""
A reference to a model version in the format owner/name:version.
Expand Down
85 changes: 83 additions & 2 deletions replicate/model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union

from typing_extensions import NotRequired, TypedDict, Unpack, deprecated

from replicate.exceptions import ReplicateException
from replicate.identifier import ModelIdentifier
from replicate.pagination import Page
from replicate.prediction import Prediction
from replicate.prediction import (
Prediction,
_create_prediction_body,
_json_to_prediction,
)
from replicate.resource import Namespace, Resource
from replicate.version import Version, Versions

Expand All @@ -16,6 +21,7 @@

if TYPE_CHECKING:
from replicate.client import Client
from replicate.prediction import Predictions


class Model(Resource):
Expand Down Expand Up @@ -140,6 +146,14 @@ class Models(Namespace):

model = Model

@property
def predictions(self) -> "ModelsPredictions":
"""
Get a namespace for operations related to predictions on a model.
"""

return ModelsPredictions(client=self._client)

def list(self, cursor: Union[str, "ellipsis", None] = ...) -> Page[Model]: # noqa: F821
"""
List all public models.
Expand Down Expand Up @@ -275,6 +289,54 @@ async def async_create(
return _json_to_model(self._client, resp.json())


class ModelsPredictions(Namespace):
"""
Namespace for operations related to predictions in a deployment.
"""

def create(
self,
model: Optional[Union[str, Tuple[str, str], "Model"]],
input: Dict[str, Any],
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Prediction:
"""
Create a new prediction with the deployment.
"""

url = _create_prediction_url_from_model(model)
body = _create_prediction_body(version=None, input=input, **params)

resp = self._client._request(
"POST",
url,
json=body,
)

return _json_to_prediction(self._client, resp.json())

async def async_create(
self,
model: Optional[Union[str, Tuple[str, str], "Model"]],
input: Dict[str, Any],
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Prediction:
"""
Create a new prediction with the deployment.
"""

url = _create_prediction_url_from_model(model)
body = _create_prediction_body(version=None, input=input, **params)

resp = await self._client._async_request(
"POST",
url,
json=body,
)

return _json_to_prediction(self._client, resp.json())


def _create_model_body( # pylint: disable=too-many-arguments
owner: str,
name: str,
Expand Down Expand Up @@ -318,3 +380,22 @@ def _json_to_model(client: "Client", json: Dict[str, Any]) -> Model:
if model.default_example is not None:
model.default_example._client = client
return model


def _create_prediction_url_from_model(
model: Union[str, Tuple[str, str], "Model"]
) -> str:
owner, name = None, None
if isinstance(model, Model):
owner, name = model.owner, model.name
elif isinstance(model, tuple):
owner, name = model[0], model[1]
elif isinstance(model, str):
owner, name = ModelIdentifier.parse(model)

if owner is None or name is None:
raise ValueError(
"model must be a Model, a tuple of (owner, name), or a string in the format 'owner/name'"
)

return f"/v1/models/{owner}/{name}/predictions"
8 changes: 6 additions & 2 deletions replicate/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing_extensions import NotRequired, Unpack

from replicate.files import upload_file
from replicate.identifier import ModelVersionIdentifier
from replicate.identifier import ModelIdentifier, ModelVersionIdentifier
from replicate.json import encode_json
from replicate.model import Model
from replicate.pagination import Page
Expand Down Expand Up @@ -378,8 +378,12 @@ def _create_training_url_from_model_and_version(
owner, name = model.owner, model.name
elif isinstance(model, tuple):
owner, name = model[0], model[1]
elif isinstance(model, str):
owner, name = ModelIdentifier.parse(model)
else:
raise ValueError("model must be a Model or a tuple of (owner, name)")
raise ValueError(
"model must be a Model, a tuple of (owner, name), or a string in the format 'owner/name'"
)

if isinstance(version, Version):
version_id = version.id
Expand Down
2 changes: 1 addition & 1 deletion replicate/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Versions(Namespace):
model: Tuple[str, str]

def __init__(
self, client: "Client", model: Union["Model", str, Tuple[str, str]]
self, client: "Client", model: Union[str, Tuple[str, str], "Model"]
) -> None:
super().__init__(client=client)

Expand Down
55 changes: 55 additions & 0 deletions tests/cassettes/models-predictions-create.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
interactions:
- request:
body: '{"input": {"prompt": "Please write a haiku about llamas."}}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '59'
content-type:
- application/json
host:
- api.replicate.com
user-agent:
- replicate-python/0.21.0
method: POST
uri: https://api.replicate.com/v1/models/meta/llama-2-70b-chat/predictions
response:
content: '{"id":"heat2o3bzn3ahtr6bjfftvbaci","model":"replicate/lifeboat-70b","version":"d-c6559c5791b50af57b69f4a73f8e021c","input":{"prompt":"Please
write a haiku about llamas."},"logs":"","error":null,"status":"starting","created_at":"2023-11-27T13:35:45.99397566Z","urls":{"cancel":"https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel","get":"https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci"}}

'
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 82cac197efaec53d-SEA
Connection:
- keep-alive
Content-Length:
- '431'
Content-Type:
- application/json
Date:
- Mon, 27 Nov 2023 13:35:46 GMT
NEL:
- '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}'
Report-To:
- '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=7R5RONMF6xaGRc39n0wnSe3jU1FbpX64Xz4U%2B%2F2nasvFaz0pKARxPhnzDgYkLaWgdK9zWrD2jxU04aKOy5HMPHAXboJ993L4zfsOyto56lBtdqSjNgkptzzxYEsKD%2FxIhe2F"}],"group":"cf-nel","max_age":604800}'
Server:
- cloudflare
Strict-Transport-Security:
- max-age=15552000
ratelimit-remaining:
- '599'
ratelimit-reset:
- '1'
via:
- 1.1 google
http_version: HTTP/1.1
status_code: 201
version: 1
23 changes: 23 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,26 @@ async def test_models_create_with_positional_arguments(async_flag):
assert model.owner == "test"
assert model.name == "python-example"
assert model.visibility == "private"


@pytest.mark.vcr("models-predictions-create.yaml")
@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
async def test_models_predictions_create(async_flag):
input = {
"prompt": "Please write a haiku about llamas.",
}

if async_flag:
prediction = await replicate.models.predictions.async_create(
"meta/llama-2-70b-chat", input=input
)
else:
prediction = replicate.models.predictions.create(
"meta/llama-2-70b-chat", input=input
)

assert prediction.id is not None
# assert prediction.model == "meta/llama-2-70b-chat"
assert prediction.model == "replicate/lifeboat-70b" # FIXME: this is temporary
assert prediction.status == "starting"