Skip to content

Commit db719ef

Browse files
committed
Add overloads to models.get method
Signed-off-by: Mattt Zmuda <[email protected]>
1 parent d8d6a4e commit db719ef

File tree

1 file changed

+58
-11
lines changed

1 file changed

+58
-11
lines changed

replicate/model.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union
1+
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union, overload
22

33
from typing_extensions import NotRequired, TypedDict, Unpack, deprecated
44

@@ -207,31 +207,40 @@ async def async_list(
207207

208208
return Page[Model](**obj)
209209

210-
def get(self, key: str) -> Model:
210+
@overload
211+
def get(self, key: str) -> Model: ...
212+
213+
@overload
214+
def get(self, owner: str, name: str) -> Model: ...
215+
216+
def get(self, *args, **kwargs) -> Model:
211217
"""
212218
Get a model by name.
213-
214-
Args:
215-
key: The qualified name of the model, in the format `owner/model-name`.
216-
Returns:
217-
The model.
218219
"""
219220

220-
resp = self._client._request("GET", f"/v1/models/{key}")
221+
url = _get_model_url(*args, **kwargs)
222+
resp = self._client._request("GET", url)
221223

222224
return _json_to_model(self._client, resp.json())
223225

224-
async def async_get(self, key: str) -> Model:
226+
@overload
227+
async def async_get(self, key: str) -> Model: ...
228+
229+
@overload
230+
async def async_get(self, owner: str, name: str) -> Model: ...
231+
232+
async def async_get(self, *args, **kwargs) -> Model:
225233
"""
226234
Get a model by name.
227235
228236
Args:
229-
key: The qualified name of the model, in the format `owner/model-name`.
237+
key: The qualified name of the model, in the format `owner/name`.
230238
Returns:
231239
The model.
232240
"""
233241

234-
resp = await self._client._async_request("GET", f"/v1/models/{key}")
242+
url = _get_model_url(*args, **kwargs)
243+
resp = await self._client._async_request("GET", url)
235244

236245
return _json_to_model(self._client, resp.json())
237246

@@ -288,6 +297,13 @@ async def async_create(
288297

289298
return _json_to_model(self._client, resp.json())
290299

300+
def delete(self, key: str) -> None:
301+
"""
302+
Delete a model.
303+
"""
304+
305+
self._client._request("DELETE", f"/v1/models/{key}")
306+
291307

292308
class ModelsPredictions(Namespace):
293309
"""
@@ -374,6 +390,37 @@ def _create_model_body( # pylint: disable=too-many-arguments
374390
return body
375391

376392

393+
def _get_model_url(*args, **kwargs) -> str:
394+
if len(args) > 0 and len(kwargs) > 0:
395+
raise ValueError("Cannot mix positional and keyword arguments")
396+
397+
owner = kwargs.get("owner", None)
398+
name = kwargs.get("name", None)
399+
key = kwargs.get("key", None)
400+
401+
if key and (owner or name):
402+
raise ValueError(
403+
"Must specify exactly one of 'owner' and 'name' or single 'key' in the format 'owner/name'"
404+
)
405+
406+
if args:
407+
if len(args) == 1:
408+
key = args[0]
409+
elif len(args) == 2:
410+
owner, name = args
411+
else:
412+
raise ValueError("Invalid number of arguments")
413+
414+
if not key:
415+
if not (owner and name):
416+
raise ValueError(
417+
"Both 'owner' and 'name' must be provided if 'key' is not specified."
418+
)
419+
key = f"{owner}/{name}"
420+
421+
return f"/v1/models/{key}"
422+
423+
377424
def _json_to_model(client: "Client", json: Dict[str, Any]) -> Model:
378425
model = Model(**json)
379426
model._client = client

0 commit comments

Comments
 (0)