Skip to content

Commit d405512

Browse files
committed
Fix a couple of bugs in the base64 file_encoding_strategy
This commit adds tests for the `file_encoding_strategy` argument for `replicate.run()` and fixes two bugs that surfaced: 1. `replicate.run()` would convert the file provided into base64 encoded data but not a valid data URL. We now use the `base64_encode_file` function used for outputs. 2. `replicate.async_run()` accepted but did not use the `file_encoding_strategy` flag at all. This is fixed, though it is worth noting that `base64_encode_file` is not optimized for async workflows and will block. This migth be okay as the file sizes expected for data URL paylaods should be very small.
1 parent 4fdd78f commit d405512

File tree

2 files changed

+139
-4
lines changed

2 files changed

+139
-4
lines changed

replicate/helpers.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def encode_json(
4343
return encode_json(file, client, file_encoding_strategy)
4444
if isinstance(obj, io.IOBase):
4545
if file_encoding_strategy == "base64":
46-
return base64.b64encode(obj.read()).decode("utf-8")
46+
return base64_encode_file(obj)
4747
else:
4848
return client.files.create(obj).urls["get"]
4949
if HAS_NUMPY:
@@ -77,9 +77,13 @@ async def async_encode_json(
7777
]
7878
if isinstance(obj, Path):
7979
with obj.open("rb") as file:
80-
return encode_json(file, client, file_encoding_strategy)
80+
return await async_encode_json(file, client, file_encoding_strategy)
8181
if isinstance(obj, io.IOBase):
82-
return (await client.files.async_create(obj)).urls["get"]
82+
if file_encoding_strategy == "base64":
83+
# TODO: This should ideally use an async based file reader path.
84+
return base64_encode_file(obj)
85+
else:
86+
return (await client.files.async_create(obj)).urls["get"]
8387
if HAS_NUMPY:
8488
if isinstance(obj, np.integer): # type: ignore
8589
return int(obj)

tests/test_run.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import json
2+
import io
13
import asyncio
24
import sys
3-
from typing import AsyncIterator, Iterator, Optional, cast
5+
from typing import AsyncIterator, Iterator, Optional, cast, Type
46

57
import httpx
68
import pytest
@@ -11,6 +13,11 @@
1113
from replicate.exceptions import ModelError, ReplicateError
1214
from replicate.helpers import FileOutput
1315

16+
import email
17+
from email.message import EmailMessage
18+
from email.parser import BytesParser
19+
from email.policy import HTTP
20+
1421

1522
@pytest.mark.vcr("run.yaml")
1623
@pytest.mark.asyncio
@@ -581,6 +588,130 @@ async def test_run_with_model_error(mock_replicate_api_token):
581588
assert excinfo.value.prediction.status == "failed"
582589

583590

591+
@pytest.mark.asyncio
592+
@pytest.mark.parametrize("async_flag", [True, False])
593+
async def test_run_with_file_input_files_api(async_flag, mock_replicate_api_token):
594+
router = respx.Router(base_url="https://api.replicate.com/v1")
595+
mock_predictions_create = router.route(method="POST", path="/predictions").mock(
596+
return_value=httpx.Response(
597+
201,
598+
json=_prediction_with_status("processing"),
599+
)
600+
)
601+
router.route(
602+
method="GET",
603+
path="/models/test/example/versions/v1",
604+
).mock(
605+
return_value=httpx.Response(
606+
200,
607+
json=_version_with_schema(),
608+
)
609+
)
610+
mock_files_create = router.route(
611+
method="POST",
612+
path="/files",
613+
).mock(
614+
return_value=httpx.Response(
615+
200,
616+
json={
617+
"id": "file1",
618+
"name": "file.png",
619+
"content_type": "image/png",
620+
"size": 10,
621+
"etag": "123",
622+
"checksums": {},
623+
"metadata": {},
624+
"created_at": "",
625+
"expires_at": "",
626+
"urls": {"get": "https://api.replicate.com/files/file.txt"},
627+
},
628+
)
629+
)
630+
router.route(host="api.replicate.com").pass_through()
631+
632+
client = Client(
633+
api_token="test-token", transport=httpx.MockTransport(router.handler)
634+
)
635+
if async_flag:
636+
await client.async_run(
637+
"test/example:v1",
638+
input={"file": io.BytesIO(initial_bytes=b"hello world")},
639+
)
640+
else:
641+
client.run(
642+
"test/example:v1",
643+
input={"file": io.BytesIO(initial_bytes=b"hello world")},
644+
)
645+
646+
assert mock_predictions_create.called
647+
prediction_payload = json.loads(mock_predictions_create.calls[0].request.content)
648+
assert (
649+
prediction_payload.get("input", {}).get("file")
650+
== "https://api.replicate.com/files/file.txt"
651+
)
652+
653+
# Validate the Files API request
654+
req = mock_files_create.calls[0].request
655+
body = req.content
656+
content_type = req.headers["Content-Type"]
657+
658+
# Parse the multipart data
659+
parser = BytesParser(EmailMessage, policy=HTTP)
660+
headers = f"Content-Type: {content_type}\n\n".encode("utf-8")
661+
parsed_message_generator = parser.parsebytes(headers + body).walk()
662+
next(parsed_message_generator) # wrapper
663+
input_file = next(parsed_message_generator)
664+
assert mock_files_create.called
665+
assert input_file.get_content() == b"hello world"
666+
assert input_file.get_content_type() == "application/octet-stream"
667+
668+
669+
@pytest.mark.asyncio
670+
@pytest.mark.parametrize("async_flag", [True, False])
671+
async def test_run_with_file_input_data_url(async_flag, mock_replicate_api_token):
672+
router = respx.Router(base_url="https://api.replicate.com/v1")
673+
mock_predictions_create = router.route(method="POST", path="/predictions").mock(
674+
return_value=httpx.Response(
675+
201,
676+
json=_prediction_with_status("processing"),
677+
)
678+
)
679+
router.route(
680+
method="GET",
681+
path="/models/test/example/versions/v1",
682+
).mock(
683+
return_value=httpx.Response(
684+
200,
685+
json=_version_with_schema(),
686+
)
687+
)
688+
router.route(host="api.replicate.com").pass_through()
689+
690+
client = Client(
691+
api_token="test-token", transport=httpx.MockTransport(router.handler)
692+
)
693+
694+
if async_flag:
695+
await client.async_run(
696+
"test/example:v1",
697+
input={"file": io.BytesIO(initial_bytes=b"hello world")},
698+
file_encoding_strategy="base64",
699+
)
700+
else:
701+
client.run(
702+
"test/example:v1",
703+
input={"file": io.BytesIO(initial_bytes=b"hello world")},
704+
file_encoding_strategy="base64",
705+
)
706+
707+
assert mock_predictions_create.called
708+
prediction_payload = json.loads(mock_predictions_create.calls[0].request.content)
709+
assert (
710+
prediction_payload.get("input", {}).get("file")
711+
== "data:application/octet-stream;base64,aGVsbG8gd29ybGQ="
712+
)
713+
714+
584715
@pytest.mark.asyncio
585716
async def test_run_with_file_output(mock_replicate_api_token):
586717
router = respx.Router(base_url="https://api.replicate.com/v1")

0 commit comments

Comments
 (0)