Skip to content

Commit 07c8fbb

Browse files
authored
Fix a couple of bugs in the base64 file_encoding_strategy (#398)
This commit adds tests for the `file_encoding_strategy` argument for `replicate.run()` and fixes two bugs that surfaced: 1. `replicate.run()` would convert the file provided into base64 encoded data but not a valid data URL. We now use the `base64_encode_file` function used for outputs. 2. `replicate.async_run()` accepted but did not use the `file_encoding_strategy` flag at all. This is fixed, though it is worth noting that `base64_encode_file` is not optimized for async workflows and will block. This might be okay as the file sizes expected for data URL payloads should be very small.
1 parent 4fdd78f commit 07c8fbb

File tree

2 files changed

+136
-3
lines changed

2 files changed

+136
-3
lines changed

replicate/helpers.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def encode_json(
4343
return encode_json(file, client, file_encoding_strategy)
4444
if isinstance(obj, io.IOBase):
4545
if file_encoding_strategy == "base64":
46-
return base64.b64encode(obj.read()).decode("utf-8")
46+
return base64_encode_file(obj)
4747
else:
4848
return client.files.create(obj).urls["get"]
4949
if HAS_NUMPY:
@@ -77,9 +77,13 @@ async def async_encode_json(
7777
]
7878
if isinstance(obj, Path):
7979
with obj.open("rb") as file:
80-
return encode_json(file, client, file_encoding_strategy)
80+
return await async_encode_json(file, client, file_encoding_strategy)
8181
if isinstance(obj, io.IOBase):
82-
return (await client.files.async_create(obj)).urls["get"]
82+
if file_encoding_strategy == "base64":
83+
# TODO: This should ideally use an async based file reader path.
84+
return base64_encode_file(obj)
85+
else:
86+
return (await client.files.async_create(obj)).urls["get"]
8387
if HAS_NUMPY:
8488
if isinstance(obj, np.integer): # type: ignore
8589
return int(obj)

tests/test_run.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
import asyncio
2+
import io
3+
import json
24
import sys
5+
from email.message import EmailMessage
6+
from email.parser import BytesParser
7+
from email.policy import HTTP
38
from typing import AsyncIterator, Iterator, Optional, cast
49

510
import httpx
@@ -581,6 +586,130 @@ async def test_run_with_model_error(mock_replicate_api_token):
581586
assert excinfo.value.prediction.status == "failed"
582587

583588

589+
@pytest.mark.asyncio
590+
@pytest.mark.parametrize("async_flag", [True, False])
591+
async def test_run_with_file_input_files_api(async_flag, mock_replicate_api_token):
592+
router = respx.Router(base_url="https://api.replicate.com/v1")
593+
mock_predictions_create = router.route(method="POST", path="/predictions").mock(
594+
return_value=httpx.Response(
595+
201,
596+
json=_prediction_with_status("processing"),
597+
)
598+
)
599+
router.route(
600+
method="GET",
601+
path="/models/test/example/versions/v1",
602+
).mock(
603+
return_value=httpx.Response(
604+
200,
605+
json=_version_with_schema(),
606+
)
607+
)
608+
mock_files_create = router.route(
609+
method="POST",
610+
path="/files",
611+
).mock(
612+
return_value=httpx.Response(
613+
200,
614+
json={
615+
"id": "file1",
616+
"name": "file.png",
617+
"content_type": "image/png",
618+
"size": 10,
619+
"etag": "123",
620+
"checksums": {},
621+
"metadata": {},
622+
"created_at": "",
623+
"expires_at": "",
624+
"urls": {"get": "https://api.replicate.com/files/file.txt"},
625+
},
626+
)
627+
)
628+
router.route(host="api.replicate.com").pass_through()
629+
630+
client = Client(
631+
api_token="test-token", transport=httpx.MockTransport(router.handler)
632+
)
633+
if async_flag:
634+
await client.async_run(
635+
"test/example:v1",
636+
input={"file": io.BytesIO(initial_bytes=b"hello world")},
637+
)
638+
else:
639+
client.run(
640+
"test/example:v1",
641+
input={"file": io.BytesIO(initial_bytes=b"hello world")},
642+
)
643+
644+
assert mock_predictions_create.called
645+
prediction_payload = json.loads(mock_predictions_create.calls[0].request.content)
646+
assert (
647+
prediction_payload.get("input", {}).get("file")
648+
== "https://api.replicate.com/files/file.txt"
649+
)
650+
651+
# Validate the Files API request
652+
req = mock_files_create.calls[0].request
653+
body = req.content
654+
content_type = req.headers["Content-Type"]
655+
656+
# Parse the multipart data
657+
parser = BytesParser(EmailMessage, policy=HTTP)
658+
headers = f"Content-Type: {content_type}\n\n".encode()
659+
parsed_message_generator = parser.parsebytes(headers + body).walk()
660+
next(parsed_message_generator) # wrapper
661+
input_file = next(parsed_message_generator)
662+
assert mock_files_create.called
663+
assert input_file.get_content() == b"hello world"
664+
assert input_file.get_content_type() == "application/octet-stream"
665+
666+
667+
@pytest.mark.asyncio
668+
@pytest.mark.parametrize("async_flag", [True, False])
669+
async def test_run_with_file_input_data_url(async_flag, mock_replicate_api_token):
670+
router = respx.Router(base_url="https://api.replicate.com/v1")
671+
mock_predictions_create = router.route(method="POST", path="/predictions").mock(
672+
return_value=httpx.Response(
673+
201,
674+
json=_prediction_with_status("processing"),
675+
)
676+
)
677+
router.route(
678+
method="GET",
679+
path="/models/test/example/versions/v1",
680+
).mock(
681+
return_value=httpx.Response(
682+
200,
683+
json=_version_with_schema(),
684+
)
685+
)
686+
router.route(host="api.replicate.com").pass_through()
687+
688+
client = Client(
689+
api_token="test-token", transport=httpx.MockTransport(router.handler)
690+
)
691+
692+
if async_flag:
693+
await client.async_run(
694+
"test/example:v1",
695+
input={"file": io.BytesIO(initial_bytes=b"hello world")},
696+
file_encoding_strategy="base64",
697+
)
698+
else:
699+
client.run(
700+
"test/example:v1",
701+
input={"file": io.BytesIO(initial_bytes=b"hello world")},
702+
file_encoding_strategy="base64",
703+
)
704+
705+
assert mock_predictions_create.called
706+
prediction_payload = json.loads(mock_predictions_create.calls[0].request.content)
707+
assert (
708+
prediction_payload.get("input", {}).get("file")
709+
== "data:application/octet-stream;base64,aGVsbG8gd29ybGQ="
710+
)
711+
712+
584713
@pytest.mark.asyncio
585714
async def test_run_with_file_output(mock_replicate_api_token):
586715
router = respx.Router(base_url="https://api.replicate.com/v1")

0 commit comments

Comments
 (0)