Skip to content

Commit 4885f19

Browse files
authored
Add wait parameter to prediction creation methods (#354)
Signed-off-by: Mattt Zmuda <[email protected]>
1 parent cd599f8 commit 4885f19

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

replicate/deployment.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from replicate.prediction import (
99
Prediction,
1010
_create_prediction_body,
11+
_create_prediction_headers,
1112
_json_to_prediction,
1213
)
1314
from replicate.resource import Namespace, Resource
@@ -425,12 +426,14 @@ def create(
425426
client=self._client,
426427
file_encoding_strategy=file_encoding_strategy,
427428
)
429+
headers = _create_prediction_headers(wait=params.pop("wait", None))
428430
body = _create_prediction_body(version=None, input=input, **params)
429431

430432
resp = self._client._request(
431433
"POST",
432434
f"/v1/deployments/{self._deployment.owner}/{self._deployment.name}/predictions",
433435
json=body,
436+
headers=headers,
434437
)
435438

436439
return _json_to_prediction(self._client, resp.json())
@@ -451,12 +454,14 @@ async def async_create(
451454
client=self._client,
452455
file_encoding_strategy=file_encoding_strategy,
453456
)
457+
headers = _create_prediction_headers(wait=params.pop("wait", None))
454458
body = _create_prediction_body(version=None, input=input, **params)
455459

456460
resp = await self._client._async_request(
457461
"POST",
458462
f"/v1/deployments/{self._deployment.owner}/{self._deployment.name}/predictions",
459463
json=body,
464+
headers=headers,
460465
)
461466

462467
return _json_to_prediction(self._client, resp.json())

replicate/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from replicate.prediction import (
1010
Prediction,
1111
_create_prediction_body,
12+
_create_prediction_headers,
1213
_json_to_prediction,
1314
)
1415
from replicate.resource import Namespace, Resource
@@ -400,12 +401,14 @@ def create(
400401
client=self._client,
401402
file_encoding_strategy=file_encoding_strategy,
402403
)
404+
headers = _create_prediction_headers(wait=params.pop("wait", None))
403405
body = _create_prediction_body(version=None, input=input, **params)
404406

405407
resp = self._client._request(
406408
"POST",
407409
url,
408410
json=body,
411+
headers=headers,
409412
)
410413

411414
return _json_to_prediction(self._client, resp.json())
@@ -429,12 +432,14 @@ async def async_create(
429432
client=self._client,
430433
file_encoding_strategy=file_encoding_strategy,
431434
)
435+
headers = _create_prediction_headers(wait=params.pop("wait", None))
432436
body = _create_prediction_body(version=None, input=input, **params)
433437

434438
resp = await self._client._async_request(
435439
"POST",
436440
url,
437441
json=body,
442+
headers=headers,
438443
)
439444

440445
return _json_to_prediction(self._client, resp.json())

replicate/prediction.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,15 @@ class CreatePredictionParams(TypedDict):
383383
stream: NotRequired[bool]
384384
"""Enable streaming of prediction output."""
385385

386+
wait: NotRequired[Union[int, bool]]
387+
"""
388+
Wait until the prediction is completed before returning.
389+
390+
If `True`, wait a predetermined number of seconds until the prediction
391+
is completed before returning.
392+
If an `int`, wait for the specified number of seconds.
393+
"""
394+
386395
file_encoding_strategy: NotRequired[FileEncodingStrategy]
387396
"""The strategy to use for encoding files in the prediction input."""
388397

@@ -463,6 +472,7 @@ def create( # type: ignore
463472
client=self._client,
464473
file_encoding_strategy=file_encoding_strategy,
465474
)
475+
headers = _create_prediction_headers(wait=params.pop("wait", None))
466476
body = _create_prediction_body(
467477
version,
468478
input,
@@ -472,6 +482,7 @@ def create( # type: ignore
472482
resp = self._client._request(
473483
"POST",
474484
"/v1/predictions",
485+
headers=headers,
475486
json=body,
476487
)
477488

@@ -554,6 +565,7 @@ async def async_create( # type: ignore
554565
client=self._client,
555566
file_encoding_strategy=file_encoding_strategy,
556567
)
568+
headers = _create_prediction_headers(wait=params.pop("wait", None))
557569
body = _create_prediction_body(
558570
version,
559571
input,
@@ -563,6 +575,7 @@ async def async_create( # type: ignore
563575
resp = await self._client._async_request(
564576
"POST",
565577
"/v1/predictions",
578+
headers=headers,
566579
json=body,
567580
)
568581

@@ -603,6 +616,20 @@ async def async_cancel(self, id: str) -> Prediction:
603616
return _json_to_prediction(self._client, resp.json())
604617

605618

619+
def _create_prediction_headers(
620+
*,
621+
wait: Optional[Union[int, bool]] = None,
622+
) -> Dict[str, Any]:
623+
headers = {}
624+
625+
if wait:
626+
if isinstance(wait, bool):
627+
headers["Prefer"] = "wait"
628+
elif isinstance(wait, int):
629+
headers["Prefer"] = f"wait={wait}"
630+
return headers
631+
632+
606633
def _create_prediction_body( # pylint: disable=too-many-arguments
607634
version: Optional[Union[Version, str]],
608635
input: Optional[Dict[str, Any]],

0 commit comments

Comments
 (0)