Skip to content

Commit c0a4b13

Browse files
committed
Add support to FileOutput to read data-uris
1 parent 04be83f commit c0a4b13

File tree

2 files changed

+97
-0
lines changed

2 files changed

+97
-0
lines changed

replicate/helpers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,21 +131,37 @@ def __init__(self, url: str, client: "Client") -> None:
131131
self._client = client
132132

133133
def read(self) -> bytes:
134+
if self.url.startswith("data:"):
135+
_, encoded = self.url.split(",", 1)
136+
return base64.b64decode(encoded)
137+
134138
with self._client._client.stream("GET", self.url) as response:
135139
response.raise_for_status()
136140
return response.read()
137141

138142
def __iter__(self) -> Iterator[bytes]:
143+
if self.url.startswith("data:"):
144+
yield self.read()
145+
return
146+
139147
with self._client._client.stream("GET", self.url) as response:
140148
response.raise_for_status()
141149
yield from response.iter_bytes()
142150

143151
async def aread(self) -> bytes:
152+
if self.url.startswith("data:"):
153+
_, encoded = self.url.split(",", 1)
154+
return base64.b64decode(encoded)
155+
144156
async with self._client._async_client.stream("GET", self.url) as response:
145157
response.raise_for_status()
146158
return await response.aread()
147159

148160
async def __aiter__(self) -> AsyncIterator[bytes]:
161+
if self.url.startswith("data:"):
162+
yield await self.aread()
163+
return
164+
149165
async with self._client._async_client.stream("GET", self.url) as response:
150166
response.raise_for_status()
151167
async for chunk in response.aiter_bytes():

tests/test_run.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,3 +426,84 @@ def prediction_with_status(
426426

427427
assert output1.read() == b"Hello,"
428428
assert output2.read() == b" world!"
429+
430+
431+
@pytest.mark.asyncio
432+
async def test_run_with_file_output_data_uri(mock_replicate_api_token):
433+
def prediction_with_status(
434+
status: str, output: str | list[str] | None = None
435+
) -> dict:
436+
return {
437+
"id": "p1",
438+
"model": "test/example",
439+
"version": "v1",
440+
"urls": {
441+
"get": "https://api.replicate.com/v1/predictions/p1",
442+
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
443+
},
444+
"created_at": "2023-10-05T12:00:00.000000Z",
445+
"source": "api",
446+
"status": status,
447+
"input": {"text": "world"},
448+
"output": output,
449+
"error": "OOM" if status == "failed" else None,
450+
"logs": "",
451+
}
452+
453+
router = respx.Router(base_url="https://api.replicate.com/v1")
454+
router.route(method="POST", path="/predictions").mock(
455+
return_value=httpx.Response(
456+
201,
457+
json=prediction_with_status("processing"),
458+
)
459+
)
460+
router.route(method="GET", path="/predictions/p1").mock(
461+
return_value=httpx.Response(
462+
200,
463+
json=prediction_with_status(
464+
"succeeded",
465+
"data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==",
466+
),
467+
)
468+
)
469+
router.route(
470+
method="GET",
471+
path="/models/test/example/versions/v1",
472+
).mock(
473+
return_value=httpx.Response(
474+
201,
475+
json={
476+
"id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1",
477+
"created_at": "2024-07-18T00:35:56.210272Z",
478+
"cog_version": "0.9.10",
479+
"openapi_schema": {
480+
"openapi": "3.0.2",
481+
},
482+
},
483+
)
484+
)
485+
486+
client = Client(
487+
api_token="test-token", transport=httpx.MockTransport(router.handler)
488+
)
489+
client.poll_interval = 0.001
490+
491+
output = cast(
492+
FileOutput,
493+
client.run(
494+
"test/example:v1",
495+
input={
496+
"text": "Hello, world!",
497+
},
498+
use_file_output=True,
499+
),
500+
)
501+
502+
assert output.url == "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ=="
503+
assert output.read() == b"Hello, world!"
504+
for chunk in output:
505+
assert chunk == b"Hello, world!"
506+
507+
assert await output.aread() == b"Hello, world!"
508+
async for chunk in output:
509+
assert chunk == b"Hello, world!"

0 commit comments

Comments
 (0)