Skip to content

Add support for deployments.get endpoint #247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 163 additions & 7 deletions replicate/deployment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union

from typing_extensions import Unpack, deprecated

from replicate.account import Account
from replicate.prediction import (
Prediction,
_create_prediction_body,
Expand Down Expand Up @@ -37,6 +38,76 @@ class Deployment(Resource):
The name of the deployment.
"""

class Release(Resource):
"""
A release of a deployment.
"""

number: int
"""
The release number.
"""

model: str
"""
The model identifier string in the format of `{model_owner}/{model_name}`.
"""

version: str
"""
The ID of the model version used in the release.
"""

created_at: str
"""
The time the release was created.
"""

created_by: Optional[Account]
"""
The account that created the release.
"""

class Configuration(Resource):
"""
A configuration for a deployment.
"""

hardware: str
"""
The SKU for the hardware used to run the model.
"""

class Scaling(Resource):
"""
A scaling configuration for a deployment.
"""

min_instances: int
"""
The minimum number of instances for scaling.
"""

max_instances: int
"""
The maximum number of instances for scaling.
"""

scaling: Scaling
"""
The scaling configuration for the deployment.
"""

configuration: Configuration
"""
The deployment configuration.
"""

current_release: Optional[Release]
"""
The current release of the deployment.
"""

@property
@deprecated("Use `deployment.owner` instead.")
def username(self) -> str:
Expand Down Expand Up @@ -81,10 +152,12 @@ def get(self, name: str) -> Deployment:

owner, name = name.split("/", 1)

deployment = Deployment(owner=owner, name=name)
deployment._client = self._client
resp = self._client._request(
"GET",
f"/v1/deployments/{owner}/{name}",
)

return deployment
return _json_to_deployment(self._client, resp.json())

async def async_get(self, name: str) -> Deployment:
"""
Expand All @@ -98,10 +171,26 @@ async def async_get(self, name: str) -> Deployment:

owner, name = name.split("/", 1)

deployment = Deployment(owner=owner, name=name)
deployment._client = self._client
resp = await self._client._async_request(
"GET",
f"/v1/deployments/{owner}/{name}",
)

return _json_to_deployment(self._client, resp.json())

@property
def predictions(self) -> "DeploymentsPredictions":
"""
Get predictions for deployments.
"""

return DeploymentsPredictions(client=self._client)


return deployment
def _json_to_deployment(client: "Client", json: Dict[str, Any]) -> Deployment:
deployment = Deployment(**json)
deployment._client = client
return deployment


class DeploymentPredictions(Namespace):
Expand Down Expand Up @@ -152,3 +241,70 @@ async def async_create(
)

return _json_to_prediction(self._client, resp.json())


class DeploymentsPredictions(Namespace):
"""
Namespace for operations related to predictions in deployments.
"""

def create(
self,
deployment: Union[str, Tuple[str, str], Deployment],
input: Dict[str, Any],
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Prediction:
"""
Create a new prediction with the deployment.
"""

url = _create_prediction_url_from_deployment(deployment)
body = _create_prediction_body(version=None, input=input, **params)

resp = self._client._request(
"POST",
url,
json=body,
)

return _json_to_prediction(self._client, resp.json())

async def async_create(
self,
deployment: Union[str, Tuple[str, str], Deployment],
input: Dict[str, Any],
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Prediction:
"""
Create a new prediction with the deployment.
"""

url = _create_prediction_url_from_deployment(deployment)
body = _create_prediction_body(version=None, input=input, **params)

resp = await self._client._async_request(
"POST",
url,
json=body,
)

return _json_to_prediction(self._client, resp.json())


def _create_prediction_url_from_deployment(
deployment: Union[str, Tuple[str, str], Deployment],
) -> str:
owner, name = None, None
if isinstance(deployment, Deployment):
owner, name = deployment.owner, deployment.name
elif isinstance(deployment, tuple):
owner, name = deployment[0], deployment[1]
elif isinstance(deployment, str):
owner, name = deployment.split("/", 1)

if owner is None or name is None:
raise ValueError(
"deployment must be a Deployment, a tuple of (owner, name), or a string in the format 'owner/name'"
)

return f"/v1/deployments/{owner}/{name}/predictions"
2 changes: 1 addition & 1 deletion replicate/identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def parse(cls, ref: str) -> "ModelVersionIdentifier":


def _resolve(
ref: Union["Model", "Version", "ModelVersionIdentifier", str]
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
) -> Tuple[Optional["Version"], Optional[str], Optional[str], Optional[str]]:
from replicate.model import Model # pylint: disable=import-outside-toplevel
from replicate.version import Version # pylint: disable=import-outside-toplevel
Expand Down
2 changes: 1 addition & 1 deletion replicate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def _json_to_model(client: "Client", json: Dict[str, Any]) -> Model:


def _create_prediction_url_from_model(
model: Union[str, Tuple[str, str], "Model"]
model: Union[str, Tuple[str, str], "Model"],
) -> str:
owner, name = None, None
if isinstance(model, Model):
Expand Down
107 changes: 103 additions & 4 deletions tests/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,46 @@
from replicate.client import Client

router = respx.Router(base_url="https://api.replicate.com/v1")

router.route(
method="GET",
path="/deployments/replicate/my-app-image-generator",
name="deployments.get",
).mock(
return_value=httpx.Response(
201,
json={
"owner": "replicate",
"name": "my-app-image-generator",
"current_release": {
"number": 1,
"model": "stability-ai/sdxl",
"version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
"created_at": "2024-02-15T16:32:57.018467Z",
"created_by": {
"type": "organization",
"username": "acme",
"name": "Acme Corp, Inc.",
"github_url": "https://github.com/acme",
},
"configuration": {
"hardware": "gpu-t4",
"scaling": {"min_instances": 1, "max_instances": 5},
},
},
},
)
)
router.route(
method="POST",
path="/deployments/test/model/predictions",
path="/deployments/replicate/my-app-image-generator/predictions",
name="deployments.predictions.create",
).mock(
return_value=httpx.Response(
201,
json={
"id": "p1",
"model": "test/model",
"model": "replicate/my-app-image-generator",
"version": "v1",
"urls": {
"get": "https://api.replicate.com/v1/predictions/p1",
Expand All @@ -35,6 +65,37 @@
router.route(host="api.replicate.com").pass_through()


@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
async def test_deployment_get(async_flag):
client = Client(
api_token="test-token", transport=httpx.MockTransport(router.handler)
)

if async_flag:
deployment = await client.deployments.async_get(
"replicate/my-app-image-generator"
)
else:
deployment = client.deployments.get("replicate/my-app-image-generator")

assert router["deployments.get"].called

assert deployment.owner == "replicate"
assert deployment.name == "my-app-image-generator"
assert deployment.current_release is not None
assert deployment.current_release.number == 1
assert deployment.current_release.model == "stability-ai/sdxl"
assert (
deployment.current_release.version
== "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf"
)
assert deployment.current_release.created_by.type == "organization"
assert deployment.current_release.created_by.username == "acme"
assert deployment.current_release.created_by.name == "Acme Corp, Inc."
assert deployment.current_release.created_by.github_url == "https://github.com/acme"


@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
async def test_deployment_predictions_create(async_flag):
Expand All @@ -43,7 +104,9 @@ async def test_deployment_predictions_create(async_flag):
)

if async_flag:
deployment = await client.deployments.async_get("test/model")
deployment = await client.deployments.async_get(
"replicate/my-app-image-generator"
)

prediction = await deployment.predictions.async_create(
input={"text": "world"},
Expand All @@ -52,7 +115,7 @@ async def test_deployment_predictions_create(async_flag):
stream=True,
)
else:
deployment = client.deployments.get("test/model")
deployment = client.deployments.get("replicate/my-app-image-generator")

prediction = deployment.predictions.create(
input={"text": "world"},
Expand All @@ -71,3 +134,39 @@ async def test_deployment_predictions_create(async_flag):

assert prediction.id == "p1"
assert prediction.input == {"text": "world"}


@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
async def test_deploymentspredictions_create(async_flag):
client = Client(
api_token="test-token", transport=httpx.MockTransport(router.handler)
)

if async_flag:
prediction = await client.deployments.predictions.async_create(
deployment="replicate/my-app-image-generator",
input={"text": "world"},
webhook="https://example.com/webhook",
webhook_events_filter=["completed"],
stream=True,
)
else:
prediction = await client.deployments.predictions.async_create(
deployment="replicate/my-app-image-generator",
input={"text": "world"},
webhook="https://example.com/webhook",
webhook_events_filter=["completed"],
stream=True,
)

assert router["deployments.predictions.create"].called
request = router["deployments.predictions.create"].calls[0].request
request_body = json.loads(request.content)
assert request_body["input"] == {"text": "world"}
assert request_body["webhook"] == "https://example.com/webhook"
assert request_body["webhook_events_filter"] == ["completed"]
assert request_body["stream"] is True

assert prediction.id == "p1"
assert prediction.input == {"text": "world"}