Skip to content

Commit df67837

Browse files
committed
Add overloads to models.get method
Signed-off-by: Mattt Zmuda <[email protected]>
1 parent d8d6a4e commit df67837

File tree

1 file changed

+90
-5
lines changed

1 file changed

+90
-5
lines changed

replicate/model.py

Lines changed: 90 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union
1+
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union, overload
22

33
from typing_extensions import NotRequired, TypedDict, Unpack, deprecated
44

@@ -207,31 +207,65 @@ async def async_list(
207207

208208
return Page[Model](**obj)
209209

210+
@overload
210211
def get(self, key: str) -> Model:
211212
"""
212213
Get a model by name.
213214
214215
Args:
215-
key: The qualified name of the model, in the format `owner/model-name`.
216+
key: The qualified name of the model, in the format `owner/name`.
216217
Returns:
217218
The model.
218219
"""
220+
...
219221

220-
resp = self._client._request("GET", f"/v1/models/{key}")
222+
@overload
223+
def get(self, owner: str, name: str) -> Model:
224+
"""
225+
Get a model by owner and name.
226+
227+
Args:
228+
owner: The owner of the model.
229+
name: The name of the model.
230+
Returns:
231+
The model.
232+
"""
233+
...
234+
235+
def get(self, *args, **kwargs) -> Model:
236+
url = _get_model_url(*args, **kwargs)
237+
resp = self._client._request("GET", url)
221238

222239
return _json_to_model(self._client, resp.json())
223240

241+
@overload
224242
async def async_get(self, key: str) -> Model:
225243
"""
226244
Get a model by name.
227245
228246
Args:
229-
key: The qualified name of the model, in the format `owner/model-name`.
247+
key: The qualified name of the model, in the format `owner/name`.
230248
Returns:
231249
The model.
232250
"""
251+
...
233252

234-
resp = await self._client._async_request("GET", f"/v1/models/{key}")
253+
@overload
254+
async def async_get(self, owner: str, name: str) -> Model:
255+
"""
256+
Get a model by owner and name.
257+
258+
Args:
259+
owner: The owner of the model.
260+
name: The name of the model.
261+
Returns:
262+
The model.
263+
"""
264+
...
265+
266+
async def async_get(self, *args, **kwargs) -> Model:
267+
url = _get_model_url(*args, **kwargs)
268+
resp = await self._client._async_request("GET", url)
235269

236270
return _json_to_model(self._client, resp.json())
237271

@@ -288,6 +322,13 @@ async def async_create(
288322

289323
return _json_to_model(self._client, resp.json())
290324

325+
def delete(self, key: str) -> None:
326+
"""
327+
Delete a model.
328+
"""
329+
330+
self._client._request("DELETE", f"/v1/models/{key}")
331+
291332

292333
class ModelsPredictions(Namespace):
293334
"""
@@ -374,6 +415,50 @@ def _create_model_body( # pylint: disable=too-many-arguments
374415
return body
375416

376417

418+
def _get_model_url(*args, **kwargs) -> str:
419+
"""
420+
Extracts the key from the provided arguments.
421+
422+
Args:
423+
args: Positional arguments that may contain the key or owner and name.
424+
kwargs: Keyword arguments that may contain the key, owner, or name.
425+
426+
Returns:
427+
The extracted key in the format 'owner/name'.
428+
429+
Raises:
430+
ValueError: If the arguments are invalid or insufficient to determine the key.
431+
"""
432+
if len(args) > 0 and len(kwargs) > 0:
433+
raise ValueError("Cannot mix positional and keyword arguments")
434+
435+
owner = kwargs.get("owner", None)
436+
name = kwargs.get("name", None)
437+
key = kwargs.get("key", None)
438+
439+
if key and (owner or name):
440+
raise ValueError(
441+
"Must specify exactly one of 'owner' and 'name' or single 'key' in the format 'owner/name'"
442+
)
443+
444+
if args:
445+
if len(args) == 1:
446+
key = args[0]
447+
elif len(args) == 2:
448+
owner, name = args
449+
else:
450+
raise ValueError("Invalid number of arguments")
451+
452+
if not key:
453+
if not (owner and name):
454+
raise ValueError(
455+
"Both 'owner' and 'name' must be provided if 'key' is not specified."
456+
)
457+
key = f"{owner}/{name}"
458+
459+
return f"/v1/models/{key}"
460+
461+
377462
def _json_to_model(client: "Client", json: Dict[str, Any]) -> Model:
378463
model = Model(**json)
379464
model._client = client

0 commit comments

Comments
 (0)