Skip to content

Commit 21dd8cd

Browse files
committed
Add support for deployments endpoints
Signed-off-by: Mattt Zmuda <[email protected]>
1 parent cd3d2eb commit 21dd8cd

File tree

4 files changed

+680
-13
lines changed

4 files changed

+680
-13
lines changed

replicate/deployment.py

Lines changed: 351 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from typing import TYPE_CHECKING, Any, Dict
1+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, TypedDict, Union
22

33
from typing_extensions import Unpack, deprecated
44

5+
from replicate.account import Account
6+
from replicate.pagination import Page
57
from replicate.prediction import (
68
Prediction,
79
_create_prediction_body,
@@ -37,6 +39,76 @@ class Deployment(Resource):
3739
The name of the deployment.
3840
"""
3941

42+
class Release(Resource):
43+
"""
44+
A release of a deployment.
45+
"""
46+
47+
number: int
48+
"""
49+
The release number.
50+
"""
51+
52+
model: str
53+
"""
54+
The model identifier string in the format of `{model_owner}/{model_name}`.
55+
"""
56+
57+
version: str
58+
"""
59+
The ID of the model version used in the release.
60+
"""
61+
62+
created_at: str
63+
"""
64+
The time the release was created.
65+
"""
66+
67+
created_by: Optional[Account]
68+
"""
69+
The account that created the release.
70+
"""
71+
72+
class Configuration(Resource):
73+
"""
74+
A configuration for a deployment.
75+
"""
76+
77+
hardware: str
78+
"""
79+
The SKU for the hardware used to run the model.
80+
"""
81+
82+
class Scaling(Resource):
83+
"""
84+
A scaling configuration for a deployment.
85+
"""
86+
87+
min_instances: int
88+
"""
89+
The minimum number of instances for scaling.
90+
"""
91+
92+
max_instances: int
93+
"""
94+
The maximum number of instances for scaling.
95+
"""
96+
97+
scaling: Scaling
98+
"""
99+
The scaling configuration for the deployment.
100+
"""
101+
102+
configuration: Configuration
103+
"""
104+
The deployment configuration.
105+
"""
106+
107+
current_release: Optional[Release]
108+
"""
109+
The current release of the deployment.
110+
"""
111+
40112
@property
41113
@deprecated("Use `deployment.owner` instead.")
42114
def username(self) -> str:
@@ -69,6 +141,55 @@ class Deployments(Namespace):
69141

70142
_client: "Client"
71143

144+
def list(
145+
self,
146+
cursor: Union[str, "ellipsis", None] = ..., # noqa: F821
147+
) -> Page[Deployment]:
148+
"""
149+
List all deployments.
150+
151+
Returns:
152+
A page of Deployments.
153+
"""
154+
155+
if cursor is None:
156+
raise ValueError("cursor cannot be None")
157+
158+
resp = self._client._request(
159+
"GET", "/v1/deployments" if cursor is ... else cursor
160+
)
161+
162+
obj = resp.json()
163+
obj["results"] = [
164+
_json_to_deployment(self._client, result) for result in obj["results"]
165+
]
166+
167+
return Page[Deployment](**obj)
168+
169+
async def async_list(
170+
self,
171+
cursor: Union[str, "ellipsis", None] = ..., # noqa: F821
172+
) -> Page[Deployment]:
173+
"""
174+
List all deployments.
175+
176+
Returns:
177+
A page of Deployments.
178+
"""
179+
if cursor is None:
180+
raise ValueError("cursor cannot be None")
181+
182+
resp = await self._client._async_request(
183+
"GET", "/v1/deployments" if cursor is ... else cursor
184+
)
185+
186+
obj = resp.json()
187+
obj["results"] = [
188+
_json_to_deployment(self._client, result) for result in obj["results"]
189+
]
190+
191+
return Page[Deployment](**obj)
192+
72193
def get(self, name: str) -> Deployment:
73194
"""
74195
Get a deployment by name.
@@ -81,10 +202,12 @@ def get(self, name: str) -> Deployment:
81202

82203
owner, name = name.split("/", 1)
83204

84-
deployment = Deployment(owner=owner, name=name)
85-
deployment._client = self._client
205+
resp = self._client._request(
206+
"GET",
207+
f"/v1/deployments/{owner}/{name}",
208+
)
86209

87-
return deployment
210+
return _json_to_deployment(self._client, resp.json())
88211

89212
async def async_get(self, name: str) -> Deployment:
90213
"""
@@ -98,10 +221,164 @@ async def async_get(self, name: str) -> Deployment:
98221

99222
owner, name = name.split("/", 1)
100223

101-
deployment = Deployment(owner=owner, name=name)
102-
deployment._client = self._client
224+
resp = await self._client._async_request(
225+
"GET",
226+
f"/v1/deployments/{owner}/{name}",
227+
)
103228

104-
return deployment
229+
return _json_to_deployment(self._client, resp.json())
230+
231+
class CreateDeploymentParams(TypedDict):
232+
"""
233+
Parameters for creating a new deployment.
234+
"""
235+
236+
name: str
237+
"""The name of the deployment."""
238+
239+
model: str
240+
"""The model identifier string in the format of `{model_owner}/{model_name}`."""
241+
242+
version: str
243+
"""The version of the model to deploy."""
244+
245+
hardware: str
246+
"""The SKU for the hardware used to run the model."""
247+
248+
min_instances: int
249+
"""The minimum number of instances for scaling."""
250+
251+
max_instances: int
252+
"""The maximum number of instances for scaling."""
253+
254+
def create(self, **params: Unpack[CreateDeploymentParams]) -> Deployment:
255+
"""
256+
Create a new deployment.
257+
258+
Args:
259+
params: Configuration for the new deployment.
260+
Returns:
261+
The newly created Deployment.
262+
"""
263+
264+
if name := params.get("name", None):
265+
if "/" in name:
266+
_, name = name.split("/", 1)
267+
params["name"] = name
268+
269+
resp = self._client._request(
270+
"POST",
271+
"/v1/deployments",
272+
json=params,
273+
)
274+
275+
return _json_to_deployment(self._client, resp.json())
276+
277+
async def async_create(
278+
self, **params: Unpack[CreateDeploymentParams]
279+
) -> Deployment:
280+
"""
281+
Create a new deployment.
282+
283+
Args:
284+
params: Configuration for the new deployment.
285+
Returns:
286+
The newly created Deployment.
287+
"""
288+
289+
if name := params.get("name", None):
290+
if "/" in name:
291+
_, name = name.split("/", 1)
292+
params["name"] = name
293+
294+
resp = await self._client._async_request(
295+
"POST",
296+
"/v1/deployments",
297+
json=params,
298+
)
299+
300+
return _json_to_deployment(self._client, resp.json())
301+
302+
class UpdateDeploymentParams(TypedDict, total=False):
303+
"""
304+
Parameters for updating an existing deployment.
305+
"""
306+
307+
version: str
308+
"""The version of the model to deploy."""
309+
310+
hardware: str
311+
"""The SKU for the hardware used to run the model."""
312+
313+
min_instances: int
314+
"""The minimum number of instances for scaling."""
315+
316+
max_instances: int
317+
"""The maximum number of instances for scaling."""
318+
319+
def update(
320+
self,
321+
deployment_owner: str,
322+
deployment_name: str,
323+
**params: Unpack[UpdateDeploymentParams],
324+
) -> Deployment:
325+
"""
326+
Update an existing deployment.
327+
328+
Args:
329+
deployment_owner: The owner of the deployment.
330+
deployment_name: The name of the deployment.
331+
params: Configuration updates for the deployment.
332+
Returns:
333+
The updated Deployment.
334+
"""
335+
336+
resp = self._client._request(
337+
"PATCH",
338+
f"/v1/deployments/{deployment_owner}/{deployment_name}",
339+
json=params,
340+
)
341+
342+
return _json_to_deployment(self._client, resp.json())
343+
344+
async def async_update(
345+
self,
346+
deployment_owner: str,
347+
deployment_name: str,
348+
**params: Unpack[UpdateDeploymentParams],
349+
) -> Deployment:
350+
"""
351+
Update an existing deployment.
352+
353+
Args:
354+
deployment_owner: The owner of the deployment.
355+
deployment_name: The name of the deployment.
356+
params: Configuration updates for the deployment.
357+
Returns:
358+
The updated Deployment.
359+
"""
360+
361+
resp = await self._client._async_request(
362+
"PATCH",
363+
f"/v1/deployments/{deployment_owner}/{deployment_name}",
364+
json=params,
365+
)
366+
367+
return _json_to_deployment(self._client, resp.json())
368+
369+
@property
370+
def predictions(self) -> "DeploymentsPredictions":
371+
"""
372+
Get predictions for deployments.
373+
"""
374+
375+
return DeploymentsPredictions(client=self._client)
376+
377+
378+
def _json_to_deployment(client: "Client", json: Dict[str, Any]) -> Deployment:
379+
deployment = Deployment(**json)
380+
deployment._client = client
381+
return deployment
105382

106383

107384
class DeploymentPredictions(Namespace):
@@ -152,3 +429,70 @@ async def async_create(
152429
)
153430

154431
return _json_to_prediction(self._client, resp.json())
432+
433+
434+
class DeploymentsPredictions(Namespace):
435+
"""
436+
Namespace for operations related to predictions in deployments.
437+
"""
438+
439+
def create(
440+
self,
441+
deployment: Union[str, Tuple[str, str], Deployment],
442+
input: Dict[str, Any],
443+
**params: Unpack["Predictions.CreatePredictionParams"],
444+
) -> Prediction:
445+
"""
446+
Create a new prediction with the deployment.
447+
"""
448+
449+
url = _create_prediction_url_from_deployment(deployment)
450+
body = _create_prediction_body(version=None, input=input, **params)
451+
452+
resp = self._client._request(
453+
"POST",
454+
url,
455+
json=body,
456+
)
457+
458+
return _json_to_prediction(self._client, resp.json())
459+
460+
async def async_create(
461+
self,
462+
deployment: Union[str, Tuple[str, str], Deployment],
463+
input: Dict[str, Any],
464+
**params: Unpack["Predictions.CreatePredictionParams"],
465+
) -> Prediction:
466+
"""
467+
Create a new prediction with the deployment.
468+
"""
469+
470+
url = _create_prediction_url_from_deployment(deployment)
471+
body = _create_prediction_body(version=None, input=input, **params)
472+
473+
resp = await self._client._async_request(
474+
"POST",
475+
url,
476+
json=body,
477+
)
478+
479+
return _json_to_prediction(self._client, resp.json())
480+
481+
482+
def _create_prediction_url_from_deployment(
483+
deployment: Union[str, Tuple[str, str], Deployment],
484+
) -> str:
485+
owner, name = None, None
486+
if isinstance(deployment, Deployment):
487+
owner, name = deployment.owner, deployment.name
488+
elif isinstance(deployment, tuple):
489+
owner, name = deployment[0], deployment[1]
490+
elif isinstance(deployment, str):
491+
owner, name = deployment.split("/", 1)
492+
493+
if owner is None or name is None:
494+
raise ValueError(
495+
"deployment must be a Deployment, a tuple of (owner, name), or a string in the format 'owner/name'"
496+
)
497+
498+
return f"/v1/deployments/{owner}/{name}/predictions"

0 commit comments

Comments
 (0)