Skip to content

Commit eb3cebf

Browse files
authored
Support predictions.create with model, version, or deployment parameters (#290)
This PR updates `predictions.create` to support overloads with `model`, `version`, or `deployment` parameters. With these changes, API consumers can more easily change between official models, model versions, and deployments. ```python import replicate prediction = replicate.predictions.create( model="meta/meta-llama-3-8b-instruct", input={"prompt": "write a haiku about corgis"}, ) prediction = replicate.predictions.create( version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input={"prompt": "a studio photo of a rainbow colored corgi"}, ) prediction = replicate.predictions.create( deployment="my-username/my-embeddings-model", input={"text": "hello world"}, ) ``` --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent c0cf1ec commit eb3cebf

17 files changed

+3509
-1972
lines changed

replicate/prediction.py

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
List,
1212
Literal,
1313
Optional,
14+
Tuple,
1415
Union,
16+
overload,
1517
)
1618

1719
from typing_extensions import NotRequired, TypedDict, Unpack
@@ -31,6 +33,8 @@
3133

3234
if TYPE_CHECKING:
3335
from replicate.client import Client
36+
from replicate.deployment import Deployment
37+
from replicate.model import Model
3438
from replicate.stream import ServerSentEvent
3539

3640

@@ -380,21 +384,82 @@ class CreatePredictionParams(TypedDict):
380384
stream: NotRequired[bool]
381385
"""Enable streaming of prediction output."""
382386

387+
@overload
383388
def create(
384389
self,
385390
version: Union[Version, str],
386391
input: Optional[Dict[str, Any]],
387392
**params: Unpack["Predictions.CreatePredictionParams"],
393+
) -> Prediction: ...
394+
395+
@overload
396+
def create(
397+
self,
398+
*,
399+
model: Union[str, Tuple[str, str], "Model"],
400+
input: Optional[Dict[str, Any]],
401+
**params: Unpack["Predictions.CreatePredictionParams"],
402+
) -> Prediction: ...
403+
404+
@overload
405+
def create(
406+
self,
407+
*,
408+
deployment: Union[str, Tuple[str, str], "Deployment"],
409+
input: Optional[Dict[str, Any]],
410+
**params: Unpack["Predictions.CreatePredictionParams"],
411+
) -> Prediction: ...
412+
413+
def create( # type: ignore
414+
self,
415+
*args,
416+
model: Optional[Union[str, Tuple[str, str], "Model"]] = None,
417+
version: Optional[Union[Version, str, "Version"]] = None,
418+
deployment: Optional[Union[str, Tuple[str, str], "Deployment"]] = None,
419+
input: Optional[Dict[str, Any]] = None,
420+
**params: Unpack["Predictions.CreatePredictionParams"],
388421
) -> Prediction:
389422
"""
390-
Create a new prediction for the specified model version.
423+
Create a new prediction for the specified model, version, or deployment.
391424
"""
392425

426+
if args:
427+
version = args[0] if len(args) > 0 else None
428+
input = args[1] if len(args) > 1 else input
429+
430+
if sum(bool(x) for x in [model, version, deployment]) != 1:
431+
raise ValueError(
432+
"Exactly one of 'model', 'version', or 'deployment' must be specified."
433+
)
434+
435+
if model is not None:
436+
from replicate.model import ( # pylint: disable=import-outside-toplevel
437+
Models,
438+
)
439+
440+
return Models(self._client).predictions.create(
441+
model=model,
442+
input=input or {},
443+
**params,
444+
)
445+
446+
if deployment is not None:
447+
from replicate.deployment import ( # pylint: disable=import-outside-toplevel
448+
Deployments,
449+
)
450+
451+
return Deployments(self._client).predictions.create(
452+
deployment=deployment,
453+
input=input or {},
454+
**params,
455+
)
456+
393457
body = _create_prediction_body(
394458
version,
395459
input,
396460
**params,
397461
)
462+
398463
resp = self._client._request(
399464
"POST",
400465
"/v1/predictions",
@@ -403,21 +468,82 @@ def create(
403468

404469
return _json_to_prediction(self._client, resp.json())
405470

471+
@overload
406472
async def async_create(
407473
self,
408474
version: Union[Version, str],
409475
input: Optional[Dict[str, Any]],
410476
**params: Unpack["Predictions.CreatePredictionParams"],
477+
) -> Prediction: ...
478+
479+
@overload
480+
async def async_create(
481+
self,
482+
*,
483+
model: Union[str, Tuple[str, str], "Model"],
484+
input: Optional[Dict[str, Any]],
485+
**params: Unpack["Predictions.CreatePredictionParams"],
486+
) -> Prediction: ...
487+
488+
@overload
489+
async def async_create(
490+
self,
491+
*,
492+
deployment: Union[str, Tuple[str, str], "Deployment"],
493+
input: Optional[Dict[str, Any]],
494+
**params: Unpack["Predictions.CreatePredictionParams"],
495+
) -> Prediction: ...
496+
497+
async def async_create( # type: ignore
498+
self,
499+
*args,
500+
model: Optional[Union[str, Tuple[str, str], "Model"]] = None,
501+
version: Optional[Union[Version, str, "Version"]] = None,
502+
deployment: Optional[Union[str, Tuple[str, str], "Deployment"]] = None,
503+
input: Optional[Dict[str, Any]] = None,
504+
**params: Unpack["Predictions.CreatePredictionParams"],
411505
) -> Prediction:
412506
"""
413-
Create a new prediction for the specified model version.
507+
Create a new prediction for the specified model, version, or deployment.
414508
"""
415509

510+
if args:
511+
version = args[0] if len(args) > 0 else None
512+
input = args[1] if len(args) > 1 else input
513+
514+
if sum(bool(x) for x in [model, version, deployment]) != 1:
515+
raise ValueError(
516+
"Exactly one of 'model', 'version', or 'deployment' must be specified."
517+
)
518+
519+
if model is not None:
520+
from replicate.model import ( # pylint: disable=import-outside-toplevel
521+
Models,
522+
)
523+
524+
return await Models(self._client).predictions.async_create(
525+
model=model,
526+
input=input or {},
527+
**params,
528+
)
529+
530+
if deployment is not None:
531+
from replicate.deployment import ( # pylint: disable=import-outside-toplevel
532+
Deployments,
533+
)
534+
535+
return await Deployments(self._client).predictions.async_create(
536+
deployment=deployment,
537+
input=input or {},
538+
**params,
539+
)
540+
416541
body = _create_prediction_body(
417542
version,
418543
input,
419544
**params,
420545
)
546+
421547
resp = await self._client._async_request(
422548
"POST",
423549
"/v1/predictions",

0 commit comments

Comments
 (0)