Skip to content

Commit f57ce38

Browse files
aronzeke
authored andcommitted
Update stream interface to always use FileOutput
1 parent 25eb355 commit f57ce38

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

replicate/client.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ def run(
164164
self,
165165
ref: str,
166166
input: Optional[Dict[str, Any]] = None,
167-
use_file_output: Optional[bool] = None,
167+
*,
168+
use_file_output: Optional[bool] = True,
168169
**params: Unpack["Predictions.CreatePredictionParams"],
169170
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
170171
"""
@@ -177,7 +178,8 @@ async def async_run(
177178
self,
178179
ref: str,
179180
input: Optional[Dict[str, Any]] = None,
180-
use_file_output: Optional[bool] = None,
181+
*,
182+
use_file_output: Optional[bool] = True,
181183
**params: Unpack["Predictions.CreatePredictionParams"],
182184
) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401
183185
"""
@@ -191,28 +193,30 @@ async def async_run(
191193
def stream(
192194
self,
193195
ref: str,
196+
*,
194197
input: Optional[Dict[str, Any]] = None,
195-
use_file_output: Optional[bool] = None,
198+
use_file_output: Optional[bool] = True,
196199
**params: Unpack["Predictions.CreatePredictionParams"],
197200
) -> Iterator["ServerSentEvent"]:
198201
"""
199202
Stream a model's output.
200203
"""
201204

202-
return stream(self, ref, input, use_file_output, **params)
205+
return stream(self, ref, input, use_file_output=use_file_output, **params)
203206

204207
async def async_stream(
205208
self,
206209
ref: str,
207210
input: Optional[Dict[str, Any]] = None,
208-
use_file_output: Optional[bool] = None,
211+
*,
212+
use_file_output: Optional[bool] = True,
209213
**params: Unpack["Predictions.CreatePredictionParams"],
210214
) -> AsyncIterator["ServerSentEvent"]:
211215
"""
212216
Stream a model's output asynchronously.
213217
"""
214218

215-
return async_stream(self, ref, input, use_file_output, **params)
219+
return async_stream(self, ref, input, use_file_output=use_file_output, **params)
216220

217221

218222
# Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155

replicate/stream.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,12 @@ def __init__(
7171
self,
7272
client: "Client",
7373
response: "httpx.Response",
74-
use_file_output: Optional[bool] = None,
74+
*,
75+
use_file_output: Optional[bool] = True,
7576
) -> None:
7677
self.client = client
7778
self.response = response
78-
self.use_file_output = use_file_output or False
79+
self.use_file_output = use_file_output or True
7980
content_type, _, _ = response.headers["content-type"].partition(";")
8081
if content_type != "text/event-stream":
8182
raise ValueError(
@@ -193,7 +194,8 @@ def stream(
193194
client: "Client",
194195
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
195196
input: Optional[Dict[str, Any]] = None,
196-
use_file_output: Optional[bool] = None,
197+
*,
198+
use_file_output: Optional[bool] = True,
197199
**params: Unpack["Predictions.CreatePredictionParams"],
198200
) -> Iterator[ServerSentEvent]:
199201
"""
@@ -234,7 +236,8 @@ async def async_stream(
234236
client: "Client",
235237
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
236238
input: Optional[Dict[str, Any]] = None,
237-
use_file_output: Optional[bool] = None,
239+
*,
240+
use_file_output: Optional[bool] = True,
238241
**params: Unpack["Predictions.CreatePredictionParams"],
239242
) -> AsyncIterator[ServerSentEvent]:
240243
"""

0 commit comments

Comments
 (0)