Skip to content

Allow run and stream methods to take model arguments, when supported #210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 26 additions & 25 deletions replicate/identifier.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,9 @@
import re
from typing import NamedTuple
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union


class ModelIdentifier(NamedTuple):
"""
A reference to a model in the format owner/name:version.
"""

owner: str
name: str

@classmethod
def parse(cls, ref: str) -> "ModelIdentifier":
"""
Split a reference in the format owner/name:version into its components.
"""

match = re.match(r"^(?P<owner>[^/]+)/(?P<name>[^:]+)$", ref)
if not match:
raise ValueError(
f"Invalid reference to model version: {ref}. Expected format: owner/name"
)

return cls(match.group("owner"), match.group("name"))
if TYPE_CHECKING:
from replicate.model import Model
from replicate.version import Version


class ModelVersionIdentifier(NamedTuple):
Expand All @@ -32,18 +13,38 @@ class ModelVersionIdentifier(NamedTuple):

owner: str
name: str
version: str
version: Optional[str] = None

@classmethod
def parse(cls, ref: str) -> "ModelVersionIdentifier":
"""
Split a reference in the format owner/name:version into its components.
"""

match = re.match(r"^(?P<owner>[^/]+)/(?P<name>[^:]+):(?P<version>.+)$", ref)
match = re.match(r"^(?P<owner>[^/]+)/(?P<name>[^/:]+)(:(?P<version>.+))?$", ref)
if not match:
raise ValueError(
f"Invalid reference to model version: {ref}. Expected format: owner/name:version"
)

return cls(match.group("owner"), match.group("name"), match.group("version"))


def _resolve(
ref: Union["Model", "Version", "ModelVersionIdentifier", str]
) -> Tuple[Optional["Version"], Optional[str], Optional[str], Optional[str]]:
from replicate.model import Model # pylint: disable=import-outside-toplevel
from replicate.version import Version # pylint: disable=import-outside-toplevel

version = None
owner, name, version_id = None, None, None
if isinstance(ref, Model):
owner, name = ref.owner, ref.name
elif isinstance(ref, Version):
version = ref
version_id = ref.id
elif isinstance(ref, ModelVersionIdentifier):
owner, name, version_id = ref
elif isinstance(ref, str):
owner, name, version_id = ModelVersionIdentifier.parse(ref)
return version, owner, name, version_id
12 changes: 8 additions & 4 deletions replicate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing_extensions import NotRequired, TypedDict, Unpack, deprecated

