11
11
List ,
12
12
Literal ,
13
13
Optional ,
14
+ Tuple ,
14
15
Union ,
16
+ overload ,
15
17
)
16
18
17
19
from typing_extensions import NotRequired , TypedDict , Unpack
31
33
32
34
if TYPE_CHECKING :
33
35
from replicate .client import Client
36
+ from replicate .deployment import Deployment
37
+ from replicate .model import Model
34
38
from replicate .stream import ServerSentEvent
35
39
36
40
@@ -380,21 +384,82 @@ class CreatePredictionParams(TypedDict):
380
384
stream : NotRequired [bool ]
381
385
"""Enable streaming of prediction output."""
382
386
387
+ @overload
383
388
def create (
384
389
self ,
385
390
version : Union [Version , str ],
386
391
input : Optional [Dict [str , Any ]],
387
392
** params : Unpack ["Predictions.CreatePredictionParams" ],
393
+ ) -> Prediction : ...
394
+
395
+ @overload
396
+ def create (
397
+ self ,
398
+ * ,
399
+ model : Union [str , Tuple [str , str ], "Model" ],
400
+ input : Optional [Dict [str , Any ]],
401
+ ** params : Unpack ["Predictions.CreatePredictionParams" ],
402
+ ) -> Prediction : ...
403
+
404
+ @overload
405
+ def create (
406
+ self ,
407
+ * ,
408
+ deployment : Union [str , Tuple [str , str ], "Deployment" ],
409
+ input : Optional [Dict [str , Any ]],
410
+ ** params : Unpack ["Predictions.CreatePredictionParams" ],
411
+ ) -> Prediction : ...
412
+
413
+ def create ( # type: ignore
414
+ self ,
415
+ * args ,
416
+ model : Optional [Union [str , Tuple [str , str ], "Model" ]] = None ,
417
+ version : Optional [Union [Version , str , "Version" ]] = None ,
418
+ deployment : Optional [Union [str , Tuple [str , str ], "Deployment" ]] = None ,
419
+ input : Optional [Dict [str , Any ]] = None ,
420
+ ** params : Unpack ["Predictions.CreatePredictionParams" ],
388
421
) -> Prediction :
389
422
"""
390
- Create a new prediction for the specified model version.
423
+ Create a new prediction for the specified model, version, or deployment .
391
424
"""
392
425
426
+ if args :
427
+ version = args [0 ] if len (args ) > 0 else None
428
+ input = args [1 ] if len (args ) > 1 else input
429
+
430
+ if sum (bool (x ) for x in [model , version , deployment ]) != 1 :
431
+ raise ValueError (
432
+ "Exactly one of 'model', 'version', or 'deployment' must be specified."
433
+ )
434
+
435
+ if model is not None :
436
+ from replicate .model import ( # pylint: disable=import-outside-toplevel
437
+ Models ,
438
+ )
439
+
440
+ return Models (self ._client ).predictions .create (
441
+ model = model ,
442
+ input = input or {},
443
+ ** params ,
444
+ )
445
+
446
+ if deployment is not None :
447
+ from replicate .deployment import ( # pylint: disable=import-outside-toplevel
448
+ Deployments ,
449
+ )
450
+
451
+ return Deployments (self ._client ).predictions .create (
452
+ deployment = deployment ,
453
+ input = input or {},
454
+ ** params ,
455
+ )
456
+
393
457
body = _create_prediction_body (
394
458
version ,
395
459
input ,
396
460
** params ,
397
461
)
462
+
398
463
resp = self ._client ._request (
399
464
"POST" ,
400
465
"/v1/predictions" ,
@@ -403,21 +468,82 @@ def create(
403
468
404
469
return _json_to_prediction (self ._client , resp .json ())
405
470
471
+ @overload
406
472
async def async_create (
407
473
self ,
408
474
version : Union [Version , str ],
409
475
input : Optional [Dict [str , Any ]],
410
476
** params : Unpack ["Predictions.CreatePredictionParams" ],
477
+ ) -> Prediction : ...
478
+
479
+ @overload
480
+ async def async_create (
481
+ self ,
482
+ * ,
483
+ model : Union [str , Tuple [str , str ], "Model" ],
484
+ input : Optional [Dict [str , Any ]],
485
+ ** params : Unpack ["Predictions.CreatePredictionParams" ],
486
+ ) -> Prediction : ...
487
+
488
+ @overload
489
+ async def async_create (
490
+ self ,
491
+ * ,
492
+ deployment : Union [str , Tuple [str , str ], "Deployment" ],
493
+ input : Optional [Dict [str , Any ]],
494
+ ** params : Unpack ["Predictions.CreatePredictionParams" ],
495
+ ) -> Prediction : ...
496
+
497
+ async def async_create ( # type: ignore
498
+ self ,
499
+ * args ,
500
+ model : Optional [Union [str , Tuple [str , str ], "Model" ]] = None ,
501
+ version : Optional [Union [Version , str , "Version" ]] = None ,
502
+ deployment : Optional [Union [str , Tuple [str , str ], "Deployment" ]] = None ,
503
+ input : Optional [Dict [str , Any ]] = None ,
504
+ ** params : Unpack ["Predictions.CreatePredictionParams" ],
411
505
) -> Prediction :
412
506
"""
413
- Create a new prediction for the specified model version.
507
+ Create a new prediction for the specified model, version, or deployment .
414
508
"""
415
509
510
+ if args :
511
+ version = args [0 ] if len (args ) > 0 else None
512
+ input = args [1 ] if len (args ) > 1 else input
513
+
514
+ if sum (bool (x ) for x in [model , version , deployment ]) != 1 :
515
+ raise ValueError (
516
+ "Exactly one of 'model', 'version', or 'deployment' must be specified."
517
+ )
518
+
519
+ if model is not None :
520
+ from replicate .model import ( # pylint: disable=import-outside-toplevel
521
+ Models ,
522
+ )
523
+
524
+ return await Models (self ._client ).predictions .async_create (
525
+ model = model ,
526
+ input = input or {},
527
+ ** params ,
528
+ )
529
+
530
+ if deployment is not None :
531
+ from replicate .deployment import ( # pylint: disable=import-outside-toplevel
532
+ Deployments ,
533
+ )
534
+
535
+ return await Deployments (self ._client ).predictions .async_create (
536
+ deployment = deployment ,
537
+ input = input or {},
538
+ ** params ,
539
+ )
540
+
416
541
body = _create_prediction_body (
417
542
version ,
418
543
input ,
419
544
** params ,
420
545
)
546
+
421
547
resp = await self ._client ._async_request (
422
548
"POST" ,
423
549
"/v1/predictions" ,
0 commit comments