Skip to content

Commit 9e88cb5

Browse files
committed
Implement experimental FileOutput interface
1 parent 3cc0b86 commit 9e88cb5

File tree

3 files changed

+207
-3
lines changed

3 files changed

+207
-3
lines changed

replicate/client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,25 +164,27 @@ def run(
164164
self,
165165
ref: str,
166166
input: Optional[Dict[str, Any]] = None,
167+
use_file_output: bool = False,
167168
**params: Unpack["Predictions.CreatePredictionParams"],
168169
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
169170
"""
170171
Run a model and wait for its output.
171172
"""
172173

173-
return run(self, ref, input, **params)
174+
return run(self, ref, input, use_file_output, **params)
174175

175176
async def async_run(
176177
self,
177178
ref: str,
178179
input: Optional[Dict[str, Any]] = None,
180+
use_file_output: bool = False,
179181
**params: Unpack["Predictions.CreatePredictionParams"],
180182
) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401
181183
"""
182184
Run a model and wait for its output asynchronously.
183185
"""
184186

185-
return await async_run(self, ref, input, **params)
187+
return await async_run(self, ref, input, use_file_output, **params)
186188

187189
def stream(
188190
self,

replicate/run.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from replicate.prediction import Prediction
1818
from replicate.schema import make_schema_backwards_compatible
1919
from replicate.version import Version, Versions
20+
from replicate.stream import FileOutputProvider
2021

2122
if TYPE_CHECKING:
2223
from replicate.client import Client
@@ -28,6 +29,7 @@ def run(
2829
client: "Client",
2930
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
3031
input: Optional[Dict[str, Any]] = None,
32+
use_file_output: bool = False,
3133
**params: Unpack["Predictions.CreatePredictionParams"],
3234
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
3335
"""
@@ -60,13 +62,17 @@ def run(
6062
if prediction.status == "failed":
6163
raise ModelError(prediction)
6264

65+
if use_file_output:
66+
return transform_output(prediction.output, client)
67+
6368
return prediction.output
6469

6570

6671
async def async_run(
6772
client: "Client",
6873
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
6974
input: Optional[Dict[str, Any]] = None,
75+
use_file_output: bool = False,
7076
**params: Unpack["Predictions.CreatePredictionParams"],
7177
) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401
7278
"""
@@ -99,6 +105,9 @@ async def async_run(
99105
if prediction.status == "failed":
100106
raise ModelError(prediction)
101107

108+
if use_file_output:
109+
return transform_output(prediction.output, client)
110+
102111
return prediction.output
103112

104113

@@ -130,4 +139,22 @@ def _make_async_output_iterator(
130139
return None
131140

132141

142+
def transform(obj, func):
143+
if isinstance(obj, dict):
144+
return {k: transform(v, func) for k, v in obj.items()}
145+
elif isinstance(obj, list):
146+
return [transform(item, func) for item in obj]
147+
else:
148+
return func(obj)
149+
150+
151+
def transform_output(value: Any, client: "Client"):
152+
def wrapper(x):
153+
if isinstance(x, str) and (x.startswith("https:") or x.startswith("data:")):
154+
return FileOutputProvider(x, client)
155+
return x
156+
157+
return transform(value, wrapper)
158+
159+
133160
__all__: List = []

tests/test_run.py

Lines changed: 176 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import pytest
66
import respx
77

8+
from typing import cast
89
import replicate
910
from replicate.client import Client
1011
from replicate.exceptions import ModelError, ReplicateError
12+
from replicate.stream import FileOutputProvider
1113

1214

1315
@pytest.mark.vcr("run.yaml")
@@ -73,7 +75,7 @@ async def test_run_concurrently(mock_replicate_api_token, record_mode):
7375
results = await asyncio.gather(*tasks)
7476
assert len(results) == len(prompts)
7577
assert all(isinstance(result, list) for result in results)
76-
assert all(len(result) > 0 for result in results)
78+
assert all(len(results) > 0 for result in results)
7779

7880

7981
@pytest.mark.vcr("run.yaml")
@@ -253,3 +255,176 @@ def prediction_with_status(status: str) -> dict:
253255
assert str(excinfo.value) == "OOM"
254256
assert excinfo.value.prediction.error == "OOM"
255257
assert excinfo.value.prediction.status == "failed"
258+
259+
260+
@pytest.mark.asyncio
261+
async def test_run_with_file_output(mock_replicate_api_token):
262+
def prediction_with_status(
263+
status: str, output: str | list[str] | None = None
264+
) -> dict:
265+
return {
266+
"id": "p1",
267+
"model": "test/example",
268+
"version": "v1",
269+
"urls": {
270+
"get": "https://api.replicate.com/v1/predictions/p1",
271+
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
272+
},
273+
"created_at": "2023-10-05T12:00:00.000000Z",
274+
"source": "api",
275+
"status": status,
276+
"input": {"text": "world"},
277+
"output": output,
278+
"error": "OOM" if status == "failed" else None,
279+
"logs": "",
280+
}
281+
282+
router = respx.Router(base_url="https://api.replicate.com/v1")
283+
router.route(method="POST", path="/predictions").mock(
284+
return_value=httpx.Response(
285+
201,
286+
json=prediction_with_status("processing"),
287+
)
288+
)
289+
router.route(method="GET", path="/predictions/p1").mock(
290+
return_value=httpx.Response(
291+
200,
292+
json=prediction_with_status(
293+
"succeeded", "https://api.replicate.com/v1/assets/output.txt"
294+
),
295+
)
296+
)
297+
router.route(
298+
method="GET",
299+
path="/models/test/example/versions/v1",
300+
).mock(
301+
return_value=httpx.Response(
302+
201,
303+
json={
304+
"id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1",
305+
"created_at": "2024-07-18T00:35:56.210272Z",
306+
"cog_version": "0.9.10",
307+
"openapi_schema": {
308+
"openapi": "3.0.2",
309+
},
310+
},
311+
)
312+
)
313+
router.route(method="GET", path="/assets/output.txt").mock(
314+
return_value=httpx.Response(200, content=b"Hello, world!")
315+
)
316+
317+
client = Client(
318+
api_token="test-token", transport=httpx.MockTransport(router.handler)
319+
)
320+
client.poll_interval = 0.001
321+
322+
output = cast(
323+
FileOutputProvider,
324+
client.run(
325+
"test/example:v1",
326+
input={
327+
"text": "Hello, world!",
328+
},
329+
use_file_output=True,
330+
),
331+
)
332+
333+
assert output.url == "https://api.replicate.com/v1/assets/output.txt"
334+
335+
assert output.read() == b"Hello, world!"
336+
with output.stream() as file:
337+
for chunk in file:
338+
assert chunk == b"Hello, world!"
339+
340+
assert await output.aread() == b"Hello, world!"
341+
async with output.astream() as file:
342+
async for chunk in file:
343+
assert chunk == b"Hello, world!"
344+
345+
346+
@pytest.mark.asyncio
347+
async def test_run_with_file_output_array(mock_replicate_api_token):
348+
def prediction_with_status(
349+
status: str, output: str | list[str] | None = None
350+
) -> dict:
351+
return {
352+
"id": "p1",
353+
"model": "test/example",
354+
"version": "v1",
355+
"urls": {
356+
"get": "https://api.replicate.com/v1/predictions/p1",
357+
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
358+
},
359+
"created_at": "2023-10-05T12:00:00.000000Z",
360+
"source": "api",
361+
"status": status,
362+
"input": {"text": "world"},
363+
"output": output,
364+
"error": "OOM" if status == "failed" else None,
365+
"logs": "",
366+
}
367+
368+
router = respx.Router(base_url="https://api.replicate.com/v1")
369+
router.route(method="POST", path="/predictions").mock(
370+
return_value=httpx.Response(
371+
201,
372+
json=prediction_with_status("processing"),
373+
)
374+
)
375+
router.route(method="GET", path="/predictions/p1").mock(
376+
return_value=httpx.Response(
377+
200,
378+
json=prediction_with_status(
379+
"succeeded",
380+
[
381+
"https://api.replicate.com/v1/assets/hello.txt",
382+
"https://api.replicate.com/v1/assets/world.txt",
383+
],
384+
),
385+
)
386+
)
387+
router.route(
388+
method="GET",
389+
path="/models/test/example/versions/v1",
390+
).mock(
391+
return_value=httpx.Response(
392+
201,
393+
json={
394+
"id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1",
395+
"created_at": "2024-07-18T00:35:56.210272Z",
396+
"cog_version": "0.9.10",
397+
"openapi_schema": {
398+
"openapi": "3.0.2",
399+
},
400+
},
401+
)
402+
)
403+
router.route(method="GET", path="/assets/hello.txt").mock(
404+
return_value=httpx.Response(200, content=b"Hello,")
405+
)
406+
router.route(method="GET", path="/assets/world.txt").mock(
407+
return_value=httpx.Response(200, content=b" world!")
408+
)
409+
410+
client = Client(
411+
api_token="test-token", transport=httpx.MockTransport(router.handler)
412+
)
413+
client.poll_interval = 0.001
414+
415+
[output1, output2] = cast(
416+
list[FileOutputProvider],
417+
client.run(
418+
"test/example:v1",
419+
input={
420+
"text": "Hello, world!",
421+
},
422+
use_file_output=True,
423+
),
424+
)
425+
426+
assert output1.url == "https://api.replicate.com/v1/assets/hello.txt"
427+
assert output2.url == "https://api.replicate.com/v1/assets/world.txt"
428+
429+
assert output1.read() == b"Hello,"
430+
assert output2.read() == b" world!"

0 commit comments

Comments
 (0)