Skip to content

Commit e7f699f

Browse files
aronmattt
andauthored
Introduce experimental FileOutput interface for models that output File and Path types (#348)
This PR is a proposal to add a new `FileOutput` type to the client SDK to abstract away file outputs from Replicate models. It can be enabled by passing the `use_file_output` flag to the `run()` method (this will be moved to the constructor). ```python replicate.run("black-forest-labs/flux-schnell", input={}, use_file_input =True); ``` When enabled any URLs (and soon data-uris) will be converted into a `FileOutput type`. This is essentially an `Iterable[bytes] | AsyncIterable[bytes]` that has two additional fields, the attribute `url` referencing underlying URL and `read()` which will return `bytes` with the file data loaded into memory. The intention here is to make it easier to work with file outputs and allows us to optimize the delivery of file assets to the client in future iterations. Usage is as follows: ```python output = replicate.run( "black-forest-labs/flux-schnell", input={"prompt": "astronaut riding a rocket like a horse"}, use_file_input=True, ); ``` For most basic cases you'll want to utilize either the `url` or `read()` fields depending on whether you want to directly consume the file or pass it on. To access the file URL: ```python print(output.url) #=> "https://delivery.replicate.com/..." ``` To consume the file directly: ```python with open('output.bin', 'wb') as file: file.write(output.read()) ``` Or for very large files they can be streamed: ```python with open(file_path, 'wb') as file: for chunk in output: file.write(chunk) ``` Each of these methods has an equivalent `asyncio` API. ```python async with aiofiles.open(filename, 'w') as file: await file.write(await output.aread()) async with aiofiles.open(filename, 'w') as file: await for chunk in output: await file.write(chunk) ``` For streaming responses from common frameworks, all support taking `Iterator` types: **Django** ```python @condition(etag_func=None) def stream_response(request): output = replicate.run("black-forest-labs/flux-schnell", input={...}, use_file_input =True) return HttpResponse(output, content_type='image/webp') ``` **FastAPI** ```python @app.get("/") async def main(): output = replicate.run("black-forest-labs/flux-schnell", input={...}, use_file_input =True) return StreamingResponse(output) ``` **Flask** ```python @app.route('/stream') def streamed_response(): output = replicate.run("black-forest-labs/flux-schnell", input={...}, use_file_input =True) return app.response_class(stream_with_context(output)) ``` --------- Signed-off-by: Mattt Zmuda <[email protected]> Co-authored-by: Mattt Zmuda <[email protected]>
1 parent e051739 commit e7f699f

13 files changed

+364
-11
lines changed

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ ignore = [
7575
"ANN003", # Missing type annotation for `**kwargs`
7676
"ANN101", # Missing type annotation for self in method
7777
"ANN102", # Missing type annotation for cls in classmethod
78+
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed in {name}
7879
"W191", # Indentation contains tabs
7980
"UP037", # Remove quotes from type annotation
8081
]
@@ -86,3 +87,7 @@ ignore = [
8687
"ANN201", # Missing return type annotation for public function
8788
"ANN202", # Missing return type annotation for private function
8889
]
90+
91+
[tool.pyright]
92+
venvPath = "."
93+
venv = ".venv"

replicate/client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,25 +164,27 @@ def run(
164164
self,
165165
ref: str,
166166
input: Optional[Dict[str, Any]] = None,
167+
use_file_output: Optional[bool] = None,
167168
**params: Unpack["Predictions.CreatePredictionParams"],
168169
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
169170
"""
170171
Run a model and wait for its output.
171172
"""
172173

173-
return run(self, ref, input, **params)
174+
return run(self, ref, input, use_file_output, **params)
174175

175176
async def async_run(
176177
self,
177178
ref: str,
178179
input: Optional[Dict[str, Any]] = None,
180+
use_file_output: Optional[bool] = None,
179181
**params: Unpack["Predictions.CreatePredictionParams"],
180182
) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401
181183
"""
182184
Run a model and wait for its output asynchronously.
183185
"""
184186

185-
return await async_run(self, ref, input, **params)
187+
return await async_run(self, ref, input, use_file_output, **params)
186188

187189
def stream(
188190
self,

replicate/deployment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing_extensions import Unpack, deprecated
44

55
from replicate.account import Account
6-
from replicate.json import async_encode_json, encode_json
6+
from replicate.helpers import async_encode_json, encode_json
77
from replicate.pagination import Page
88
from replicate.prediction import (
99
Prediction,

replicate/json.py renamed to replicate/helpers.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import base64
22
import io
33
import mimetypes
4+
from collections.abc import Mapping, Sequence
45
from pathlib import Path
56
from types import GeneratorType
6-
from typing import TYPE_CHECKING, Any, Optional
7+
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, Optional
8+
9+
import httpx
710

811
if TYPE_CHECKING:
912
from replicate.client import Client
@@ -108,3 +111,80 @@ def base64_encode_file(file: io.IOBase) -> str:
108111
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
109112
)
110113
return f"data:{mime_type};base64,{encoded_body}"
114+
115+
116+
class FileOutput(httpx.SyncByteStream, httpx.AsyncByteStream):
117+
"""
118+
An object that can be used to read the contents of an output file
119+
created by running a Replicate model.
120+
"""
121+
122+
url: str
123+
"""
124+
The file URL.
125+
"""
126+
127+
_client: "Client"
128+
129+
def __init__(self, url: str, client: "Client") -> None:
130+
self.url = url
131+
self._client = client
132+
133+
def read(self) -> bytes:
134+
if self.url.startswith("data:"):
135+
_, encoded = self.url.split(",", 1)
136+
return base64.b64decode(encoded)
137+
138+
with self._client._client.stream("GET", self.url) as response:
139+
response.raise_for_status()
140+
return response.read()
141+
142+
def __iter__(self) -> Iterator[bytes]:
143+
if self.url.startswith("data:"):
144+
yield self.read()
145+
return
146+
147+
with self._client._client.stream("GET", self.url) as response:
148+
response.raise_for_status()
149+
yield from response.iter_bytes()
150+
151+
async def aread(self) -> bytes:
152+
if self.url.startswith("data:"):
153+
_, encoded = self.url.split(",", 1)
154+
return base64.b64decode(encoded)
155+
156+
async with self._client._async_client.stream("GET", self.url) as response:
157+
response.raise_for_status()
158+
return await response.aread()
159+
160+
async def __aiter__(self) -> AsyncIterator[bytes]:
161+
if self.url.startswith("data:"):
162+
yield await self.aread()
163+
return
164+
165+
async with self._client._async_client.stream("GET", self.url) as response:
166+
response.raise_for_status()
167+
async for chunk in response.aiter_bytes():
168+
yield chunk
169+
170+
def __str__(self) -> str:
171+
return self.url
172+
173+
174+
def transform_output(value: Any, client: "Client") -> Any:
175+
"""
176+
Transform the output of a prediction to a `FileOutput` object if it's a URL.
177+
"""
178+
179+
def transform(obj: Any) -> Any:
180+
if isinstance(obj, Mapping):
181+
return {k: transform(v) for k, v in obj.items()}
182+
elif isinstance(obj, Sequence) and not isinstance(obj, str):
183+
return [transform(item) for item in obj]
184+
elif isinstance(obj, str) and (
185+
obj.startswith("https:") or obj.startswith("data:")
186+
):
187+
return FileOutput(obj, client)
188+
return obj
189+
190+
return transform(value)

replicate/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from typing_extensions import NotRequired, TypedDict, Unpack, deprecated
44

55
from replicate.exceptions import ReplicateException
6+
from replicate.helpers import async_encode_json, encode_json
67
from replicate.identifier import ModelVersionIdentifier
7-
from replicate.json import async_encode_json, encode_json
88
from replicate.pagination import Page
99
from replicate.prediction import (
1010
Prediction,

replicate/prediction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from replicate.exceptions import ModelError, ReplicateError
2222
from replicate.file import FileEncodingStrategy
23-
from replicate.json import async_encode_json, encode_json
23+
from replicate.helpers import async_encode_json, encode_json
2424
from replicate.pagination import Page
2525
from replicate.resource import Namespace, Resource
2626
from replicate.stream import EventSource

replicate/run.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from replicate import identifier
1515
from replicate.exceptions import ModelError
16+
from replicate.helpers import transform_output
1617
from replicate.model import Model
1718
from replicate.prediction import Prediction
1819
from replicate.schema import make_schema_backwards_compatible
@@ -28,6 +29,7 @@ def run(
2829
client: "Client",
2930
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
3031
input: Optional[Dict[str, Any]] = None,
32+
use_file_output: Optional[bool] = None,
3133
**params: Unpack["Predictions.CreatePredictionParams"],
3234
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
3335
"""
@@ -60,13 +62,17 @@ def run(
6062
if prediction.status == "failed":
6163
raise ModelError(prediction)
6264

65+
if use_file_output:
66+
return transform_output(prediction.output, client)
67+
6368
return prediction.output
6469

6570

6671
async def async_run(
6772
client: "Client",
6873
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
6974
input: Optional[Dict[str, Any]] = None,
75+
use_file_output: Optional[bool] = None,
7076
**params: Unpack["Predictions.CreatePredictionParams"],
7177
) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401
7278
"""
@@ -99,6 +105,9 @@ async def async_run(
99105
if prediction.status == "failed":
100106
raise ModelError(prediction)
101107

108+
if use_file_output:
109+
return transform_output(prediction.output, client)
110+
102111
return prediction.output
103112

104113

replicate/stream.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Union,
1111
)
1212

13+
import httpx
1314
from typing_extensions import Unpack
1415

1516
from replicate import identifier
@@ -22,8 +23,6 @@
2223

2324

2425
if TYPE_CHECKING:
25-
import httpx
26-
2726
from replicate.client import Client
2827
from replicate.identifier import ModelVersionIdentifier
2928
from replicate.model import Model

replicate/training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414
from typing_extensions import NotRequired, Unpack
1515

16+
from replicate.helpers import async_encode_json, encode_json
1617
from replicate.identifier import ModelVersionIdentifier
17-
from replicate.json import async_encode_json, encode_json
1818
from replicate.model import Model
1919
from replicate.pagination import Page
2020
from replicate.resource import Namespace, Resource

requirements-dev.lock

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
# features: []
77
# all-features: false
88
# with-sources: false
9+
# generate-hashes: false
10+
# universal: false
911

1012
-e file:.
1113
annotated-types==0.6.0

requirements.lock

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
# features: []
77
# all-features: false
88
# with-sources: false
9+
# generate-hashes: false
10+
# universal: false
911

1012
-e file:.
1113
annotated-types==0.6.0

tests/test_json.py renamed to tests/test_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from replicate.json import base64_encode_file
5+
from replicate.helpers import base64_encode_file
66

77

88
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)