Skip to content

Commit 88c03a5

Browse files
committed
Add support for deployment list, create, and update endpoints
Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 5929c40 commit 88c03a5

File tree

1 file changed

+171
-1
lines changed

1 file changed

+171
-1
lines changed

replicate/deployment.py

Lines changed: 171 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
1+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, TypedDict, Union
22

33
from typing_extensions import Unpack, deprecated
44

55
from replicate.account import Account
6+
from replicate.pagination import Page
67
from replicate.prediction import (
78
Prediction,
89
_create_prediction_body,
@@ -140,6 +141,55 @@ class Deployments(Namespace):
140141

141142
_client: "Client"
142143

144+
def list(
145+
self,
146+
cursor: Union[str, "ellipsis", None] = ..., # noqa: F821
147+
) -> Page[Deployment]:
148+
"""
149+
List all deployments.
150+
151+
Returns:
152+
A page of Deployments.
153+
"""
154+
155+
if cursor is None:
156+
raise ValueError("cursor cannot be None")
157+
158+
resp = self._client._request(
159+
"GET", "/v1/deployments" if cursor is ... else cursor
160+
)
161+
162+
obj = resp.json()
163+
obj["results"] = [
164+
_json_to_deployment(self._client, result) for result in obj["results"]
165+
]
166+
167+
return Page[Deployment](**obj)
168+
169+
async def async_list(
170+
self,
171+
cursor: Union[str, "ellipsis", None] = ..., # noqa: F821
172+
) -> Page[Deployment]:
173+
"""
174+
List all deployments.
175+
176+
Returns:
177+
A page of Deployments.
178+
"""
179+
if cursor is None:
180+
raise ValueError("cursor cannot be None")
181+
182+
resp = await self._client._async_request(
183+
"GET", "/v1/deployments" if cursor is ... else cursor
184+
)
185+
186+
obj = resp.json()
187+
obj["results"] = [
188+
_json_to_deployment(self._client, result) for result in obj["results"]
189+
]
190+
191+
return Page[Deployment](**obj)
192+
143193
def get(self, name: str) -> Deployment:
144194
"""
145195
Get a deployment by name.
@@ -178,6 +228,126 @@ async def async_get(self, name: str) -> Deployment:
178228

179229
return _json_to_deployment(self._client, resp.json())
180230

231+
class CreateDeploymentParams(TypedDict):
232+
"""
233+
Parameters for creating a new deployment.
234+
"""
235+
236+
name: str
237+
model: str
238+
version: str
239+
hardware: str
240+
min_instances: int
241+
max_instances: int
242+
243+
def create(self, deployment_config: CreateDeploymentParams) -> Deployment:
244+
"""
245+
Create a new deployment.
246+
247+
Args:
248+
deployment_config: Configuration for the new deployment.
249+
Returns:
250+
The newly created Deployment.
251+
"""
252+
253+
if name := deployment_config.get("name", None):
254+
if "/" in name:
255+
_, name = name.split("/", 1)
256+
deployment_config["name"] = name
257+
258+
resp = self._client._request(
259+
"POST",
260+
"/v1/deployments",
261+
json=deployment_config,
262+
)
263+
264+
return _json_to_deployment(self._client, resp.json())
265+
266+
async def async_create(
267+
self, deployment_config: CreateDeploymentParams
268+
) -> Deployment:
269+
"""
270+
Create a new deployment.
271+
272+
Args:
273+
deployment_config: Configuration for the new deployment.
274+
Returns:
275+
The newly created Deployment.
276+
"""
277+
278+
if name := deployment_config.get("name", None):
279+
if "/" in name:
280+
_, name = name.split("/", 1)
281+
deployment_config["name"] = name
282+
283+
resp = await self._client._async_request(
284+
"POST",
285+
"/v1/deployments",
286+
json=deployment_config,
287+
)
288+
289+
return _json_to_deployment(self._client, resp.json())
290+
291+
class UpdateDeploymentParams(TypedDict, total=False):
292+
"""
293+
Parameters for updating an existing deployment.
294+
"""
295+
296+
version: str
297+
hardware: str
298+
min_instances: int
299+
max_instances: int
300+
301+
def update(
302+
self,
303+
deployment_owner: str,
304+
deployment_name: str,
305+
deployment_config: UpdateDeploymentParams,
306+
) -> Deployment:
307+
"""
308+
Update an existing deployment.
309+
310+
Args:
311+
deployment_owner: The owner of the deployment.
312+
deployment_name: The name of the deployment.
313+
deployment_config: Configuration updates for the deployment.
314+
Returns:
315+
The updated Deployment.
316+
"""
317+
318+
resp = self._client._request(
319+
"PATCH",
320+
f"/v1/deployments/{deployment_owner}/{deployment_name}",
321+
json=deployment_config,
322+
)
323+
324+
return _json_to_deployment(self._client, resp.json())
325+
326+
async def async_update(
327+
self,
328+
deployment_owner: str,
329+
deployment_name: str,
330+
deployment_config: UpdateDeploymentParams,
331+
) -> Deployment:
332+
"""
333+
Update an existing deployment.
334+
335+
Args:
336+
deployment_owner: The owner of the deployment.
337+
deployment_name: The name of the deployment.
338+
deployment_config: Configuration updates for the deployment.
339+
Returns:
340+
The updated Deployment.
341+
"""
342+
343+
resp = await self._client._async_request(
344+
"PATCH",
345+
f"/v1/deployments/{deployment_owner}/{deployment_name}",
346+
json=deployment_config,
347+
)
348+
349+
return _json_to_deployment(self._client, resp.json())
350+
181351
@property
182352
def predictions(self) -> "DeploymentsPredictions":
183353
"""

0 commit comments

Comments
 (0)