Skip to content

Commit 7b2e7f4

Browse files
committed
Implement experimental FileOutput interface
1 parent 3cc0b86 commit 7b2e7f4

File tree

4 files changed

+222
-40
lines changed

4 files changed

+222
-40
lines changed

replicate/client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,25 +164,27 @@ def run(
164164
self,
165165
ref: str,
166166
input: Optional[Dict[str, Any]] = None,
167+
use_file_output: bool = False,
167168
**params: Unpack["Predictions.CreatePredictionParams"],
168169
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
169170
"""
170171
Run a model and wait for its output.
171172
"""
172173

173-
return run(self, ref, input, **params)
174+
return run(self, ref, input, use_file_output, **params)
174175

175176
async def async_run(
176177
self,
177178
ref: str,
178179
input: Optional[Dict[str, Any]] = None,
180+
use_file_output: bool = False,
179181
**params: Unpack["Predictions.CreatePredictionParams"],
180182
) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401
181183
"""
182184
Run a model and wait for its output asynchronously.
183185
"""
184186

185-
return await async_run(self, ref, input, **params)
187+
return await async_run(self, ref, input, use_file_output, **params)
186188

187189
def stream(
188190
self,

replicate/run.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from replicate.prediction import Prediction
1818
from replicate.schema import make_schema_backwards_compatible
1919
from replicate.version import Version, Versions
20+
from replicate.stream import FileOutput
2021

2122
if TYPE_CHECKING:
2223
from replicate.client import Client
@@ -28,6 +29,7 @@ def run(
2829
client: "Client",
2930
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
3031
input: Optional[Dict[str, Any]] = None,
32+
use_file_output: bool = False,
3133
**params: Unpack["Predictions.CreatePredictionParams"],
3234
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
3335
"""
@@ -60,13 +62,17 @@ def run(
6062
if prediction.status == "failed":
6163
raise ModelError(prediction)
6264

65+
if use_file_output:
66+
return transform_output(prediction.output, client)
67+
6368
return prediction.output
6469

6570

6671
async def async_run(
6772
client: "Client",
6873
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
6974
input: Optional[Dict[str, Any]] = None,
75+
use_file_output: bool = False,
7076
**params: Unpack["Predictions.CreatePredictionParams"],
7177
) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401
7278
"""
@@ -99,6 +105,9 @@ async def async_run(
99105
if prediction.status == "failed":
100106
raise ModelError(prediction)
101107

108+
if use_file_output:
109+
return transform_output(prediction.output, client)
110+
102111
return prediction.output
103112

104113

@@ -130,4 +139,22 @@ def _make_async_output_iterator(
130139
return None
131140

132141

142+
def transform(obj, func):
143+
if isinstance(obj, dict):
144+
return {k: transform(v, func) for k, v in obj.items()}
145+
elif isinstance(obj, list):
146+
return [transform(item, func) for item in obj]
147+
else:
148+
return func(obj)
149+
150+
151+
def transform_output(value: Any, client: "Client"):
152+
def wrapper(x):
153+
if isinstance(x, str) and (x.startswith("https:") or x.startswith("data:")):
154+
return FileOutput(x, client)
155+
return x
156+
157+
return transform(value, wrapper)
158+
159+
133160
__all__: List = []

replicate/stream.py

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from replicate.version import Version
3434

3535

36-
class FileOutputProvider:
36+
class FileOutput(httpx.ByteStream, httpx.AsyncByteStream):
3737
url: str
3838
client: "Client"
3939

@@ -42,52 +42,32 @@ def __init__(self, url: str, client: "Client"):
4242
self.client = client
4343

4444
def read(self) -> bytes:
45-
with self.stream() as file:
46-
return file.read()
47-
48-
@contextmanager
49-
def stream(self) -> Iterator["FileOutput"]:
5045
with self.client._client.stream("GET", self.url) as response:
5146
response.raise_for_status()
52-
yield FileOutput(response)
47+
return response.read()
5348

54-
@asynccontextmanager
55-
async def astream(self) -> AsyncIterator["FileOutput"]:
56-
async with self.client._async_client.stream("GET", self.url) as response:
49+
def __iter__(self) -> Iterator[bytes]:
50+
with self.client._client.stream("GET", self.url) as response:
5751
response.raise_for_status()
58-
yield FileOutput(response)
52+
for chunk in response.iter_bytes():
53+
yield chunk
5954

6055
async def aread(self) -> bytes:
61-
async with self.astream() as file:
62-
return await file.aread()
63-
64-
def __repr__(self) -> str:
65-
return self.url
66-
67-
68-
class FileOutput(httpx.ByteStream, httpx.AsyncByteStream):
69-
def __init__(self, response: httpx.Response):
70-
self.response = response
71-
72-
def __iter__(self) -> Iterator[bytes]:
73-
for bytes in self.response.iter_bytes():
74-
yield bytes
75-
76-
def close(self):
77-
return self.response.close()
78-
79-
def read(self):
80-
return self.response.read()
56+
async with self.client._async_client.stream("GET", self.url) as response:
57+
response.raise_for_status()
58+
return await response.aread()
8159

8260
async def __aiter__(self) -> AsyncIterator[bytes]:
83-
async for bytes in self.response.aiter_bytes():
84-
yield bytes
61+
async with self.client._async_client.stream("GET", self.url) as response:
62+
response.raise_for_status()
63+
async for chunk in response.aiter_bytes():
64+
yield chunk
8565

86-
async def aclose(self):
87-
return await self.response.aclose()
66+
def __str__(self) -> str:
67+
return self.url
8868

89-
async def aread(self):
90-
return await self.response.aread()
69+
def __repr__(self) -> str:
70+
return self.url
9171

9272

9373
class ServerSentEvent(pydantic.BaseModel): # type: ignore

tests/test_run.py

Lines changed: 174 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import pytest
66
import respx
77

8+
from typing import cast
89
import replicate
910
from replicate.client import Client
1011
from replicate.exceptions import ModelError, ReplicateError
12+
from replicate.stream import FileOutput
1113

1214

1315
@pytest.mark.vcr("run.yaml")
@@ -73,7 +75,7 @@ async def test_run_concurrently(mock_replicate_api_token, record_mode):
7375
results = await asyncio.gather(*tasks)
7476
assert len(results) == len(prompts)
7577
assert all(isinstance(result, list) for result in results)
76-
assert all(len(result) > 0 for result in results)
78+
assert all(len(results) > 0 for result in results)
7779

7880

7981
@pytest.mark.vcr("run.yaml")
@@ -253,3 +255,174 @@ def prediction_with_status(status: str) -> dict:
253255
assert str(excinfo.value) == "OOM"
254256
assert excinfo.value.prediction.error == "OOM"
255257
assert excinfo.value.prediction.status == "failed"
258+
259+
260+
@pytest.mark.asyncio
261+
async def test_run_with_file_output(mock_replicate_api_token):
262+
def prediction_with_status(
263+
status: str, output: str | list[str] | None = None
264+
) -> dict:
265+
return {
266+
"id": "p1",
267+
"model": "test/example",
268+
"version": "v1",
269+
"urls": {
270+
"get": "https://api.replicate.com/v1/predictions/p1",
271+
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
272+
},
273+
"created_at": "2023-10-05T12:00:00.000000Z",
274+
"source": "api",
275+
"status": status,
276+
"input": {"text": "world"},
277+
"output": output,
278+
"error": "OOM" if status == "failed" else None,
279+
"logs": "",
280+
}
281+
282+
router = respx.Router(base_url="https://api.replicate.com/v1")
283+
router.route(method="POST", path="/predictions").mock(
284+
return_value=httpx.Response(
285+
201,
286+
json=prediction_with_status("processing"),
287+
)
288+
)
289+
router.route(method="GET", path="/predictions/p1").mock(
290+
return_value=httpx.Response(
291+
200,
292+
json=prediction_with_status(
293+
"succeeded", "https://api.replicate.com/v1/assets/output.txt"
294+
),
295+
)
296+
)
297+
router.route(
298+
method="GET",
299+
path="/models/test/example/versions/v1",
300+
).mock(
301+
return_value=httpx.Response(
302+
201,
303+
json={
304+
"id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1",
305+
"created_at": "2024-07-18T00:35:56.210272Z",
306+
"cog_version": "0.9.10",
307+
"openapi_schema": {
308+
"openapi": "3.0.2",
309+
},
310+
},
311+
)
312+
)
313+
router.route(method="GET", path="/assets/output.txt").mock(
314+
return_value=httpx.Response(200, content=b"Hello, world!")
315+
)
316+
317+
client = Client(
318+
api_token="test-token", transport=httpx.MockTransport(router.handler)
319+
)
320+
client.poll_interval = 0.001
321+
322+
output = cast(
323+
FileOutput,
324+
client.run(
325+
"test/example:v1",
326+
input={
327+
"text": "Hello, world!",
328+
},
329+
use_file_output=True,
330+
),
331+
)
332+
333+
assert output.url == "https://api.replicate.com/v1/assets/output.txt"
334+
335+
assert output.read() == b"Hello, world!"
336+
for chunk in output:
337+
assert chunk == b"Hello, world!"
338+
339+
assert await output.aread() == b"Hello, world!"
340+
async for chunk in output:
341+
assert chunk == b"Hello, world!"
342+
343+
344+
@pytest.mark.asyncio
345+
async def test_run_with_file_output_array(mock_replicate_api_token):
346+
def prediction_with_status(
347+
status: str, output: str | list[str] | None = None
348+
) -> dict:
349+
return {
350+
"id": "p1",
351+
"model": "test/example",
352+
"version": "v1",
353+
"urls": {
354+
"get": "https://api.replicate.com/v1/predictions/p1",
355+
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
356+
},
357+
"created_at": "2023-10-05T12:00:00.000000Z",
358+
"source": "api",
359+
"status": status,
360+
"input": {"text": "world"},
361+
"output": output,
362+
"error": "OOM" if status == "failed" else None,
363+
"logs": "",
364+
}
365+
366+
router = respx.Router(base_url="https://api.replicate.com/v1")
367+
router.route(method="POST", path="/predictions").mock(
368+
return_value=httpx.Response(
369+
201,
370+
json=prediction_with_status("processing"),
371+
)
372+
)
373+
router.route(method="GET", path="/predictions/p1").mock(
374+
return_value=httpx.Response(
375+
200,
376+
json=prediction_with_status(
377+
"succeeded",
378+
[
379+
"https://api.replicate.com/v1/assets/hello.txt",
380+
"https://api.replicate.com/v1/assets/world.txt",
381+
],
382+
),
383+
)
384+
)
385+
router.route(
386+
method="GET",
387+
path="/models/test/example/versions/v1",
388+
).mock(
389+
return_value=httpx.Response(
390+
201,
391+
json={
392+
"id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1",
393+
"created_at": "2024-07-18T00:35:56.210272Z",
394+
"cog_version": "0.9.10",
395+
"openapi_schema": {
396+
"openapi": "3.0.2",
397+
},
398+
},
399+
)
400+
)
401+
router.route(method="GET", path="/assets/hello.txt").mock(
402+
return_value=httpx.Response(200, content=b"Hello,")
403+
)
404+
router.route(method="GET", path="/assets/world.txt").mock(
405+
return_value=httpx.Response(200, content=b" world!")
406+
)
407+
408+
client = Client(
409+
api_token="test-token", transport=httpx.MockTransport(router.handler)
410+
)
411+
client.poll_interval = 0.001
412+
413+
[output1, output2] = cast(
414+
list[FileOutput],
415+
client.run(
416+
"test/example:v1",
417+
input={
418+
"text": "Hello, world!",
419+
},
420+
use_file_output=True,
421+
),
422+
)
423+
424+
assert output1.url == "https://api.replicate.com/v1/assets/hello.txt"
425+
assert output2.url == "https://api.replicate.com/v1/assets/world.txt"
426+
427+
assert output1.read() == b"Hello,"
428+
assert output2.read() == b" world!"

0 commit comments

Comments
 (0)