from replicate.exceptions import ReplicateException
from replicate.identifier import ModelIdentifier
from replicate.identifier import ModelVersionIdentifier
from replicate.pagination import Page
from replicate.prediction import (
Prediction,
Expand Down Expand Up @@ -296,7 +296,7 @@ class ModelsPredictions(Namespace):

def create(
self,
model: Optional[Union[str, Tuple[str, str], "Model"]],
model: Union[str, Tuple[str, str], "Model"],
input: Dict[str, Any],
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Prediction:
Expand All @@ -317,7 +317,7 @@ def create(

async def async_create(
self,
model: Optional[Union[str, Tuple[str, str], "Model"]],
model: Union[str, Tuple[str, str], "Model"],
input: Dict[str, Any],
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Prediction:
Expand Down Expand Up @@ -391,7 +391,11 @@ def _create_prediction_url_from_model(
elif isinstance(model, tuple):
owner, name = model[0], model[1]
elif isinstance(model, str):
owner, name = ModelIdentifier.parse(model)
owner, name, version_id = ModelVersionIdentifier.parse(model)
if version_id is not None:
raise ValueError(
f"Invalid reference to model version: {model}. Expected model or reference in the format owner/name"
)

if owner is None or name is None:
raise ValueError(
Expand Down
93 changes: 54 additions & 39 deletions replicate/run.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,50 @@
import asyncio
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union

from typing_extensions import Unpack

from replicate import identifier
from replicate.exceptions import ModelError
from replicate.identifier import ModelVersionIdentifier
from replicate.model import Model
from replicate.prediction import Prediction
from replicate.schema import make_schema_backwards_compatible
from replicate.version import Versions
from replicate.version import Version, Versions

if TYPE_CHECKING:
from replicate.client import Client
from replicate.identifier import ModelVersionIdentifier
from replicate.prediction import Predictions


def run(
client: "Client",
ref: str,
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
input: Optional[Dict[str, Any]] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
"""
Run a model and wait for its output.
"""

owner, name, version_id = ModelVersionIdentifier.parse(ref)
version, owner, name, version_id = identifier._resolve(ref)

prediction = client.predictions.create(
version=version_id, input=input or {}, **params
)
if version_id is not None:
prediction = client.predictions.create(
version=version_id, input=input or {}, **params
)
elif owner and name:
prediction = client.models.predictions.create(
model=(owner, name), input=input or {}, **params
)
else:
raise ValueError(
f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version"
)

if owner and name:
if not version and (owner and name and version_id):
version = Versions(client, model=(owner, name)).get(version_id)

# Return an iterator of the output
schema = make_schema_backwards_compatible(
version.openapi_schema, version.cog_version
)
output = schema["components"]["schemas"]["Output"]
if (
output.get("type") == "array"
and output.get("x-cog-array-type") == "iterator"
):
return prediction.output_iterator()
if version and (iterator := _make_output_iterator(version, prediction)):
return iterator

prediction.wait()

Expand All @@ -53,42 +56,54 @@ def run(

async def async_run(
client: "Client",
ref: str,
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
input: Optional[Dict[str, Any]] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
"""
Run a model and wait for its output asynchronously.
"""

owner, name, version_id = ModelVersionIdentifier.parse(ref)
version, owner, name, version_id = identifier._resolve(ref)

prediction = await client.predictions.async_create(
version=version_id, input=input or {}, **params
)
if version or version_id:
prediction = await client.predictions.async_create(
version=(version or version_id), input=input or {}, **params
)
elif owner and name:
prediction = await client.models.predictions.async_create(
model=(owner, name), input=input or {}, **params
)
else:
raise ValueError(
f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version"
)

if owner and name:
version = await Versions(client, model=(owner, name)).async_get(version_id)
if not version and (owner and name and version_id):
version = Versions(client, model=(owner, name)).get(version_id)

# Return an iterator of the output
schema = make_schema_backwards_compatible(
version.openapi_schema, version.cog_version
)
output = schema["components"]["schemas"]["Output"]
if (
output.get("type") == "array"
and output.get("x-cog-array-type") == "iterator"
):
return prediction.output_iterator()
if version and (iterator := _make_output_iterator(version, prediction)):
return iterator

while prediction.status not in ["succeeded", "failed", "canceled"]:
await asyncio.sleep(client.poll_interval)
prediction = await client.predictions.async_get(prediction.id)
prediction.wait()

if prediction.status == "failed":
raise ModelError(prediction.error)

return prediction.output


def _make_output_iterator(
version: Version, prediction: Prediction
) -> Optional[Iterator[Any]]:
schema = make_schema_backwards_compatible(
version.openapi_schema, version.cog_version
)
output = schema["components"]["schemas"]["Output"]
if output.get("type") == "array" and output.get("x-cog-array-type") == "iterator":
return prediction.output_iterator()

return None


__all__: List = []
49 changes: 38 additions & 11 deletions replicate/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
Iterator,
List,
Optional,
Union,
)

from typing_extensions import Unpack

from replicate import identifier
from replicate.exceptions import ReplicateError
from replicate.identifier import ModelVersionIdentifier

try:
from pydantic import v1 as pydantic # type: ignore
Expand All @@ -24,7 +25,10 @@
import httpx

from replicate.client import Client
from replicate.identifier import ModelVersionIdentifier
from replicate.model import Model
from replicate.prediction import Predictions
from replicate.version import Version


class ServerSentEvent(pydantic.BaseModel): # type: ignore
Expand Down Expand Up @@ -157,7 +161,7 @@ async def __aiter__(self) -> AsyncIterator[ServerSentEvent]:

def stream(
client: "Client",
ref: str,
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
input: Optional[Dict[str, Any]] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Iterator[ServerSentEvent]:
Expand All @@ -168,10 +172,20 @@ def stream(
params = params or {}
params["stream"] = True

_, _, version_id = ModelVersionIdentifier.parse(ref)
prediction = client.predictions.create(
version=version_id, input=input or {}, **params
)
version, owner, name, version_id = identifier._resolve(ref)

if version or version_id:
prediction = client.predictions.create(
version=(version or version_id), input=input or {}, **params
)
elif owner and name:
prediction = client.models.predictions.create(
model=(owner, name), input=input or {}, **params
)
else:
raise ValueError(
f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version"
)

url = prediction.urls and prediction.urls.get("stream", None)
if not url or not isinstance(url, str):
Expand All @@ -187,7 +201,7 @@ def stream(

async def async_stream(
client: "Client",
ref: str,
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
input: Optional[Dict[str, Any]] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> AsyncIterator[ServerSentEvent]:
Expand All @@ -198,10 +212,20 @@ async def async_stream(
params = params or {}
params["stream"] = True

_, _, version_id = ModelVersionIdentifier.parse(ref)
prediction = await client.predictions.async_create(
version=version_id, input=input or {}, **params
)
version, owner, name, version_id = identifier._resolve(ref)

if version or version_id:
prediction = await client.predictions.async_create(
version=(version or version_id), input=input or {}, **params
)
elif owner and name:
prediction = await client.models.predictions.async_create(
model=(owner, name), input=input or {}, **params
)
else:
raise ValueError(
f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version"
)

url = prediction.urls and prediction.urls.get("stream", None)
if not url or not isinstance(url, str):
Expand All @@ -214,3 +238,6 @@ async def async_stream(
async with client._async_client.stream("GET", url, headers=headers) as response:
async for event in EventSource(response):
yield event


__all__ = ["ServerSentEvent"]
6 changes: 3 additions & 3 deletions replicate/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing_extensions import NotRequired, Unpack

from replicate.files import upload_file
from replicate.identifier import ModelIdentifier, ModelVersionIdentifier
from replicate.identifier import ModelVersionIdentifier
from replicate.json import encode_json
from replicate.model import Model
from replicate.pagination import Page
Expand Down Expand Up @@ -373,14 +373,14 @@ def _create_training_url_from_shorthand(ref: str) -> str:

def _create_training_url_from_model_and_version(
model: Union[str, Tuple[str, str], "Model"],
version: Union[str, Version],
version: Union[str, "Version"],
) -> str:
if isinstance(model, Model):
owner, name = model.owner, model.name
elif isinstance(model, tuple):
owner, name = model[0], model[1]
elif isinstance(model, str):
owner, name = ModelIdentifier.parse(model)
owner, name, _ = ModelVersionIdentifier.parse(model)
else:
raise ValueError(
"model must be a Model, a tuple of (owner, name), or a string in the format 'owner/name'"
Expand Down
Loading