From fae7d939757c38054405cfc6e1b4955be343bf82 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 23 Sep 2024 03:55:50 -0700 Subject: [PATCH 1/3] Add block parameter to prediction creation methods Signed-off-by: Mattt Zmuda --- replicate/deployment.py | 5 +++++ replicate/model.py | 5 +++++ replicate/prediction.py | 19 +++++++++++++++++++ 3 files changed, 29 insertions(+) diff --git a/replicate/deployment.py b/replicate/deployment.py index e17edcbc..8f68baf8 100644 --- a/replicate/deployment.py +++ b/replicate/deployment.py @@ -8,6 +8,7 @@ from replicate.prediction import ( Prediction, _create_prediction_body, + _create_prediction_headers, _json_to_prediction, ) from replicate.resource import Namespace, Resource @@ -425,12 +426,14 @@ def create( client=self._client, file_encoding_strategy=file_encoding_strategy, ) + headers = _create_prediction_headers(block=params.pop("block", None)) body = _create_prediction_body(version=None, input=input, **params) resp = self._client._request( "POST", f"/v1/deployments/{self._deployment.owner}/{self._deployment.name}/predictions", json=body, + headers=headers, ) return _json_to_prediction(self._client, resp.json()) @@ -451,12 +454,14 @@ async def async_create( client=self._client, file_encoding_strategy=file_encoding_strategy, ) + headers = _create_prediction_headers(block=params.pop("block", None)) body = _create_prediction_body(version=None, input=input, **params) resp = await self._client._async_request( "POST", f"/v1/deployments/{self._deployment.owner}/{self._deployment.name}/predictions", json=body, + headers=headers, ) return _json_to_prediction(self._client, resp.json()) diff --git a/replicate/model.py b/replicate/model.py index ba5e1113..449387a4 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -9,6 +9,7 @@ from replicate.prediction import ( Prediction, _create_prediction_body, + _create_prediction_headers, _json_to_prediction, ) from replicate.resource import Namespace, Resource @@ -400,12 +401,14 @@ def create( client=self._client, file_encoding_strategy=file_encoding_strategy, ) + headers = _create_prediction_headers(block=params.pop("block", None)) body = _create_prediction_body(version=None, input=input, **params) resp = self._client._request( "POST", url, json=body, + headers=headers, ) return _json_to_prediction(self._client, resp.json()) @@ -429,12 +432,14 @@ async def async_create( client=self._client, file_encoding_strategy=file_encoding_strategy, ) + headers = _create_prediction_headers(block=params.pop("block", None)) body = _create_prediction_body(version=None, input=input, **params) resp = await self._client._async_request( "POST", url, json=body, + headers=headers, ) return _json_to_prediction(self._client, resp.json()) diff --git a/replicate/prediction.py b/replicate/prediction.py index 9770029b..3bdba63d 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -383,6 +383,9 @@ class CreatePredictionParams(TypedDict): stream: NotRequired[bool] """Enable streaming of prediction output.""" + block: NotRequired[bool] + """Wait until the prediction is completed before returning.""" + file_encoding_strategy: NotRequired[FileEncodingStrategy] """The strategy to use for encoding files in the prediction input.""" @@ -463,6 +466,7 @@ def create( # type: ignore client=self._client, file_encoding_strategy=file_encoding_strategy, ) + headers = _create_prediction_headers(block=params.pop("block", None)) body = _create_prediction_body( version, input, @@ -472,6 +476,7 @@ def create( # type: ignore resp = self._client._request( "POST", "/v1/predictions", + headers=headers, json=body, ) @@ -554,6 +559,7 @@ async def async_create( # type: ignore client=self._client, file_encoding_strategy=file_encoding_strategy, ) + headers = _create_prediction_headers(block=params.pop("block", None)) body = _create_prediction_body( version, input, @@ -563,6 +569,7 @@ async def async_create( # type: ignore resp = await self._client._async_request( "POST", "/v1/predictions", + headers=headers, json=body, ) @@ -603,6 +610,18 @@ async def async_cancel(self, id: str) -> Prediction: return _json_to_prediction(self._client, resp.json()) +def _create_prediction_headers( + *, + block: Optional[bool] = None, +) -> Dict[str, Any]: + headers = {} + + if block: + headers["X-Sync"] = "true" + + return headers + + def _create_prediction_body( # pylint: disable=too-many-arguments version: Optional[Union[Version, str]], input: Optional[Dict[str, Any]], From 0ef89294e90762e55fe6f75e2c86e594ed8352c1 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 25 Sep 2024 11:11:11 -0700 Subject: [PATCH 2/3] Replace X-Sync header with Prefer: wait Signed-off-by: Mattt Zmuda --- replicate/prediction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/replicate/prediction.py b/replicate/prediction.py index 3bdba63d..989841eb 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -617,7 +617,7 @@ def _create_prediction_headers( headers = {} if block: - headers["X-Sync"] = "true" + headers["Prefer"] = "wait" return headers From 32bb517b05b40ef6e8a8c1dff12657e9677a8e2e Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 25 Sep 2024 11:14:02 -0700 Subject: [PATCH 3/3] Rename block parameter to wait Signed-off-by: Mattt Zmuda --- replicate/deployment.py | 4 ++-- replicate/model.py | 4 ++-- replicate/prediction.py | 24 ++++++++++++++++-------- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/replicate/deployment.py b/replicate/deployment.py index 8f68baf8..1f0fdaba 100644 --- a/replicate/deployment.py +++ b/replicate/deployment.py @@ -426,7 +426,7 @@ def create( client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(block=params.pop("block", None)) + headers = _create_prediction_headers(wait=params.pop("wait", None)) body = _create_prediction_body(version=None, input=input, **params) resp = self._client._request( @@ -454,7 +454,7 @@ async def async_create( client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(block=params.pop("block", None)) + headers = _create_prediction_headers(wait=params.pop("wait", None)) body = _create_prediction_body(version=None, input=input, **params) resp = await self._client._async_request( diff --git a/replicate/model.py b/replicate/model.py index 449387a4..31f625af 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -401,7 +401,7 @@ def create( client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(block=params.pop("block", None)) + headers = _create_prediction_headers(wait=params.pop("wait", None)) body = _create_prediction_body(version=None, input=input, **params) resp = self._client._request( @@ -432,7 +432,7 @@ async def async_create( client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(block=params.pop("block", None)) + headers = _create_prediction_headers(wait=params.pop("wait", None)) body = _create_prediction_body(version=None, input=input, **params) resp = await self._client._async_request( diff --git a/replicate/prediction.py b/replicate/prediction.py index 989841eb..d09ef504 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -383,8 +383,14 @@ class CreatePredictionParams(TypedDict): stream: NotRequired[bool] """Enable streaming of prediction output.""" - block: NotRequired[bool] - """Wait until the prediction is completed before returning.""" + wait: NotRequired[Union[int, bool]] + """ + Wait until the prediction is completed before returning. + + If `True`, wait a predetermined number of seconds until the prediction + is completed before returning. + If an `int`, wait for the specified number of seconds. + """ file_encoding_strategy: NotRequired[FileEncodingStrategy] """The strategy to use for encoding files in the prediction input.""" @@ -466,7 +472,7 @@ def create( # type: ignore client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(block=params.pop("block", None)) + headers = _create_prediction_headers(wait=params.pop("wait", None)) body = _create_prediction_body( version, input, @@ -559,7 +565,7 @@ async def async_create( # type: ignore client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(block=params.pop("block", None)) + headers = _create_prediction_headers(wait=params.pop("wait", None)) body = _create_prediction_body( version, input, @@ -612,13 +618,15 @@ async def async_cancel(self, id: str) -> Prediction: def _create_prediction_headers( *, - block: Optional[bool] = None, + wait: Optional[Union[int, bool]] = None, ) -> Dict[str, Any]: headers = {} - if block: - headers["Prefer"] = "wait" - + if wait: + if isinstance(wait, bool): + headers["Prefer"] = "wait" + elif isinstance(wait, int): + headers["Prefer"] = f"wait={wait}" return headers