15
15
16
16
from replicate import identifier
17
17
from replicate .exceptions import ReplicateError
18
+ from replicate .helpers import transform_output
18
19
19
20
try :
20
21
from pydantic import v1 as pydantic # type: ignore
@@ -62,10 +63,19 @@ class EventSource:
62
63
A server-sent event source.
63
64
"""
64
65
66
+ client : "Client"
65
67
response : "httpx.Response"
66
-
67
- def __init__ (self , response : "httpx.Response" ) -> None :
68
+ use_file_output : bool
69
+
70
+ def __init__ (
71
+ self ,
72
+ client : "Client" ,
73
+ response : "httpx.Response" ,
74
+ use_file_output : Optional [bool ] = None ,
75
+ ) -> None :
76
+ self .client = client
68
77
self .response = response
78
+ self .use_file_output = use_file_output or False
69
79
content_type , _ , _ = response .headers ["content-type" ].partition (";" )
70
80
if content_type != "text/event-stream" :
71
81
raise ValueError (
@@ -147,6 +157,12 @@ def __iter__(self) -> Iterator[ServerSentEvent]:
147
157
if sse .event == ServerSentEvent .EventType .ERROR :
148
158
raise RuntimeError (sse .data )
149
159
160
+ if (
161
+ self .use_file_output
162
+ and sse .event == ServerSentEvent .EventType .OUTPUT
163
+ ):
164
+ sse .data = transform_output (sse .data , client = self .client )
165
+
150
166
yield sse
151
167
152
168
if sse .event == ServerSentEvent .EventType .DONE :
@@ -161,6 +177,12 @@ async def __aiter__(self) -> AsyncIterator[ServerSentEvent]:
161
177
if sse .event == ServerSentEvent .EventType .ERROR :
162
178
raise RuntimeError (sse .data )
163
179
180
+ if (
181
+ self .use_file_output
182
+ and sse .event == ServerSentEvent .EventType .OUTPUT
183
+ ):
184
+ sse .data = transform_output (sse .data , client = self .client )
185
+
164
186
yield sse
165
187
166
188
if sse .event == ServerSentEvent .EventType .DONE :
@@ -171,6 +193,7 @@ def stream(
171
193
client : "Client" ,
172
194
ref : Union ["Model" , "Version" , "ModelVersionIdentifier" , str ],
173
195
input : Optional [Dict [str , Any ]] = None ,
196
+ use_file_output : Optional [bool ] = None ,
174
197
** params : Unpack ["Predictions.CreatePredictionParams" ],
175
198
) -> Iterator [ServerSentEvent ]:
176
199
"""
@@ -204,13 +227,14 @@ def stream(
204
227
headers ["Cache-Control" ] = "no-store"
205
228
206
229
with client ._client .stream ("GET" , url , headers = headers ) as response :
207
- yield from EventSource (response )
230
+ yield from EventSource (client , response , use_file_output = use_file_output )
208
231
209
232
210
233
async def async_stream (
211
234
client : "Client" ,
212
235
ref : Union ["Model" , "Version" , "ModelVersionIdentifier" , str ],
213
236
input : Optional [Dict [str , Any ]] = None ,
237
+ use_file_output : Optional [bool ] = None ,
214
238
** params : Unpack ["Predictions.CreatePredictionParams" ],
215
239
) -> AsyncIterator [ServerSentEvent ]:
216
240
"""
@@ -244,7 +268,9 @@ async def async_stream(
244
268
headers ["Cache-Control" ] = "no-store"
245
269
246
270
async with client ._async_client .stream ("GET" , url , headers = headers ) as response :
247
- async for event in EventSource (response ):
271
+ async for event in EventSource (
272
+ client , response , use_file_output = use_file_output
273
+ ):
248
274
yield event
249
275
250
276
0 commit comments