|
1 |
| -from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union |
| 1 | +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union, overload |
2 | 2 |
|
3 | 3 | from typing_extensions import NotRequired, TypedDict, Unpack, deprecated
|
4 | 4 |
|
@@ -207,31 +207,40 @@ async def async_list(
|
207 | 207 |
|
208 | 208 | return Page[Model](**obj)
|
209 | 209 |
|
210 |
| - def get(self, key: str) -> Model: |
| 210 | + @overload |
| 211 | + def get(self, key: str) -> Model: ... |
| 212 | + |
| 213 | + @overload |
| 214 | + def get(self, owner: str, name: str) -> Model: ... |
| 215 | + |
| 216 | + def get(self, *args, **kwargs) -> Model: |
211 | 217 | """
|
212 | 218 | Get a model by name.
|
213 |
| -
|
214 |
| - Args: |
215 |
| - key: The qualified name of the model, in the format `owner/model-name`. |
216 |
| - Returns: |
217 |
| - The model. |
218 | 219 | """
|
219 | 220 |
|
220 |
| - resp = self._client._request("GET", f"/v1/models/{key}") |
| 221 | + url = _get_model_url(*args, **kwargs) |
| 222 | + resp = self._client._request("GET", url) |
221 | 223 |
|
222 | 224 | return _json_to_model(self._client, resp.json())
|
223 | 225 |
|
224 |
| - async def async_get(self, key: str) -> Model: |
| 226 | + @overload |
| 227 | + async def async_get(self, key: str) -> Model: ... |
| 228 | + |
| 229 | + @overload |
| 230 | + async def async_get(self, owner: str, name: str) -> Model: ... |
| 231 | + |
| 232 | + async def async_get(self, *args, **kwargs) -> Model: |
225 | 233 | """
|
226 | 234 | Get a model by name.
|
227 | 235 |
|
228 | 236 | Args:
|
229 |
| - key: The qualified name of the model, in the format `owner/model-name`. |
| 237 | + key: The qualified name of the model, in the format `owner/name`. |
230 | 238 | Returns:
|
231 | 239 | The model.
|
232 | 240 | """
|
233 | 241 |
|
234 |
| - resp = await self._client._async_request("GET", f"/v1/models/{key}") |
| 242 | + url = _get_model_url(*args, **kwargs) |
| 243 | + resp = await self._client._async_request("GET", url) |
235 | 244 |
|
236 | 245 | return _json_to_model(self._client, resp.json())
|
237 | 246 |
|
@@ -288,6 +297,13 @@ async def async_create(
|
288 | 297 |
|
289 | 298 | return _json_to_model(self._client, resp.json())
|
290 | 299 |
|
| 300 | + def delete(self, key: str) -> None: |
| 301 | + """ |
| 302 | + Delete a model. |
| 303 | + """ |
| 304 | + |
| 305 | + self._client._request("DELETE", f"/v1/models/{key}") |
| 306 | + |
291 | 307 |
|
292 | 308 | class ModelsPredictions(Namespace):
|
293 | 309 | """
|
@@ -374,6 +390,37 @@ def _create_model_body( # pylint: disable=too-many-arguments
|
374 | 390 | return body
|
375 | 391 |
|
376 | 392 |
|
| 393 | +def _get_model_url(*args, **kwargs) -> str: |
| 394 | + if len(args) > 0 and len(kwargs) > 0: |
| 395 | + raise ValueError("Cannot mix positional and keyword arguments") |
| 396 | + |
| 397 | + owner = kwargs.get("owner", None) |
| 398 | + name = kwargs.get("name", None) |
| 399 | + key = kwargs.get("key", None) |
| 400 | + |
| 401 | + if key and (owner or name): |
| 402 | + raise ValueError( |
| 403 | + "Must specify exactly one of 'owner' and 'name' or single 'key' in the format 'owner/name'" |
| 404 | + ) |
| 405 | + |
| 406 | + if args: |
| 407 | + if len(args) == 1: |
| 408 | + key = args[0] |
| 409 | + elif len(args) == 2: |
| 410 | + owner, name = args |
| 411 | + else: |
| 412 | + raise ValueError("Invalid number of arguments") |
| 413 | + |
| 414 | + if not key: |
| 415 | + if not (owner and name): |
| 416 | + raise ValueError( |
| 417 | + "Both 'owner' and 'name' must be provided if 'key' is not specified." |
| 418 | + ) |
| 419 | + key = f"{owner}/{name}" |
| 420 | + |
| 421 | + return f"/v1/models/{key}" |
| 422 | + |
| 423 | + |
377 | 424 | def _json_to_model(client: "Client", json: Dict[str, Any]) -> Model:
|
378 | 425 | model = Model(**json)
|
379 | 426 | model._client = client
|
|
0 commit comments