From 5929c400f1c8b8590d9296e8eabc937942a81f4f Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Sun, 4 Feb 2024 02:34:06 -0800 Subject: [PATCH] Add support for deployments.get endpoint Signed-off-by: Mattt Zmuda --- replicate/deployment.py | 170 +++++++++++++++++++++++++++++++++++++-- replicate/identifier.py | 2 +- replicate/model.py | 2 +- tests/test_deployment.py | 107 +++++++++++++++++++++++- 4 files changed, 268 insertions(+), 13 deletions(-) diff --git a/replicate/deployment.py b/replicate/deployment.py index ea6c242a..3c43908d 100644 --- a/replicate/deployment.py +++ b/replicate/deployment.py @@ -1,7 +1,8 @@ -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union from typing_extensions import Unpack, deprecated +from replicate.account import Account from replicate.prediction import ( Prediction, _create_prediction_body, @@ -37,6 +38,76 @@ class Deployment(Resource): The name of the deployment. """ + class Release(Resource): + """ + A release of a deployment. + """ + + number: int + """ + The release number. + """ + + model: str + """ + The model identifier string in the format of `{model_owner}/{model_name}`. + """ + + version: str + """ + The ID of the model version used in the release. + """ + + created_at: str + """ + The time the release was created. + """ + + created_by: Optional[Account] + """ + The account that created the release. + """ + + class Configuration(Resource): + """ + A configuration for a deployment. + """ + + hardware: str + """ + The SKU for the hardware used to run the model. + """ + + class Scaling(Resource): + """ + A scaling configuration for a deployment. + """ + + min_instances: int + """ + The minimum number of instances for scaling. + """ + + max_instances: int + """ + The maximum number of instances for scaling. + """ + + scaling: Scaling + """ + The scaling configuration for the deployment. + """ + + configuration: Configuration + """ + The deployment configuration. + """ + + current_release: Optional[Release] + """ + The current release of the deployment. + """ + @property @deprecated("Use `deployment.owner` instead.") def username(self) -> str: @@ -81,10 +152,12 @@ def get(self, name: str) -> Deployment: owner, name = name.split("/", 1) - deployment = Deployment(owner=owner, name=name) - deployment._client = self._client + resp = self._client._request( + "GET", + f"/v1/deployments/{owner}/{name}", + ) - return deployment + return _json_to_deployment(self._client, resp.json()) async def async_get(self, name: str) -> Deployment: """ @@ -98,10 +171,26 @@ async def async_get(self, name: str) -> Deployment: owner, name = name.split("/", 1) - deployment = Deployment(owner=owner, name=name) - deployment._client = self._client + resp = await self._client._async_request( + "GET", + f"/v1/deployments/{owner}/{name}", + ) + + return _json_to_deployment(self._client, resp.json()) + + @property + def predictions(self) -> "DeploymentsPredictions": + """ + Get predictions for deployments. + """ + + return DeploymentsPredictions(client=self._client) + - return deployment +def _json_to_deployment(client: "Client", json: Dict[str, Any]) -> Deployment: + deployment = Deployment(**json) + deployment._client = client + return deployment class DeploymentPredictions(Namespace): @@ -152,3 +241,70 @@ async def async_create( ) return _json_to_prediction(self._client, resp.json()) + + +class DeploymentsPredictions(Namespace): + """ + Namespace for operations related to predictions in deployments. + """ + + def create( + self, + deployment: Union[str, Tuple[str, str], Deployment], + input: Dict[str, Any], + **params: Unpack["Predictions.CreatePredictionParams"], + ) -> Prediction: + """ + Create a new prediction with the deployment. + """ + + url = _create_prediction_url_from_deployment(deployment) + body = _create_prediction_body(version=None, input=input, **params) + + resp = self._client._request( + "POST", + url, + json=body, + ) + + return _json_to_prediction(self._client, resp.json()) + + async def async_create( + self, + deployment: Union[str, Tuple[str, str], Deployment], + input: Dict[str, Any], + **params: Unpack["Predictions.CreatePredictionParams"], + ) -> Prediction: + """ + Create a new prediction with the deployment. + """ + + url = _create_prediction_url_from_deployment(deployment) + body = _create_prediction_body(version=None, input=input, **params) + + resp = await self._client._async_request( + "POST", + url, + json=body, + ) + + return _json_to_prediction(self._client, resp.json()) + + +def _create_prediction_url_from_deployment( + deployment: Union[str, Tuple[str, str], Deployment], +) -> str: + owner, name = None, None + if isinstance(deployment, Deployment): + owner, name = deployment.owner, deployment.name + elif isinstance(deployment, tuple): + owner, name = deployment[0], deployment[1] + elif isinstance(deployment, str): + owner, name = deployment.split("/", 1) + + if owner is None or name is None: + raise ValueError( + "deployment must be a Deployment, a tuple of (owner, name), or a string in the format 'owner/name'" + ) + + return f"/v1/deployments/{owner}/{name}/predictions" diff --git a/replicate/identifier.py b/replicate/identifier.py index b5844526..30953240 100644 --- a/replicate/identifier.py +++ b/replicate/identifier.py @@ -31,7 +31,7 @@ def parse(cls, ref: str) -> "ModelVersionIdentifier": def _resolve( - ref: Union["Model", "Version", "ModelVersionIdentifier", str] + ref: Union["Model", "Version", "ModelVersionIdentifier", str], ) -> Tuple[Optional["Version"], Optional[str], Optional[str], Optional[str]]: from replicate.model import Model # pylint: disable=import-outside-toplevel from replicate.version import Version # pylint: disable=import-outside-toplevel diff --git a/replicate/model.py b/replicate/model.py index ab8e1406..2349fe5e 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -383,7 +383,7 @@ def _json_to_model(client: "Client", json: Dict[str, Any]) -> Model: def _create_prediction_url_from_model( - model: Union[str, Tuple[str, str], "Model"] + model: Union[str, Tuple[str, str], "Model"], ) -> str: owner, name = None, None if isinstance(model, Model): diff --git a/tests/test_deployment.py b/tests/test_deployment.py index a039567a..e7ed1bf1 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -7,16 +7,46 @@ from replicate.client import Client router = respx.Router(base_url="https://api.replicate.com/v1") + +router.route( + method="GET", + path="/deployments/replicate/my-app-image-generator", + name="deployments.get", +).mock( + return_value=httpx.Response( + 201, + json={ + "owner": "replicate", + "name": "my-app-image-generator", + "current_release": { + "number": 1, + "model": "stability-ai/sdxl", + "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", + "created_at": "2024-02-15T16:32:57.018467Z", + "created_by": { + "type": "organization", + "username": "acme", + "name": "Acme Corp, Inc.", + "github_url": "https://github.com/acme", + }, + "configuration": { + "hardware": "gpu-t4", + "scaling": {"min_instances": 1, "max_instances": 5}, + }, + }, + }, + ) +) router.route( method="POST", - path="/deployments/test/model/predictions", + path="/deployments/replicate/my-app-image-generator/predictions", name="deployments.predictions.create", ).mock( return_value=httpx.Response( 201, json={ "id": "p1", - "model": "test/model", + "model": "replicate/my-app-image-generator", "version": "v1", "urls": { "get": "https://api.replicate.com/v1/predictions/p1", @@ -35,6 +65,37 @@ router.route(host="api.replicate.com").pass_through() +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_deployment_get(async_flag): + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + + if async_flag: + deployment = await client.deployments.async_get( + "replicate/my-app-image-generator" + ) + else: + deployment = client.deployments.get("replicate/my-app-image-generator") + + assert router["deployments.get"].called + + assert deployment.owner == "replicate" + assert deployment.name == "my-app-image-generator" + assert deployment.current_release is not None + assert deployment.current_release.number == 1 + assert deployment.current_release.model == "stability-ai/sdxl" + assert ( + deployment.current_release.version + == "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf" + ) + assert deployment.current_release.created_by.type == "organization" + assert deployment.current_release.created_by.username == "acme" + assert deployment.current_release.created_by.name == "Acme Corp, Inc." + assert deployment.current_release.created_by.github_url == "https://github.com/acme" + + @pytest.mark.asyncio @pytest.mark.parametrize("async_flag", [True, False]) async def test_deployment_predictions_create(async_flag): @@ -43,7 +104,9 @@ async def test_deployment_predictions_create(async_flag): ) if async_flag: - deployment = await client.deployments.async_get("test/model") + deployment = await client.deployments.async_get( + "replicate/my-app-image-generator" + ) prediction = await deployment.predictions.async_create( input={"text": "world"}, @@ -52,7 +115,7 @@ async def test_deployment_predictions_create(async_flag): stream=True, ) else: - deployment = client.deployments.get("test/model") + deployment = client.deployments.get("replicate/my-app-image-generator") prediction = deployment.predictions.create( input={"text": "world"}, @@ -71,3 +134,39 @@ async def test_deployment_predictions_create(async_flag): assert prediction.id == "p1" assert prediction.input == {"text": "world"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_deploymentspredictions_create(async_flag): + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + + if async_flag: + prediction = await client.deployments.predictions.async_create( + deployment="replicate/my-app-image-generator", + input={"text": "world"}, + webhook="https://example.com/webhook", + webhook_events_filter=["completed"], + stream=True, + ) + else: + prediction = await client.deployments.predictions.async_create( + deployment="replicate/my-app-image-generator", + input={"text": "world"}, + webhook="https://example.com/webhook", + webhook_events_filter=["completed"], + stream=True, + ) + + assert router["deployments.predictions.create"].called + request = router["deployments.predictions.create"].calls[0].request + request_body = json.loads(request.content) + assert request_body["input"] == {"text": "world"} + assert request_body["webhook"] == "https://example.com/webhook" + assert request_body["webhook_events_filter"] == ["completed"] + assert request_body["stream"] is True + + assert prediction.id == "p1" + assert prediction.input == {"text": "world"}