Skip to content

Commit e53bd02

Browse files
authored
Add use_file_output to streaming methods (#355)
Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 4885f19 commit e53bd02

File tree

3 files changed

+48
-10
lines changed

3 files changed

+48
-10
lines changed

replicate/client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,25 +190,27 @@ def stream(
190190
self,
191191
ref: str,
192192
input: Optional[Dict[str, Any]] = None,
193+
use_file_output: Optional[bool] = None,
193194
**params: Unpack["Predictions.CreatePredictionParams"],
194195
) -> Iterator["ServerSentEvent"]:
195196
"""
196197
Stream a model's output.
197198
"""
198199

199-
return stream(self, ref, input, **params)
200+
return stream(self, ref, input, use_file_output, **params)
200201

201202
async def async_stream(
202203
self,
203204
ref: str,
204205
input: Optional[Dict[str, Any]] = None,
206+
use_file_output: Optional[bool] = None,
205207
**params: Unpack["Predictions.CreatePredictionParams"],
206208
) -> AsyncIterator["ServerSentEvent"]:
207209
"""
208210
Stream a model's output asynchronously.
209211
"""
210212

211-
return async_stream(self, ref, input, **params)
213+
return async_stream(self, ref, input, use_file_output, **params)
212214

213215

214216
# Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155

replicate/prediction.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,10 @@ async def async_wait(self) -> None:
153153
await asyncio.sleep(self._client.poll_interval)
154154
await self.async_reload()
155155

156-
def stream(self) -> Iterator["ServerSentEvent"]:
156+
def stream(
157+
self,
158+
use_file_output: Optional[bool] = None,
159+
) -> Iterator["ServerSentEvent"]:
157160
"""
158161
Stream the prediction output.
159162
@@ -170,9 +173,14 @@ def stream(self) -> Iterator["ServerSentEvent"]:
170173
headers["Cache-Control"] = "no-store"
171174

172175
with self._client._client.stream("GET", url, headers=headers) as response:
173-
yield from EventSource(response)
176+
yield from EventSource(
177+
self._client, response, use_file_output=use_file_output
178+
)
174179

175-
async def async_stream(self) -> AsyncIterator["ServerSentEvent"]:
180+
async def async_stream(
181+
self,
182+
use_file_output: Optional[bool] = None,
183+
) -> AsyncIterator["ServerSentEvent"]:
176184
"""
177185
Stream the prediction output asynchronously.
178186
@@ -194,7 +202,9 @@ async def async_stream(self) -> AsyncIterator["ServerSentEvent"]:
194202
async with self._client._async_client.stream(
195203
"GET", url, headers=headers
196204
) as response:
197-
async for event in EventSource(response):
205+
async for event in EventSource(
206+
self._client, response, use_file_output=use_file_output
207+
):
198208
yield event
199209

200210
def cancel(self) -> None:

replicate/stream.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from replicate import identifier
1717
from replicate.exceptions import ReplicateError
18+
from replicate.helpers import transform_output
1819

1920
try:
2021
from pydantic import v1 as pydantic # type: ignore
@@ -62,10 +63,19 @@ class EventSource:
6263
A server-sent event source.
6364
"""
6465

66+
client: "Client"
6567
response: "httpx.Response"
66-
67-
def __init__(self, response: "httpx.Response") -> None:
68+
use_file_output: bool
69+
70+
def __init__(
71+
self,
72+
client: "Client",
73+
response: "httpx.Response",
74+
use_file_output: Optional[bool] = None,
75+
) -> None:
76+
self.client = client
6877
self.response = response
78+
self.use_file_output = use_file_output or False
6979
content_type, _, _ = response.headers["content-type"].partition(";")
7080
if content_type != "text/event-stream":
7181
raise ValueError(
@@ -147,6 +157,12 @@ def __iter__(self) -> Iterator[ServerSentEvent]:
147157
if sse.event == ServerSentEvent.EventType.ERROR:
148158
raise RuntimeError(sse.data)
149159

160+
if (
161+
self.use_file_output
162+
and sse.event == ServerSentEvent.EventType.OUTPUT
163+
):
164+
sse.data = transform_output(sse.data, client=self.client)
165+
150166
yield sse
151167

152168
if sse.event == ServerSentEvent.EventType.DONE:
@@ -161,6 +177,12 @@ async def __aiter__(self) -> AsyncIterator[ServerSentEvent]:
161177
if sse.event == ServerSentEvent.EventType.ERROR:
162178
raise RuntimeError(sse.data)
163179

180+
if (
181+
self.use_file_output
182+
and sse.event == ServerSentEvent.EventType.OUTPUT
183+
):
184+
sse.data = transform_output(sse.data, client=self.client)
185+
164186
yield sse
165187

166188
if sse.event == ServerSentEvent.EventType.DONE:
@@ -171,6 +193,7 @@ def stream(
171193
client: "Client",
172194
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
173195
input: Optional[Dict[str, Any]] = None,
196+
use_file_output: Optional[bool] = None,
174197
**params: Unpack["Predictions.CreatePredictionParams"],
175198
) -> Iterator[ServerSentEvent]:
176199
"""
@@ -204,13 +227,14 @@ def stream(
204227
headers["Cache-Control"] = "no-store"
205228

206229
with client._client.stream("GET", url, headers=headers) as response:
207-
yield from EventSource(response)
230+
yield from EventSource(client, response, use_file_output=use_file_output)
208231

209232

210233
async def async_stream(
211234
client: "Client",
212235
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
213236
input: Optional[Dict[str, Any]] = None,
237+
use_file_output: Optional[bool] = None,
214238
**params: Unpack["Predictions.CreatePredictionParams"],
215239
) -> AsyncIterator[ServerSentEvent]:
216240
"""
@@ -244,7 +268,9 @@ async def async_stream(
244268
headers["Cache-Control"] = "no-store"
245269

246270
async with client._async_client.stream("GET", url, headers=headers) as response:
247-
async for event in EventSource(response):
271+
async for event in EventSource(
272+
client, response, use_file_output=use_file_output
273+
):
248274
yield event
249275

250276

0 commit comments

Comments
 (0)