From 9a8fc8f62c5d2d79832d5c80ec42c29ebab4c4c2 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 23 Sep 2024 05:10:00 -0700 Subject: [PATCH] Add use_file_output to streaming methods Signed-off-by: Mattt Zmuda --- replicate/client.py | 6 ++++-- replicate/prediction.py | 18 ++++++++++++++---- replicate/stream.py | 34 ++++++++++++++++++++++++++++++---- 3 files changed, 48 insertions(+), 10 deletions(-) diff --git a/replicate/client.py b/replicate/client.py index 3da3cc15..52d07f70 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -190,25 +190,27 @@ def stream( self, ref: str, input: Optional[Dict[str, Any]] = None, + use_file_output: Optional[bool] = None, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Iterator["ServerSentEvent"]: """ Stream a model's output. """ - return stream(self, ref, input, **params) + return stream(self, ref, input, use_file_output, **params) async def async_stream( self, ref: str, input: Optional[Dict[str, Any]] = None, + use_file_output: Optional[bool] = None, **params: Unpack["Predictions.CreatePredictionParams"], ) -> AsyncIterator["ServerSentEvent"]: """ Stream a model's output asynchronously. """ - return async_stream(self, ref, input, **params) + return async_stream(self, ref, input, use_file_output, **params) # Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155 diff --git a/replicate/prediction.py b/replicate/prediction.py index 9770029b..a6204748 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -153,7 +153,10 @@ async def async_wait(self) -> None: await asyncio.sleep(self._client.poll_interval) await self.async_reload() - def stream(self) -> Iterator["ServerSentEvent"]: + def stream( + self, + use_file_output: Optional[bool] = None, + ) -> Iterator["ServerSentEvent"]: """ Stream the prediction output. @@ -170,9 +173,14 @@ def stream(self) -> Iterator["ServerSentEvent"]: headers["Cache-Control"] = "no-store" with self._client._client.stream("GET", url, headers=headers) as response: - yield from EventSource(response) + yield from EventSource( + self._client, response, use_file_output=use_file_output + ) - async def async_stream(self) -> AsyncIterator["ServerSentEvent"]: + async def async_stream( + self, + use_file_output: Optional[bool] = None, + ) -> AsyncIterator["ServerSentEvent"]: """ Stream the prediction output asynchronously. @@ -194,7 +202,9 @@ async def async_stream(self) -> AsyncIterator["ServerSentEvent"]: async with self._client._async_client.stream( "GET", url, headers=headers ) as response: - async for event in EventSource(response): + async for event in EventSource( + self._client, response, use_file_output=use_file_output + ): yield event def cancel(self) -> None: diff --git a/replicate/stream.py b/replicate/stream.py index 3472799e..4cf0d156 100644 --- a/replicate/stream.py +++ b/replicate/stream.py @@ -15,6 +15,7 @@ from replicate import identifier from replicate.exceptions import ReplicateError +from replicate.helpers import transform_output try: from pydantic import v1 as pydantic # type: ignore @@ -62,10 +63,19 @@ class EventSource: A server-sent event source. """ + client: "Client" response: "httpx.Response" - - def __init__(self, response: "httpx.Response") -> None: + use_file_output: bool + + def __init__( + self, + client: "Client", + response: "httpx.Response", + use_file_output: Optional[bool] = None, + ) -> None: + self.client = client self.response = response + self.use_file_output = use_file_output or False content_type, _, _ = response.headers["content-type"].partition(";") if content_type != "text/event-stream": raise ValueError( @@ -147,6 +157,12 @@ def __iter__(self) -> Iterator[ServerSentEvent]: if sse.event == ServerSentEvent.EventType.ERROR: raise RuntimeError(sse.data) + if ( + self.use_file_output + and sse.event == ServerSentEvent.EventType.OUTPUT + ): + sse.data = transform_output(sse.data, client=self.client) + yield sse if sse.event == ServerSentEvent.EventType.DONE: @@ -161,6 +177,12 @@ async def __aiter__(self) -> AsyncIterator[ServerSentEvent]: if sse.event == ServerSentEvent.EventType.ERROR: raise RuntimeError(sse.data) + if ( + self.use_file_output + and sse.event == ServerSentEvent.EventType.OUTPUT + ): + sse.data = transform_output(sse.data, client=self.client) + yield sse if sse.event == ServerSentEvent.EventType.DONE: @@ -171,6 +193,7 @@ def stream( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, + use_file_output: Optional[bool] = None, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Iterator[ServerSentEvent]: """ @@ -204,13 +227,14 @@ def stream( headers["Cache-Control"] = "no-store" with client._client.stream("GET", url, headers=headers) as response: - yield from EventSource(response) + yield from EventSource(client, response, use_file_output=use_file_output) async def async_stream( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, + use_file_output: Optional[bool] = None, **params: Unpack["Predictions.CreatePredictionParams"], ) -> AsyncIterator[ServerSentEvent]: """ @@ -244,7 +268,9 @@ async def async_stream( headers["Cache-Control"] = "no-store" async with client._async_client.stream("GET", url, headers=headers) as response: - async for event in EventSource(response): + async for event in EventSource( + client, response, use_file_output=use_file_output + ): yield event