Skip to content

Commit e19cb78

Browse files
committed
Rename upload_file to base64_encode_file
Remove unused output_file_prefix parameter Add test coverage for base64_encode_file Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 3cc1974 commit e19cb78

File tree

6 files changed

+51
-30
lines changed

6 files changed

+51
-30
lines changed

replicate/deployment.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing_extensions import Unpack, deprecated
55

66
from replicate.account import Account
7-
from replicate.file import upload_file
7+
from replicate.file import base64_encode_file
88
from replicate.json import encode_json
99
from replicate.pagination import Page
1010
from replicate.prediction import (
@@ -424,7 +424,7 @@ def create(
424424
if input is not None:
425425
input = encode_json(
426426
input,
427-
upload_file=upload_file
427+
upload_file=base64_encode_file
428428
if file_encoding_strategy == "base64"
429429
else lambda file: self._client.files.create(file).urls["get"],
430430
)
@@ -451,7 +451,7 @@ async def async_create(
451451
if input is not None:
452452
input = encode_json(
453453
input,
454-
upload_file=upload_file
454+
upload_file=base64_encode_file
455455
if file_encoding_strategy == "base64"
456456
else lambda file: asyncio.get_event_loop()
457457
.run_until_complete(self._client.files.async_create(file))
@@ -489,7 +489,7 @@ def create(
489489
if input is not None:
490490
input = encode_json(
491491
input,
492-
upload_file=upload_file
492+
upload_file=base64_encode_file
493493
if file_encoding_strategy == "base64"
494494
else lambda file: self._client.files.create(file).urls["get"],
495495
)
@@ -519,7 +519,7 @@ async def async_create(
519519
if input is not None:
520520
input = encode_json(
521521
input,
522-
upload_file=upload_file
522+
upload_file=base64_encode_file
523523
if file_encoding_strategy == "base64"
524524
else lambda file: asyncio.get_event_loop()
525525
.run_until_complete(self._client.files.async_create(file))

replicate/file.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import pathlib
77
from typing import Any, BinaryIO, Dict, List, Optional, TypedDict, Union
88

9-
import httpx
109
from typing_extensions import NotRequired, Unpack
1110

1211
from replicate.resource import Namespace, Resource
@@ -171,33 +170,23 @@ def _json_to_file(json: Dict[str, Any]) -> File: # pylint: disable=redefined-ou
171170
return File(**json)
172171

173172

174-
def upload_file(file: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
173+
def base64_encode_file(file: io.IOBase) -> str:
175174
"""
176-
Upload a file to the server.
175+
Base64 encode a file.
177176
178177
Args:
179178
file: A file handle to upload.
180-
output_file_prefix: A string to prepend to the output file name.
181179
Returns:
182-
str: A URL to the uploaded file.
180+
str: A base64-encoded data URI.
183181
"""
184-
# Lifted straight from cog.files
185182

186183
file.seek(0)
187-
188-
if output_file_prefix is not None:
189-
name = getattr(file, "name", "output")
190-
url = output_file_prefix + os.path.basename(name)
191-
resp = httpx.put(url, files={"file": file}, timeout=None) # type: ignore
192-
resp.raise_for_status()
193-
194-
return url
195-
196184
body = file.read()
185+
197186
# Ensure the file handle is in bytes
198187
body = body.encode("utf-8") if isinstance(body, str) else body
199188
encoded_body = base64.b64encode(body).decode("utf-8")
200-
# Use getattr to avoid mypy complaints about io.IOBase having no attribute name
189+
201190
mime_type = (
202191
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
203192
)

replicate/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing_extensions import NotRequired, TypedDict, Unpack, deprecated
55

66
from replicate.exceptions import ReplicateException
7-
from replicate.file import upload_file
7+
from replicate.file import base64_encode_file
88
from replicate.identifier import ModelVersionIdentifier
99
from replicate.json import encode_json
1010
from replicate.pagination import Page
@@ -399,7 +399,7 @@ def create(
399399
if input is not None:
400400
input = encode_json(
401401
input,
402-
upload_file=upload_file
402+
upload_file=base64_encode_file
403403
if file_encoding_strategy == "base64"
404404
else lambda file: self._client.files.create(file).urls["get"],
405405
)
@@ -429,7 +429,7 @@ async def async_create(
429429
if input is not None:
430430
input = encode_json(
431431
input,
432-
upload_file=upload_file
432+
upload_file=base64_encode_file
433433
if file_encoding_strategy == "base64"
434434
else lambda file: asyncio.get_event_loop()
435435
.run_until_complete(self._client.files.async_create(file))

replicate/prediction.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing_extensions import NotRequired, TypedDict, Unpack
2020

2121
from replicate.exceptions import ModelError, ReplicateError
22-
from replicate.file import upload_file
22+
from replicate.file import base64_encode_file
2323
from replicate.json import encode_json
2424
from replicate.pagination import Page
2525
from replicate.resource import Namespace, Resource
@@ -460,7 +460,7 @@ def create( # type: ignore
460460
if input is not None:
461461
input = encode_json(
462462
input,
463-
upload_file=upload_file
463+
upload_file=base64_encode_file
464464
if file_encoding_strategy == "base64"
465465
else lambda file: self._client.files.create(file).urls["get"],
466466
)
@@ -552,7 +552,7 @@ async def async_create( # type: ignore
552552
if input is not None:
553553
input = encode_json(
554554
input,
555-
upload_file=upload_file
555+
upload_file=base64_encode_file
556556
if file_encoding_strategy == "base64"
557557
else lambda file: asyncio.get_event_loop()
558558
.run_until_complete(self._client.files.async_create(file))

replicate/training.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from typing_extensions import NotRequired, Unpack
1616

17-
from replicate.file import upload_file
17+
from replicate.file import base64_encode_file
1818
from replicate.identifier import ModelVersionIdentifier
1919
from replicate.json import encode_json
2020
from replicate.model import Model
@@ -283,7 +283,7 @@ def create( # type: ignore
283283
if input is not None:
284284
input = encode_json(
285285
input,
286-
upload_file=upload_file
286+
upload_file=base64_encode_file
287287
if file_encoding_strategy == "base64"
288288
else lambda file: self._client.files.create(file).urls["get"],
289289
)
@@ -324,7 +324,7 @@ async def async_create(
324324
if input is not None:
325325
input = encode_json(
326326
input,
327-
upload_file=upload_file
327+
upload_file=base64_encode_file
328328
if file_encoding_strategy == "base64"
329329
else lambda file: asyncio.get_event_loop()
330330
.run_until_complete(self._client.files.async_create(file))

tests/test_file.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import io
12
import tempfile
23

34
import pytest
45

56
import replicate
7+
from replicate.file import base64_encode_file
68

79

810
@pytest.mark.vcr("file-operations.yaml")
@@ -56,3 +58,33 @@ async def test_file_operations(async_flag):
5658
file_list = replicate.files.list()
5759

5860
assert all(f.id != file_id for f in file_list)
61+
62+
63+
@pytest.mark.parametrize(
64+
"content, filename, expected",
65+
[
66+
(b"Hello, World!", "test.txt", "data:text/plain;base64,SGVsbG8sIFdvcmxkIQ=="),
67+
(b"\x89PNG\r\n\x1a\n", "image.png", ""),
68+
(
69+
"{'key': 'value'}",
70+
"data.json",
71+
"data:application/json;base64,eydrZXknOiAndmFsdWUnfQ==",
72+
),
73+
(
74+
b"Random bytes",
75+
None,
76+
"data:application/octet-stream;base64,UmFuZG9tIGJ5dGVz",
77+
),
78+
],
79+
)
80+
def test_base64_encode_file(content, filename, expected):
81+
# Create a file-like object with the given content
82+
file = io.BytesIO(content if isinstance(content, bytes) else content.encode())
83+
84+
# Set the filename if provided
85+
if filename:
86+
file.name = filename
87+
88+
# Call the function and check the result
89+
result = base64_encode_file(file)
90+
assert result == expected

0 commit comments

Comments
 (0)