Skip to content

Commit 32bb517

Browse files
committed
Rename block parameter to wait
Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 0ef8929 commit 32bb517

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

replicate/deployment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def create(
426426
client=self._client,
427427
file_encoding_strategy=file_encoding_strategy,
428428
)
429-
headers = _create_prediction_headers(block=params.pop("block", None))
429+
headers = _create_prediction_headers(wait=params.pop("wait", None))
430430
body = _create_prediction_body(version=None, input=input, **params)
431431

432432
resp = self._client._request(
@@ -454,7 +454,7 @@ async def async_create(
454454
client=self._client,
455455
file_encoding_strategy=file_encoding_strategy,
456456
)
457-
headers = _create_prediction_headers(block=params.pop("block", None))
457+
headers = _create_prediction_headers(wait=params.pop("wait", None))
458458
body = _create_prediction_body(version=None, input=input, **params)
459459

460460
resp = await self._client._async_request(

replicate/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def create(
401401
client=self._client,
402402
file_encoding_strategy=file_encoding_strategy,
403403
)
404-
headers = _create_prediction_headers(block=params.pop("block", None))
404+
headers = _create_prediction_headers(wait=params.pop("wait", None))
405405
body = _create_prediction_body(version=None, input=input, **params)
406406

407407
resp = self._client._request(
@@ -432,7 +432,7 @@ async def async_create(
432432
client=self._client,
433433
file_encoding_strategy=file_encoding_strategy,
434434
)
435-
headers = _create_prediction_headers(block=params.pop("block", None))
435+
headers = _create_prediction_headers(wait=params.pop("wait", None))
436436
body = _create_prediction_body(version=None, input=input, **params)
437437

438438
resp = await self._client._async_request(

replicate/prediction.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,14 @@ class CreatePredictionParams(TypedDict):
383383
stream: NotRequired[bool]
384384
"""Enable streaming of prediction output."""
385385

386-
block: NotRequired[bool]
387-
"""Wait until the prediction is completed before returning."""
386+
wait: NotRequired[Union[int, bool]]
387+
"""
388+
Wait until the prediction is completed before returning.
389+
390+
If `True`, wait a predetermined number of seconds until the prediction
391+
is completed before returning.
392+
If an `int`, wait for the specified number of seconds.
393+
"""
388394

389395
file_encoding_strategy: NotRequired[FileEncodingStrategy]
390396
"""The strategy to use for encoding files in the prediction input."""
@@ -466,7 +472,7 @@ def create( # type: ignore
466472
client=self._client,
467473
file_encoding_strategy=file_encoding_strategy,
468474
)
469-
headers = _create_prediction_headers(block=params.pop("block", None))
475+
headers = _create_prediction_headers(wait=params.pop("wait", None))
470476
body = _create_prediction_body(
471477
version,
472478
input,
@@ -559,7 +565,7 @@ async def async_create( # type: ignore
559565
client=self._client,
560566
file_encoding_strategy=file_encoding_strategy,
561567
)
562-
headers = _create_prediction_headers(block=params.pop("block", None))
568+
headers = _create_prediction_headers(wait=params.pop("wait", None))
563569
body = _create_prediction_body(
564570
version,
565571
input,
@@ -612,13 +618,15 @@ async def async_cancel(self, id: str) -> Prediction:
612618

613619
def _create_prediction_headers(
614620
*,
615-
block: Optional[bool] = None,
621+
wait: Optional[Union[int, bool]] = None,
616622
) -> Dict[str, Any]:
617623
headers = {}
618624

619-
if block:
620-
headers["Prefer"] = "wait"
621-
625+
if wait:
626+
if isinstance(wait, bool):
627+
headers["Prefer"] = "wait"
628+
elif isinstance(wait, int):
629+
headers["Prefer"] = f"wait={wait}"
622630
return headers
623631

624632

0 commit comments

Comments
 (0)