diff --git a/replicate/account.py b/replicate/account.py index 69e522b8..fdcdca6d 100644 --- a/replicate/account.py +++ b/replicate/account.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Literal +from typing import Any, Dict, Literal, Optional from replicate.resource import Namespace, Resource @@ -17,7 +17,7 @@ class Account(Resource): name: str """The name of the account.""" - github_url: str + github_url: Optional[str] """The GitHub URL of the account.""" diff --git a/replicate/deployment.py b/replicate/deployment.py index ea6c242a..1f1b81c0 100644 --- a/replicate/deployment.py +++ b/replicate/deployment.py @@ -1,7 +1,9 @@ -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, TypedDict, Union from typing_extensions import Unpack, deprecated +from replicate.account import Account +from replicate.pagination import Page from replicate.prediction import ( Prediction, _create_prediction_body, @@ -37,6 +39,76 @@ class Deployment(Resource): The name of the deployment. """ + class Release(Resource): + """ + A release of a deployment. + """ + + number: int + """ + The release number. + """ + + model: str + """ + The model identifier string in the format of `{model_owner}/{model_name}`. + """ + + version: str + """ + The ID of the model version used in the release. + """ + + created_at: str + """ + The time the release was created. + """ + + created_by: Optional[Account] + """ + The account that created the release. + """ + + class Configuration(Resource): + """ + A configuration for a deployment. + """ + + hardware: str + """ + The SKU for the hardware used to run the model. + """ + + class Scaling(Resource): + """ + A scaling configuration for a deployment. + """ + + min_instances: int + """ + The minimum number of instances for scaling. + """ + + max_instances: int + """ + The maximum number of instances for scaling. + """ + + scaling: Scaling + """ + The scaling configuration for the deployment. + """ + + configuration: Configuration + """ + The deployment configuration. + """ + + current_release: Optional[Release] + """ + The current release of the deployment. + """ + @property @deprecated("Use `deployment.owner` instead.") def username(self) -> str: @@ -69,6 +141,55 @@ class Deployments(Namespace): _client: "Client" + def list( + self, + cursor: Union[str, "ellipsis", None] = ..., # noqa: F821 + ) -> Page[Deployment]: + """ + List all deployments. + + Returns: + A page of Deployments. + """ + + if cursor is None: + raise ValueError("cursor cannot be None") + + resp = self._client._request( + "GET", "/v1/deployments" if cursor is ... else cursor + ) + + obj = resp.json() + obj["results"] = [ + _json_to_deployment(self._client, result) for result in obj["results"] + ] + + return Page[Deployment](**obj) + + async def async_list( + self, + cursor: Union[str, "ellipsis", None] = ..., # noqa: F821 + ) -> Page[Deployment]: + """ + List all deployments. + + Returns: + A page of Deployments. + """ + if cursor is None: + raise ValueError("cursor cannot be None") + + resp = await self._client._async_request( + "GET", "/v1/deployments" if cursor is ... else cursor + ) + + obj = resp.json() + obj["results"] = [ + _json_to_deployment(self._client, result) for result in obj["results"] + ] + + return Page[Deployment](**obj) + def get(self, name: str) -> Deployment: """ Get a deployment by name. @@ -81,10 +202,12 @@ def get(self, name: str) -> Deployment: owner, name = name.split("/", 1) - deployment = Deployment(owner=owner, name=name) - deployment._client = self._client + resp = self._client._request( + "GET", + f"/v1/deployments/{owner}/{name}", + ) - return deployment + return _json_to_deployment(self._client, resp.json()) async def async_get(self, name: str) -> Deployment: """ @@ -98,10 +221,164 @@ async def async_get(self, name: str) -> Deployment: owner, name = name.split("/", 1) - deployment = Deployment(owner=owner, name=name) - deployment._client = self._client + resp = await self._client._async_request( + "GET", + f"/v1/deployments/{owner}/{name}", + ) - return deployment + return _json_to_deployment(self._client, resp.json()) + + class CreateDeploymentParams(TypedDict): + """ + Parameters for creating a new deployment. + """ + + name: str + """The name of the deployment.""" + + model: str + """The model identifier string in the format of `{model_owner}/{model_name}`.""" + + version: str + """The version of the model to deploy.""" + + hardware: str + """The SKU for the hardware used to run the model.""" + + min_instances: int + """The minimum number of instances for scaling.""" + + max_instances: int + """The maximum number of instances for scaling.""" + + def create(self, **params: Unpack[CreateDeploymentParams]) -> Deployment: + """ + Create a new deployment. + + Args: + params: Configuration for the new deployment. + Returns: + The newly created Deployment. + """ + + if name := params.get("name", None): + if "/" in name: + _, name = name.split("/", 1) + params["name"] = name + + resp = self._client._request( + "POST", + "/v1/deployments", + json=params, + ) + + return _json_to_deployment(self._client, resp.json()) + + async def async_create( + self, **params: Unpack[CreateDeploymentParams] + ) -> Deployment: + """ + Create a new deployment. + + Args: + params: Configuration for the new deployment. + Returns: + The newly created Deployment. + """ + + if name := params.get("name", None): + if "/" in name: + _, name = name.split("/", 1) + params["name"] = name + + resp = await self._client._async_request( + "POST", + "/v1/deployments", + json=params, + ) + + return _json_to_deployment(self._client, resp.json()) + + class UpdateDeploymentParams(TypedDict, total=False): + """ + Parameters for updating an existing deployment. + """ + + version: str + """The version of the model to deploy.""" + + hardware: str + """The SKU for the hardware used to run the model.""" + + min_instances: int + """The minimum number of instances for scaling.""" + + max_instances: int + """The maximum number of instances for scaling.""" + + def update( + self, + deployment_owner: str, + deployment_name: str, + **params: Unpack[UpdateDeploymentParams], + ) -> Deployment: + """ + Update an existing deployment. + + Args: + deployment_owner: The owner of the deployment. + deployment_name: The name of the deployment. + params: Configuration updates for the deployment. + Returns: + The updated Deployment. + """ + + resp = self._client._request( + "PATCH", + f"/v1/deployments/{deployment_owner}/{deployment_name}", + json=params, + ) + + return _json_to_deployment(self._client, resp.json()) + + async def async_update( + self, + deployment_owner: str, + deployment_name: str, + **params: Unpack[UpdateDeploymentParams], + ) -> Deployment: + """ + Update an existing deployment. + + Args: + deployment_owner: The owner of the deployment. + deployment_name: The name of the deployment. + params: Configuration updates for the deployment. + Returns: + The updated Deployment. + """ + + resp = await self._client._async_request( + "PATCH", + f"/v1/deployments/{deployment_owner}/{deployment_name}", + json=params, + ) + + return _json_to_deployment(self._client, resp.json()) + + @property + def predictions(self) -> "DeploymentsPredictions": + """ + Get predictions for deployments. + """ + + return DeploymentsPredictions(client=self._client) + + +def _json_to_deployment(client: "Client", json: Dict[str, Any]) -> Deployment: + deployment = Deployment(**json) + deployment._client = client + return deployment class DeploymentPredictions(Namespace): @@ -152,3 +429,70 @@ async def async_create( ) return _json_to_prediction(self._client, resp.json()) + + +class DeploymentsPredictions(Namespace): + """ + Namespace for operations related to predictions in deployments. + """ + + def create( + self, + deployment: Union[str, Tuple[str, str], Deployment], + input: Dict[str, Any], + **params: Unpack["Predictions.CreatePredictionParams"], + ) -> Prediction: + """ + Create a new prediction with the deployment. + """ + + url = _create_prediction_url_from_deployment(deployment) + body = _create_prediction_body(version=None, input=input, **params) + + resp = self._client._request( + "POST", + url, + json=body, + ) + + return _json_to_prediction(self._client, resp.json()) + + async def async_create( + self, + deployment: Union[str, Tuple[str, str], Deployment], + input: Dict[str, Any], + **params: Unpack["Predictions.CreatePredictionParams"], + ) -> Prediction: + """ + Create a new prediction with the deployment. + """ + + url = _create_prediction_url_from_deployment(deployment) + body = _create_prediction_body(version=None, input=input, **params) + + resp = await self._client._async_request( + "POST", + url, + json=body, + ) + + return _json_to_prediction(self._client, resp.json()) + + +def _create_prediction_url_from_deployment( + deployment: Union[str, Tuple[str, str], Deployment], +) -> str: + owner, name = None, None + if isinstance(deployment, Deployment): + owner, name = deployment.owner, deployment.name + elif isinstance(deployment, tuple): + owner, name = deployment[0], deployment[1] + elif isinstance(deployment, str): + owner, name = deployment.split("/", 1) + + if owner is None or name is None: + raise ValueError( + "deployment must be a Deployment, a tuple of (owner, name), or a string in the format 'owner/name'" + ) + + return f"/v1/deployments/{owner}/{name}/predictions" diff --git a/tests/test_deployment.py b/tests/test_deployment.py index a039567a..5f93fdc3 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -7,16 +7,46 @@ from replicate.client import Client router = respx.Router(base_url="https://api.replicate.com/v1") + +router.route( + method="GET", + path="/deployments/replicate/my-app-image-generator", + name="deployments.get", +).mock( + return_value=httpx.Response( + 201, + json={ + "owner": "replicate", + "name": "my-app-image-generator", + "current_release": { + "number": 1, + "model": "stability-ai/sdxl", + "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", + "created_at": "2024-02-15T16:32:57.018467Z", + "created_by": { + "type": "organization", + "username": "acme", + "name": "Acme Corp, Inc.", + "github_url": "https://github.com/acme", + }, + "configuration": { + "hardware": "gpu-t4", + "scaling": {"min_instances": 1, "max_instances": 5}, + }, + }, + }, + ) +) router.route( method="POST", - path="/deployments/test/model/predictions", + path="/deployments/replicate/my-app-image-generator/predictions", name="deployments.predictions.create", ).mock( return_value=httpx.Response( 201, json={ "id": "p1", - "model": "test/model", + "model": "replicate/my-app-image-generator", "version": "v1", "urls": { "get": "https://api.replicate.com/v1/predictions/p1", @@ -32,9 +62,154 @@ }, ) ) +router.route( + method="GET", + path="/deployments", + name="deployments.list", +).mock( + return_value=httpx.Response( + 200, + json={ + "results": [ + { + "owner": "acme", + "name": "image-upscaler", + "current_release": { + "number": 1, + "model": "acme/esrgan", + "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + "created_at": "2022-01-01T00:00:00Z", + "created_by": { + "type": "organization", + "username": "acme", + "name": "Acme, Inc.", + }, + "configuration": { + "hardware": "gpu-t4", + "scaling": {"min_instances": 1, "max_instances": 5}, + }, + }, + }, + { + "owner": "acme", + "name": "text-generator", + "current_release": { + "number": 2, + "model": "acme/acme-llama", + "version": "4b7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccbb", + "created_at": "2022-02-02T00:00:00Z", + "created_by": { + "type": "organization", + "username": "acme", + "name": "Acme, Inc.", + }, + "configuration": { + "hardware": "cpu", + "scaling": {"min_instances": 2, "max_instances": 10}, + }, + }, + }, + ] + }, + ) +) + +router.route( + method="POST", + path="/deployments", + name="deployments.create", +).mock( + return_value=httpx.Response( + 201, + json={ + "owner": "acme", + "name": "new-deployment", + "current_release": { + "number": 1, + "model": "acme/new-model", + "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + "created_at": "2022-01-01T00:00:00Z", + "created_by": { + "type": "organization", + "username": "acme", + "name": "Acme, Inc.", + }, + "configuration": { + "hardware": "gpu-t4", + "scaling": {"min_instances": 1, "max_instances": 5}, + }, + }, + }, + ) +) + + +router.route( + method="PATCH", + path="/deployments/acme/image-upscaler", + name="deployments.update", +).mock( + return_value=httpx.Response( + 200, + json={ + "owner": "acme", + "name": "image-upscaler", + "current_release": { + "number": 2, + "model": "acme/esrgan-updated", + "version": "new-version-id", + "created_at": "2022-02-02T00:00:00Z", + "created_by": { + "type": "organization", + "username": "acme", + "name": "Acme, Inc.", + }, + "configuration": { + "hardware": "gpu-v100", + "scaling": {"min_instances": 2, "max_instances": 10}, + }, + }, + }, + ) +) + + router.route(host="api.replicate.com").pass_through() +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_deployment_get(async_flag): + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + + if async_flag: + deployment = await client.deployments.async_get( + "replicate/my-app-image-generator" + ) + else: + deployment = client.deployments.get("replicate/my-app-image-generator") + + assert router["deployments.get"].called + + assert deployment.owner == "replicate" + assert deployment.name == "my-app-image-generator" + assert deployment.current_release is not None + assert deployment.current_release.number == 1 + assert deployment.current_release.model == "stability-ai/sdxl" + assert ( + deployment.current_release.version + == "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf" + ) + assert deployment.current_release is not None + assert deployment.current_release.created_by is not None + assert deployment.current_release.created_by.type == "organization" + assert deployment.current_release.created_by.username == "acme" + assert deployment.current_release.created_by.name == "Acme Corp, Inc." + assert deployment.current_release.created_by.github_url == "https://github.com/acme" + + @pytest.mark.asyncio @pytest.mark.parametrize("async_flag", [True, False]) async def test_deployment_predictions_create(async_flag): @@ -43,7 +218,9 @@ async def test_deployment_predictions_create(async_flag): ) if async_flag: - deployment = await client.deployments.async_get("test/model") + deployment = await client.deployments.async_get( + "replicate/my-app-image-generator" + ) prediction = await deployment.predictions.async_create( input={"text": "world"}, @@ -52,7 +229,7 @@ async def test_deployment_predictions_create(async_flag): stream=True, ) else: - deployment = client.deployments.get("test/model") + deployment = client.deployments.get("replicate/my-app-image-generator") prediction = deployment.predictions.create( input={"text": "world"}, @@ -71,3 +248,149 @@ async def test_deployment_predictions_create(async_flag): assert prediction.id == "p1" assert prediction.input == {"text": "world"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_deploymentspredictions_create(async_flag): + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + + if async_flag: + prediction = await client.deployments.predictions.async_create( + deployment="replicate/my-app-image-generator", + input={"text": "world"}, + webhook="https://example.com/webhook", + webhook_events_filter=["completed"], + stream=True, + ) + else: + prediction = await client.deployments.predictions.async_create( + deployment="replicate/my-app-image-generator", + input={"text": "world"}, + webhook="https://example.com/webhook", + webhook_events_filter=["completed"], + stream=True, + ) + + assert router["deployments.predictions.create"].called + request = router["deployments.predictions.create"].calls[0].request + request_body = json.loads(request.content) + assert request_body["input"] == {"text": "world"} + assert request_body["webhook"] == "https://example.com/webhook" + assert request_body["webhook_events_filter"] == ["completed"] + assert request_body["stream"] is True + + assert prediction.id == "p1" + assert prediction.input == {"text": "world"} + + +@respx.mock +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_deployments_list(async_flag): + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + + if async_flag: + deployments = await client.deployments.async_list() + else: + deployments = client.deployments.list() + + assert router["deployments.list"].called + + assert len(deployments.results) == 2 + assert deployments.results[0].owner == "acme" + assert deployments.results[0].name == "image-upscaler" + assert deployments.results[0].current_release is not None + assert deployments.results[0].current_release.number == 1 + assert deployments.results[0].current_release.model == "acme/esrgan" + assert deployments.results[1].owner == "acme" + assert deployments.results[1].name == "text-generator" + assert deployments.results[1].current_release is not None + assert deployments.results[1].current_release.number == 2 + assert deployments.results[1].current_release.model == "acme/acme-llama" + + +@respx.mock +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_create_deployment(async_flag): + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + + config = { + "name": "new-deployment", + "model": "acme/new-model", + "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + "hardware": "gpu-t4", + "min_instances": 1, + "max_instances": 5, + } + + if async_flag: + deployment = await client.deployments.async_create(**config) + else: + deployment = client.deployments.create(**config) + + assert router["deployments.create"].called + + assert deployment.owner == "acme" + assert deployment.name == "new-deployment" + assert deployment.current_release is not None + assert deployment.current_release.number == 1 + assert deployment.current_release.model == "acme/new-model" + assert ( + deployment.current_release.version + == "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" + ) + assert deployment.current_release.created_by is not None + assert deployment.current_release.created_by.type == "organization" + assert deployment.current_release.created_by.username == "acme" + assert deployment.current_release.created_by.name == "Acme, Inc." + assert deployment.current_release.configuration.hardware == "gpu-t4" + assert deployment.current_release.configuration.scaling.min_instances == 1 + assert deployment.current_release.configuration.scaling.max_instances == 5 + + +@respx.mock +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_update_deployment(async_flag): + config = { + "version": "new-version-id", + "hardware": "gpu-v100", + "min_instances": 2, + "max_instances": 10, + } + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + + if async_flag: + updated_deployment = await client.deployments.async_update( + deployment_owner="acme", deployment_name="image-upscaler", **config + ) + else: + updated_deployment = client.deployments.update( + deployment_owner="acme", deployment_name="image-upscaler", **config + ) + + assert router["deployments.update"].called + request = router["deployments.update"].calls[0].request + request_body = json.loads(request.content) + assert request_body == config + + assert updated_deployment.owner == "acme" + assert updated_deployment.name == "image-upscaler" + assert updated_deployment.current_release is not None + assert updated_deployment.current_release.number == 2 + assert updated_deployment.current_release.model == "acme/esrgan-updated" + assert updated_deployment.current_release.version == "new-version-id" + assert updated_deployment.current_release.configuration.hardware == "gpu-v100" + assert updated_deployment.current_release.configuration.scaling.min_instances == 2 + assert updated_deployment.current_release.configuration.scaling.max_instances == 10