diff --git a/replicate/client.py b/replicate/client.py index 52d07f7..3e767d6 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -164,53 +164,59 @@ def run( self, ref: str, input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 """ Run a model and wait for its output. """ - return run(self, ref, input, use_file_output, **params) + return run(self, ref, input, use_file_output=use_file_output, **params) async def async_run( self, ref: str, input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401 """ Run a model and wait for its output asynchronously. """ - return await async_run(self, ref, input, use_file_output, **params) + return await async_run( + self, ref, input, use_file_output=use_file_output, **params + ) def stream( self, ref: str, + *, input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Iterator["ServerSentEvent"]: """ Stream a model's output. """ - return stream(self, ref, input, use_file_output, **params) + return stream(self, ref, input, use_file_output=use_file_output, **params) async def async_stream( self, ref: str, input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> AsyncIterator["ServerSentEvent"]: """ Stream a model's output asynchronously. """ - return async_stream(self, ref, input, use_file_output, **params) + return async_stream(self, ref, input, use_file_output=use_file_output, **params) # Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155 diff --git a/replicate/prediction.py b/replicate/prediction.py index 0e5342a..aa3e45c 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -395,11 +395,13 @@ class CreatePredictionParams(TypedDict): wait: NotRequired[Union[int, bool]] """ - Wait until the prediction is completed before returning. + Block until the prediction is completed before returning. - If `True`, wait a predetermined number of seconds until the prediction - is completed before returning. - If an `int`, wait for the specified number of seconds. + If `True`, keep the request open for up to 60 seconds, falling back to + polling until the prediction is completed. + If an `int`, same as True but hold the request for a specified number of + seconds (between 1 and 60). + If `False`, poll for the prediction status until completed. """ file_encoding_strategy: NotRequired[FileEncodingStrategy] diff --git a/replicate/run.py b/replicate/run.py index d159f11..3b6bddb 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -29,14 +29,18 @@ def run( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 """ Run a model and wait for its output. """ - is_blocking = "wait" in params + if "wait" not in params: + params["wait"] = True + is_blocking = params["wait"] != False # noqa: E712 + version, owner, name, version_id = identifier._resolve(ref) if version_id is not None: @@ -74,13 +78,18 @@ async def async_run( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401 """ Run a model and wait for its output asynchronously. """ + if "wait" not in params: + params["wait"] = True + is_blocking = params["wait"] != False # noqa: E712 + version, owner, name, version_id = identifier._resolve(ref) if version or version_id: @@ -102,7 +111,8 @@ async def async_run( if version and (iterator := _make_async_output_iterator(version, prediction)): return iterator - await prediction.async_wait() + if not (is_blocking and prediction.status != "starting"): + await prediction.async_wait() if prediction.status == "failed": raise ModelError(prediction) diff --git a/replicate/stream.py b/replicate/stream.py index 4cf0d15..e837abd 100644 --- a/replicate/stream.py +++ b/replicate/stream.py @@ -71,11 +71,12 @@ def __init__( self, client: "Client", response: "httpx.Response", - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, ) -> None: self.client = client self.response = response - self.use_file_output = use_file_output or False + self.use_file_output = use_file_output or True content_type, _, _ = response.headers["content-type"].partition(";") if content_type != "text/event-stream": raise ValueError( @@ -193,7 +194,8 @@ def stream( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Iterator[ServerSentEvent]: """ @@ -234,7 +236,8 @@ async def async_stream( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> AsyncIterator[ServerSentEvent]: """ diff --git a/tests/test_run.py b/tests/test_run.py index 7d963a4..8eac091 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -32,17 +32,20 @@ async def test_run(async_flag, record_mode): output = await replicate.async_run( f"stability-ai/sdxl:{version}", input=input, + use_file_output=True, ) else: output = replicate.run( f"stability-ai/sdxl:{version}", input=input, + use_file_output=True, ) assert output is not None assert isinstance(output, list) assert len(output) > 0 - assert output[0].startswith("https://") + assert isinstance(output[0], FileOutput) + assert output[0].url.startswith("https://") @pytest.mark.vcr("run__concurrently.yaml") @@ -123,7 +126,7 @@ def prediction_with_status(status: str) -> dict: router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("processing"), + json=prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( @@ -212,7 +215,7 @@ def prediction_with_status(status: str) -> dict: router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("processing"), + json=prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( @@ -327,6 +330,7 @@ def prediction_with_status( "text": "Hello, world!", }, use_file_output=True, + wait=False, ), ) @@ -409,7 +413,7 @@ def prediction_with_status( "text": "Hello, world!", }, use_file_output=True, - wait=True, + # wait=True (this is the default) ), ) @@ -454,7 +458,7 @@ def prediction_with_status( router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("processing"), + json=prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( @@ -505,6 +509,7 @@ def prediction_with_status( "text": "Hello, world!", }, use_file_output=True, + wait=False, ), ) @@ -541,7 +546,7 @@ def prediction_with_status( router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("processing"), + json=prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( @@ -583,6 +588,7 @@ def prediction_with_status( "text": "Hello, world!", }, use_file_output=True, + wait=False, ), )