From 5059d3a7269d51cb4af34d841478f1ae206d3855 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 4 Oct 2024 13:47:18 -0700 Subject: [PATCH 1/6] Switch to using file outputs and blocking api by default --- replicate/prediction.py | 10 ++++++---- replicate/run.py | 15 ++++++++++++--- tests/test_run.py | 8 ++++---- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/replicate/prediction.py b/replicate/prediction.py index 0e5342a..aa3e45c 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -395,11 +395,13 @@ class CreatePredictionParams(TypedDict): wait: NotRequired[Union[int, bool]] """ - Wait until the prediction is completed before returning. + Block until the prediction is completed before returning. - If `True`, wait a predetermined number of seconds until the prediction - is completed before returning. - If an `int`, wait for the specified number of seconds. + If `True`, keep the request open for up to 60 seconds, falling back to + polling until the prediction is completed. + If an `int`, same as True but hold the request for a specified number of + seconds (between 1 and 60). + If `False`, poll for the prediction status until completed. """ file_encoding_strategy: NotRequired[FileEncodingStrategy] diff --git a/replicate/run.py b/replicate/run.py index d159f11..cab6aab 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -29,14 +29,18 @@ def run( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 """ Run a model and wait for its output. """ - is_blocking = "wait" in params + if "wait" not in params: + params["wait"] = True + is_blocking = params["wait"] != False # noqa: E712 + version, owner, name, version_id = identifier._resolve(ref) if version_id is not None: @@ -74,13 +78,18 @@ async def async_run( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401 """ Run a model and wait for its output asynchronously. """ + if "wait" not in params: + params["wait"] = True + is_blocking = params["wait"] != False # noqa: E712 + version, owner, name, version_id = identifier._resolve(ref) if version or version_id: diff --git a/tests/test_run.py b/tests/test_run.py index 7d963a4..0f9aed2 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -123,7 +123,7 @@ def prediction_with_status(status: str) -> dict: router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("processing"), + json=prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( @@ -212,7 +212,7 @@ def prediction_with_status(status: str) -> dict: router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("processing"), + json=prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( @@ -454,7 +454,7 @@ def prediction_with_status( router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("processing"), + json=prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( @@ -541,7 +541,7 @@ def prediction_with_status( router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("processing"), + json=prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( From 20f71387c93d8b45a618a463fa0310175ccc10a8 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 4 Oct 2024 16:47:01 -0700 Subject: [PATCH 2/6] Fix runtime error with incorrect variable arguments This commit fixes a bug introduced by requiring `use_file_output` to be a keyword parameter. --- replicate/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/replicate/client.py b/replicate/client.py index 52d07f7..8220dd4 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -171,7 +171,7 @@ def run( Run a model and wait for its output. """ - return run(self, ref, input, use_file_output, **params) + return run(self, ref, input, use_file_output=use_file_output, **params) async def async_run( self, @@ -184,7 +184,7 @@ async def async_run( Run a model and wait for its output asynchronously. """ - return await async_run(self, ref, input, use_file_output, **params) + return await async_run(self, ref, input, use_file_output=use_file_output, **params) def stream( self, From 9921f4c64399113a73a364ae9a79894696bd52bb Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 4 Oct 2024 16:51:50 -0700 Subject: [PATCH 3/6] Do not poll if blocking prediction has succeeded --- replicate/client.py | 4 +++- replicate/run.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/replicate/client.py b/replicate/client.py index 8220dd4..a883bd1 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -184,7 +184,9 @@ async def async_run( Run a model and wait for its output asynchronously. """ - return await async_run(self, ref, input, use_file_output=use_file_output, **params) + return await async_run( + self, ref, input, use_file_output=use_file_output, **params + ) def stream( self, diff --git a/replicate/run.py b/replicate/run.py index cab6aab..3b6bddb 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -111,7 +111,8 @@ async def async_run( if version and (iterator := _make_async_output_iterator(version, prediction)): return iterator - await prediction.async_wait() + if not (is_blocking and prediction.status != "starting"): + await prediction.async_wait() if prediction.status == "failed": raise ModelError(prediction) From 57e7255370b31bfe2d017ddba618b4d0975a66a7 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 4 Oct 2024 17:02:48 -0700 Subject: [PATCH 4/6] Update stream interface to always use FileOutput --- replicate/client.py | 16 ++++++++++------ replicate/stream.py | 11 +++++++---- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/replicate/client.py b/replicate/client.py index a883bd1..3e767d6 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -164,7 +164,8 @@ def run( self, ref: str, input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 """ @@ -177,7 +178,8 @@ async def async_run( self, ref: str, input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401 """ @@ -191,28 +193,30 @@ async def async_run( def stream( self, ref: str, + *, input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Iterator["ServerSentEvent"]: """ Stream a model's output. """ - return stream(self, ref, input, use_file_output, **params) + return stream(self, ref, input, use_file_output=use_file_output, **params) async def async_stream( self, ref: str, input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> AsyncIterator["ServerSentEvent"]: """ Stream a model's output asynchronously. """ - return async_stream(self, ref, input, use_file_output, **params) + return async_stream(self, ref, input, use_file_output=use_file_output, **params) # Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155 diff --git a/replicate/stream.py b/replicate/stream.py index 4cf0d15..e837abd 100644 --- a/replicate/stream.py +++ b/replicate/stream.py @@ -71,11 +71,12 @@ def __init__( self, client: "Client", response: "httpx.Response", - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, ) -> None: self.client = client self.response = response - self.use_file_output = use_file_output or False + self.use_file_output = use_file_output or True content_type, _, _ = response.headers["content-type"].partition(";") if content_type != "text/event-stream": raise ValueError( @@ -193,7 +194,8 @@ def stream( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Iterator[ServerSentEvent]: """ @@ -234,7 +236,8 @@ async def async_stream( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> AsyncIterator[ServerSentEvent]: """ From 5b6ef4f62c93386a2cbf6dc72cec9a24b0ae31e3 Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Wed, 9 Oct 2024 15:42:57 -0700 Subject: [PATCH 5/6] try not waiting in tests --- tests/test_run.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_run.py b/tests/test_run.py index 0f9aed2..3bf50eb 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -327,6 +327,7 @@ def prediction_with_status( "text": "Hello, world!", }, use_file_output=True, + wait=False, ), ) @@ -409,7 +410,7 @@ def prediction_with_status( "text": "Hello, world!", }, use_file_output=True, - wait=True, + # wait=True (this is the default) ), ) @@ -505,6 +506,7 @@ def prediction_with_status( "text": "Hello, world!", }, use_file_output=True, + wait=False, ), ) @@ -583,6 +585,7 @@ def prediction_with_status( "text": "Hello, world!", }, use_file_output=True, + wait=False, ), ) From 6624448fc55ac140159cf1fc7927221f66e85795 Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Wed, 9 Oct 2024 15:46:19 -0700 Subject: [PATCH 6/6] use file output in tests --- tests/test_run.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_run.py b/tests/test_run.py index 3bf50eb..8eac091 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -32,17 +32,20 @@ async def test_run(async_flag, record_mode): output = await replicate.async_run( f"stability-ai/sdxl:{version}", input=input, + use_file_output=True, ) else: output = replicate.run( f"stability-ai/sdxl:{version}", input=input, + use_file_output=True, ) assert output is not None assert isinstance(output, list) assert len(output) > 0 - assert output[0].startswith("https://") + assert isinstance(output[0], FileOutput) + assert output[0].url.startswith("https://") @pytest.mark.vcr("run__concurrently.yaml")