Skip to content

Commit a79b0d2

Browse files
authored
Improve support for models.get and models.delete endpoints (#302)
Add overloads to pass model owner and name as distinct arguments or as a single string. --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent f30dc0d commit a79b0d2

File tree

2 files changed

+90
-9
lines changed

2 files changed

+90
-9
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ background = Image.open("/tmp/out.png")
270270

271271
## List models
272272

273-
You can the models you've created:
273+
You can list the models you've created:
274274

275275
```python
276276
replicate.models.list()

replicate/model.py

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union
1+
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union, overload
22

33
from typing_extensions import NotRequired, TypedDict, Unpack, deprecated
44

@@ -207,31 +207,77 @@ async def async_list(
207207

208208
return Page[Model](**obj)
209209

210-
def get(self, key: str) -> Model:
210+
@overload
211+
def get(self, key: str) -> Model: ...
212+
213+
@overload
214+
def get(self, owner: str, name: str) -> Model: ...
215+
216+
def get(self, *args, **kwargs) -> Model:
217+
"""
218+
Get a model by name.
219+
"""
220+
221+
url = _get_model_url(*args, **kwargs)
222+
resp = self._client._request("GET", url)
223+
224+
return _json_to_model(self._client, resp.json())
225+
226+
@overload
227+
async def async_get(self, key: str) -> Model: ...
228+
229+
@overload
230+
async def async_get(self, owner: str, name: str) -> Model: ...
231+
232+
async def async_get(self, *args, **kwargs) -> Model:
211233
"""
212234
Get a model by name.
213235
214236
Args:
215-
key: The qualified name of the model, in the format `owner/model-name`.
237+
key: The qualified name of the model, in the format `owner/name`.
216238
Returns:
217239
The model.
218240
"""
219241

220-
resp = self._client._request("GET", f"/v1/models/{key}")
242+
url = _get_model_url(*args, **kwargs)
243+
resp = await self._client._async_request("GET", url)
221244

222245
return _json_to_model(self._client, resp.json())
223246

224-
async def async_get(self, key: str) -> Model:
247+
@overload
248+
def delete(self, key: str) -> Model: ...
249+
250+
@overload
251+
def delete(self, owner: str, name: str) -> Model: ...
252+
253+
def delete(self, *args, **kwargs) -> Model:
225254
"""
226-
Get a model by name.
255+
Delete a model by name.
256+
"""
257+
258+
url = _delete_model_url(*args, **kwargs)
259+
resp = self._client._request("DELETE", url)
260+
261+
return _json_to_model(self._client, resp.json())
262+
263+
@overload
264+
async def async_delete(self, key: str) -> Model: ...
265+
266+
@overload
267+
async def async_delete(self, owner: str, name: str) -> Model: ...
268+
269+
async def async_delete(self, *args, **kwargs) -> Model:
270+
"""
271+
Delete a model by name.
227272
228273
Args:
229-
key: The qualified name of the model, in the format `owner/model-name`.
274+
key: The qualified name of the model, in the format `owner/name`.
230275
Returns:
231276
The model.
232277
"""
233278

234-
resp = await self._client._async_request("GET", f"/v1/models/{key}")
279+
url = _delete_model_url(*args, **kwargs)
280+
resp = await self._client._async_request("DELETE", url)
235281

236282
return _json_to_model(self._client, resp.json())
237283

@@ -374,6 +420,41 @@ def _create_model_body( # pylint: disable=too-many-arguments
374420
return body
375421

376422

423+
def _get_model_url(*args, **kwargs) -> str:
424+
if len(args) > 0 and len(kwargs) > 0:
425+
raise ValueError("Cannot mix positional and keyword arguments")
426+
427+
owner = kwargs.get("owner", None)
428+
name = kwargs.get("name", None)
429+
key = kwargs.get("key", None)
430+
431+
if key and (owner or name):
432+
raise ValueError(
433+
"Must specify exactly one of 'owner' and 'name' or single 'key' in the format 'owner/name'"
434+
)
435+
436+
if args:
437+
if len(args) == 1:
438+
key = args[0]
439+
elif len(args) == 2:
440+
owner, name = args
441+
else:
442+
raise ValueError("Invalid number of arguments")
443+
444+
if not key:
445+
if not (owner and name):
446+
raise ValueError(
447+
"Both 'owner' and 'name' must be provided if 'key' is not specified."
448+
)
449+
key = f"{owner}/{name}"
450+
451+
return f"/v1/models/{key}"
452+
453+
454+
def _delete_model_url(*args, **kwargs) -> str:
455+
return _get_model_url(*args, **kwargs)
456+
457+
377458
def _json_to_model(client: "Client", json: Dict[str, Any]) -> Model:
378459
model = Model(**json)
379460
model._client = client

0 commit comments

Comments
 (0)