Skip to content

Commit 5929c40

Browse files
committed
Add support for deployments.get endpoint
Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 2f4f180 commit 5929c40

File tree

4 files changed

+268
-13
lines changed

4 files changed

+268
-13
lines changed

replicate/deployment.py

Lines changed: 163 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import TYPE_CHECKING, Any, Dict
1+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
22

33
from typing_extensions import Unpack, deprecated
44

5+
from replicate.account import Account
56
from replicate.prediction import (
67
Prediction,
78
_create_prediction_body,
@@ -37,6 +38,76 @@ class Deployment(Resource):
3738
The name of the deployment.
3839
"""
3940

41+
class Release(Resource):
42+
"""
43+
A release of a deployment.
44+
"""
45+
46+
number: int
47+
"""
48+
The release number.
49+
"""
50+
51+
model: str
52+
"""
53+
The model identifier string in the format of `{model_owner}/{model_name}`.
54+
"""
55+
56+
version: str
57+
"""
58+
The ID of the model version used in the release.
59+
"""
60+
61+
created_at: str
62+
"""
63+
The time the release was created.
64+
"""
65+
66+
created_by: Optional[Account]
67+
"""
68+
The account that created the release.
69+
"""
70+
71+
class Configuration(Resource):
72+
"""
73+
A configuration for a deployment.
74+
"""
75+
76+
hardware: str
77+
"""
78+
The SKU for the hardware used to run the model.
79+
"""
80+
81+
class Scaling(Resource):
82+
"""
83+
A scaling configuration for a deployment.
84+
"""
85+
86+
min_instances: int
87+
"""
88+
The minimum number of instances for scaling.
89+
"""
90+
91+
max_instances: int
92+
"""
93+
The maximum number of instances for scaling.
94+
"""
95+
96+
scaling: Scaling
97+
"""
98+
The scaling configuration for the deployment.
99+
"""
100+
101+
configuration: Configuration
102+
"""
103+
The deployment configuration.
104+
"""
105+
106+
current_release: Optional[Release]
107+
"""
108+
The current release of the deployment.
109+
"""
110+
40111
@property
41112
@deprecated("Use `deployment.owner` instead.")
42113
def username(self) -> str:
@@ -81,10 +152,12 @@ def get(self, name: str) -> Deployment:
81152

82153
owner, name = name.split("/", 1)
83154

84-
deployment = Deployment(owner=owner, name=name)
85-
deployment._client = self._client
155+
resp = self._client._request(
156+
"GET",
157+
f"/v1/deployments/{owner}/{name}",
158+
)
86159

87-
return deployment
160+
return _json_to_deployment(self._client, resp.json())
88161

89162
async def async_get(self, name: str) -> Deployment:
90163
"""
@@ -98,10 +171,26 @@ async def async_get(self, name: str) -> Deployment:
98171

99172
owner, name = name.split("/", 1)
100173

101-
deployment = Deployment(owner=owner, name=name)
102-
deployment._client = self._client
174+
resp = await self._client._async_request(
175+
"GET",
176+
f"/v1/deployments/{owner}/{name}",
177+
)
178+
179+
return _json_to_deployment(self._client, resp.json())
180+
181+
@property
182+
def predictions(self) -> "DeploymentsPredictions":
183+
"""
184+
Get predictions for deployments.
185+
"""
186+
187+
return DeploymentsPredictions(client=self._client)
188+
103189

104-
return deployment
190+
def _json_to_deployment(client: "Client", json: Dict[str, Any]) -> Deployment:
191+
deployment = Deployment(**json)
192+
deployment._client = client
193+
return deployment
105194

106195

107196
class DeploymentPredictions(Namespace):
@@ -152,3 +241,70 @@ async def async_create(
152241
)
153242

154243
return _json_to_prediction(self._client, resp.json())
244+
245+
246+
class DeploymentsPredictions(Namespace):
247+
"""
248+
Namespace for operations related to predictions in deployments.
249+
"""
250+
251+
def create(
252+
self,
253+
deployment: Union[str, Tuple[str, str], Deployment],
254+
input: Dict[str, Any],
255+
**params: Unpack["Predictions.CreatePredictionParams"],
256+
) -> Prediction:
257+
"""
258+
Create a new prediction with the deployment.
259+
"""
260+
261+
url = _create_prediction_url_from_deployment(deployment)
262+
body = _create_prediction_body(version=None, input=input, **params)
263+
264+
resp = self._client._request(
265+
"POST",
266+
url,
267+
json=body,
268+
)
269+
270+
return _json_to_prediction(self._client, resp.json())
271+
272+
async def async_create(
273+
self,
274+
deployment: Union[str, Tuple[str, str], Deployment],
275+
input: Dict[str, Any],
276+
**params: Unpack["Predictions.CreatePredictionParams"],
277+
) -> Prediction:
278+
"""
279+
Create a new prediction with the deployment.
280+
"""
281+
282+
url = _create_prediction_url_from_deployment(deployment)
283+
body = _create_prediction_body(version=None, input=input, **params)
284+
285+
resp = await self._client._async_request(
286+
"POST",
287+
url,
288+
json=body,
289+
)
290+
291+
return _json_to_prediction(self._client, resp.json())
292+
293+
294+
def _create_prediction_url_from_deployment(
295+
deployment: Union[str, Tuple[str, str], Deployment],
296+
) -> str:
297+
owner, name = None, None
298+
if isinstance(deployment, Deployment):
299+
owner, name = deployment.owner, deployment.name
300+
elif isinstance(deployment, tuple):
301+
owner, name = deployment[0], deployment[1]
302+
elif isinstance(deployment, str):
303+
owner, name = deployment.split("/", 1)
304+
305+
if owner is None or name is None:
306+
raise ValueError(
307+
"deployment must be a Deployment, a tuple of (owner, name), or a string in the format 'owner/name'"
308+
)
309+
310+
return f"/v1/deployments/{owner}/{name}/predictions"

replicate/identifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def parse(cls, ref: str) -> "ModelVersionIdentifier":
3131

3232

3333
def _resolve(
34-
ref: Union["Model", "Version", "ModelVersionIdentifier", str]
34+
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
3535
) -> Tuple[Optional["Version"], Optional[str], Optional[str], Optional[str]]:
3636
from replicate.model import Model # pylint: disable=import-outside-toplevel
3737
from replicate.version import Version # pylint: disable=import-outside-toplevel

replicate/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def _json_to_model(client: "Client", json: Dict[str, Any]) -> Model:
383383

384384

385385
def _create_prediction_url_from_model(
386-
model: Union[str, Tuple[str, str], "Model"]
386+
model: Union[str, Tuple[str, str], "Model"],
387387
) -> str:
388388
owner, name = None, None
389389
if isinstance(model, Model):

tests/test_deployment.py

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,46 @@
77
from replicate.client import Client
88

99
router = respx.Router(base_url="https://api.replicate.com/v1")
10+
11+
router.route(
12+
method="GET",
13+
path="/deployments/replicate/my-app-image-generator",
14+
name="deployments.get",
15+
).mock(
16+
return_value=httpx.Response(
17+
201,
18+
json={
19+
"owner": "replicate",
20+
"name": "my-app-image-generator",
21+
"current_release": {
22+
"number": 1,
23+
"model": "stability-ai/sdxl",
24+
"version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
25+
"created_at": "2024-02-15T16:32:57.018467Z",
26+
"created_by": {
27+
"type": "organization",
28+
"username": "acme",
29+
"name": "Acme Corp, Inc.",
30+
"github_url": "https://github.com/acme",
31+
},
32+
"configuration": {
33+
"hardware": "gpu-t4",
34+
"scaling": {"min_instances": 1, "max_instances": 5},
35+
},
36+
},
37+
},
38+
)
39+
)
1040
router.route(
1141
method="POST",
12-
path="/deployments/test/model/predictions",
42+
path="/deployments/replicate/my-app-image-generator/predictions",
1343
name="deployments.predictions.create",
1444
).mock(
1545
return_value=httpx.Response(
1646
201,
1747
json={
1848
"id": "p1",
19-
"model": "test/model",
49+
"model": "replicate/my-app-image-generator",
2050
"version": "v1",
2151
"urls": {
2252
"get": "https://api.replicate.com/v1/predictions/p1",
@@ -35,6 +65,37 @@
3565
router.route(host="api.replicate.com").pass_through()
3666

3767

68+
@pytest.mark.asyncio
69+
@pytest.mark.parametrize("async_flag", [True, False])
70+
async def test_deployment_get(async_flag):
71+
client = Client(
72+
api_token="test-token", transport=httpx.MockTransport(router.handler)
73+
)
74+
75+
if async_flag:
76+
deployment = await client.deployments.async_get(
77+
"replicate/my-app-image-generator"
78+
)
79+
else:
80+
deployment = client.deployments.get("replicate/my-app-image-generator")
81+
82+
assert router["deployments.get"].called
83+
84+
assert deployment.owner == "replicate"
85+
assert deployment.name == "my-app-image-generator"
86+
assert deployment.current_release is not None
87+
assert deployment.current_release.number == 1
88+
assert deployment.current_release.model == "stability-ai/sdxl"
89+
assert (
90+
deployment.current_release.version
91+
== "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf"
92+
)
93+
assert deployment.current_release.created_by.type == "organization"
94+
assert deployment.current_release.created_by.username == "acme"
95+
assert deployment.current_release.created_by.name == "Acme Corp, Inc."
96+
assert deployment.current_release.created_by.github_url == "https://github.com/acme"
97+
98+
3899
@pytest.mark.asyncio
39100
@pytest.mark.parametrize("async_flag", [True, False])
40101
async def test_deployment_predictions_create(async_flag):
@@ -43,7 +104,9 @@ async def test_deployment_predictions_create(async_flag):
43104
)
44105

45106
if async_flag:
46-
deployment = await client.deployments.async_get("test/model")
107+
deployment = await client.deployments.async_get(
108+
"replicate/my-app-image-generator"
109+
)
47110

48111
prediction = await deployment.predictions.async_create(
49112
input={"text": "world"},
@@ -52,7 +115,7 @@ async def test_deployment_predictions_create(async_flag):
52115
stream=True,
53116
)
54117
else:
55-
deployment = client.deployments.get("test/model")
118+
deployment = client.deployments.get("replicate/my-app-image-generator")
56119

57120
prediction = deployment.predictions.create(
58121
input={"text": "world"},
@@ -71,3 +134,39 @@ async def test_deployment_predictions_create(async_flag):
71134

72135
assert prediction.id == "p1"
73136
assert prediction.input == {"text": "world"}
137+
138+
139+
@pytest.mark.asyncio
140+
@pytest.mark.parametrize("async_flag", [True, False])
141+
async def test_deploymentspredictions_create(async_flag):
142+
client = Client(
143+
api_token="test-token", transport=httpx.MockTransport(router.handler)
144+
)
145+
146+
if async_flag:
147+
prediction = await client.deployments.predictions.async_create(
148+
deployment="replicate/my-app-image-generator",
149+
input={"text": "world"},
150+
webhook="https://example.com/webhook",
151+
webhook_events_filter=["completed"],
152+
stream=True,
153+
)
154+
else:
155+
prediction = await client.deployments.predictions.async_create(
156+
deployment="replicate/my-app-image-generator",
157+
input={"text": "world"},
158+
webhook="https://example.com/webhook",
159+
webhook_events_filter=["completed"],
160+
stream=True,
161+
)
162+
163+
assert router["deployments.predictions.create"].called
164+
request = router["deployments.predictions.create"].calls[0].request
165+
request_body = json.loads(request.content)
166+
assert request_body["input"] == {"text": "world"}
167+
assert request_body["webhook"] == "https://example.com/webhook"
168+
assert request_body["webhook_events_filter"] == ["completed"]
169+
assert request_body["stream"] is True
170+
171+
assert prediction.id == "p1"
172+
assert prediction.input == {"text": "world"}

0 commit comments

Comments
 (0)