Skip to content

Commit ff41075

Browse files
authored
Automatically upload prediction and training input files (#339)
Follow-up to #226 --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 54f9c32 commit ff41075

11 files changed

+30610
-54
lines changed

replicate/deployment.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing_extensions import Unpack, deprecated
44

55
from replicate.account import Account
6+
from replicate.json import async_encode_json, encode_json
67
from replicate.pagination import Page
78
from replicate.prediction import (
89
Prediction,
@@ -417,6 +418,13 @@ def create(
417418
Create a new prediction with the deployment.
418419
"""
419420

421+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
422+
if input is not None:
423+
input = encode_json(
424+
input,
425+
client=self._client,
426+
file_encoding_strategy=file_encoding_strategy,
427+
)
420428
body = _create_prediction_body(version=None, input=input, **params)
421429

422430
resp = self._client._request(
@@ -436,6 +444,13 @@ async def async_create(
436444
Create a new prediction with the deployment.
437445
"""
438446

447+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
448+
if input is not None:
449+
input = await async_encode_json(
450+
input,
451+
client=self._client,
452+
file_encoding_strategy=file_encoding_strategy,
453+
)
439454
body = _create_prediction_body(version=None, input=input, **params)
440455

441456
resp = await self._client._async_request(
@@ -463,6 +478,14 @@ def create(
463478
"""
464479

465480
url = _create_prediction_url_from_deployment(deployment)
481+
482+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
483+
if input is not None:
484+
input = encode_json(
485+
input,
486+
client=self._client,
487+
file_encoding_strategy=file_encoding_strategy,
488+
)
466489
body = _create_prediction_body(version=None, input=input, **params)
467490

468491
resp = self._client._request(
@@ -484,6 +507,14 @@ async def async_create(
484507
"""
485508

486509
url = _create_prediction_url_from_deployment(deployment)
510+
511+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
512+
if input is not None:
513+
input = await async_encode_json(
514+
input,
515+
client=self._client,
516+
file_encoding_strategy=file_encoding_strategy,
517+
)
487518
body = _create_prediction_body(version=None, input=input, **params)
488519

489520
resp = await self._client._async_request(

replicate/file.py

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
import base64
21
import io
32
import json
43
import mimetypes
54
import os
65
import pathlib
76
from typing import Any, BinaryIO, Dict, List, Optional, TypedDict, Union
87

9-
import httpx
10-
from typing_extensions import NotRequired, Unpack
8+
from typing_extensions import Literal, NotRequired, Unpack
119

1210
from replicate.resource import Namespace, Resource
1311

12+
FileEncodingStrategy = Literal["base64", "url"]
13+
1414

1515
class File(Resource):
1616
"""
@@ -169,36 +169,3 @@ def _create_file_params(
169169

170170
def _json_to_file(json: Dict[str, Any]) -> File: # pylint: disable=redefined-outer-name
171171
return File(**json)
172-
173-
174-
def upload_file(file: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
175-
"""
176-
Upload a file to the server.
177-
178-
Args:
179-
file: A file handle to upload.
180-
output_file_prefix: A string to prepend to the output file name.
181-
Returns:
182-
str: A URL to the uploaded file.
183-
"""
184-
# Lifted straight from cog.files
185-
186-
file.seek(0)
187-
188-
if output_file_prefix is not None:
189-
name = getattr(file, "name", "output")
190-
url = output_file_prefix + os.path.basename(name)
191-
resp = httpx.put(url, files={"file": file}, timeout=None) # type: ignore
192-
resp.raise_for_status()
193-
194-
return url
195-
196-
body = file.read()
197-
# Ensure the file handle is in bytes
198-
body = body.encode("utf-8") if isinstance(body, str) else body
199-
encoded_body = base64.b64encode(body).decode("utf-8")
200-
# Use getattr to avoid mypy complaints about io.IOBase having no attribute name
201-
mime_type = (
202-
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
203-
)
204-
return f"data:{mime_type};base64,{encoded_body}"

replicate/json.py

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1+
import base64
12
import io
3+
import mimetypes
24
from pathlib import Path
35
from types import GeneratorType
4-
from typing import Any, Callable
6+
from typing import TYPE_CHECKING, Any, Optional
7+
8+
if TYPE_CHECKING:
9+
from replicate.client import Client
10+
from replicate.file import FileEncodingStrategy
11+
512

613
try:
714
import numpy as np # type: ignore
@@ -14,22 +21,62 @@
1421
# pylint: disable=too-many-return-statements
1522
def encode_json(
1623
obj: Any, # noqa: ANN401
17-
upload_file: Callable[[io.IOBase], str],
24+
client: "Client",
25+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
1826
) -> Any: # noqa: ANN401
1927
"""
2028
Return a JSON-compatible version of the object.
2129
"""
22-
# Effectively the same thing as cog.json.encode_json.
2330

2431
if isinstance(obj, dict):
25-
return {key: encode_json(value, upload_file) for key, value in obj.items()}
32+
return {
33+
key: encode_json(value, client, file_encoding_strategy)
34+
for key, value in obj.items()
35+
}
36+
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
37+
return [encode_json(value, client, file_encoding_strategy) for value in obj]
38+
if isinstance(obj, Path):
39+
with obj.open("rb") as file:
40+
return encode_json(file, client, file_encoding_strategy)
41+
if isinstance(obj, io.IOBase):
42+
if file_encoding_strategy == "base64":
43+
return base64.b64encode(obj.read()).decode("utf-8")
44+
else:
45+
return client.files.create(obj).urls["get"]
46+
if HAS_NUMPY:
47+
if isinstance(obj, np.integer): # type: ignore
48+
return int(obj)
49+
if isinstance(obj, np.floating): # type: ignore
50+
return float(obj)
51+
if isinstance(obj, np.ndarray): # type: ignore
52+
return obj.tolist()
53+
return obj
54+
55+
56+
async def async_encode_json(
57+
obj: Any, # noqa: ANN401
58+
client: "Client",
59+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
60+
) -> Any: # noqa: ANN401
61+
"""
62+
Asynchronously return a JSON-compatible version of the object.
63+
"""
64+
65+
if isinstance(obj, dict):
66+
return {
67+
key: (await async_encode_json(value, client, file_encoding_strategy))
68+
for key, value in obj.items()
69+
}
2670
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
27-
return [encode_json(value, upload_file) for value in obj]
71+
return [
72+
(await async_encode_json(value, client, file_encoding_strategy))
73+
for value in obj
74+
]
2875
if isinstance(obj, Path):
2976
with obj.open("rb") as file:
30-
return upload_file(file)
77+
return encode_json(file, client, file_encoding_strategy)
3178
if isinstance(obj, io.IOBase):
32-
return upload_file(obj)
79+
return (await client.files.async_create(obj)).urls["get"]
3380
if HAS_NUMPY:
3481
if isinstance(obj, np.integer): # type: ignore
3582
return int(obj)
@@ -38,3 +85,26 @@ def encode_json(
3885
if isinstance(obj, np.ndarray): # type: ignore
3986
return obj.tolist()
4087
return obj
88+
89+
90+
def base64_encode_file(file: io.IOBase) -> str:
91+
"""
92+
Base64 encode a file.
93+
94+
Args:
95+
file: A file handle to upload.
96+
Returns:
97+
str: A base64-encoded data URI.
98+
"""
99+
100+
file.seek(0)
101+
body = file.read()
102+
103+
# Ensure the file handle is in bytes
104+
body = body.encode("utf-8") if isinstance(body, str) else body
105+
encoded_body = base64.b64encode(body).decode("utf-8")
106+
107+
mime_type = (
108+
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
109+
)
110+
return f"data:{mime_type};base64,{encoded_body}"

replicate/model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from replicate.exceptions import ReplicateException
66
from replicate.identifier import ModelVersionIdentifier
7+
from replicate.json import async_encode_json, encode_json
78
from replicate.pagination import Page
89
from replicate.prediction import (
910
Prediction,
@@ -391,6 +392,14 @@ def create(
391392
"""
392393

393394
url = _create_prediction_url_from_model(model)
395+
396+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
397+
if input is not None:
398+
input = encode_json(
399+
input,
400+
client=self._client,
401+
file_encoding_strategy=file_encoding_strategy,
402+
)
394403
body = _create_prediction_body(version=None, input=input, **params)
395404

396405
resp = self._client._request(
@@ -412,6 +421,14 @@ async def async_create(
412421
"""
413422

414423
url = _create_prediction_url_from_model(model)
424+
425+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
426+
if input is not None:
427+
input = await async_encode_json(
428+
input,
429+
client=self._client,
430+
file_encoding_strategy=file_encoding_strategy,
431+
)
415432
body = _create_prediction_body(version=None, input=input, **params)
416433

417434
resp = await self._client._async_request(

replicate/prediction.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from typing_extensions import NotRequired, TypedDict, Unpack
2020

2121
from replicate.exceptions import ModelError, ReplicateError
22-
from replicate.file import upload_file
23-
from replicate.json import encode_json
22+
from replicate.file import FileEncodingStrategy
23+
from replicate.json import async_encode_json, encode_json
2424
from replicate.pagination import Page
2525
from replicate.resource import Namespace, Resource
2626
from replicate.stream import EventSource
@@ -383,6 +383,9 @@ class CreatePredictionParams(TypedDict):
383383
stream: NotRequired[bool]
384384
"""Enable streaming of prediction output."""
385385

386+
file_encoding_strategy: NotRequired[FileEncodingStrategy]
387+
"""The strategy to use for encoding files in the prediction input."""
388+
386389
@overload
387390
def create(
388391
self,
@@ -453,6 +456,13 @@ def create( # type: ignore
453456
**params,
454457
)
455458

459+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
460+
if input is not None:
461+
input = encode_json(
462+
input,
463+
client=self._client,
464+
file_encoding_strategy=file_encoding_strategy,
465+
)
456466
body = _create_prediction_body(
457467
version,
458468
input,
@@ -537,6 +547,13 @@ async def async_create( # type: ignore
537547
**params,
538548
)
539549

550+
file_encoding_strategy = params.pop("file_encoding_strategy", None)
551+
if input is not None:
552+
input = await async_encode_json(
553+
input,
554+
client=self._client,
555+
file_encoding_strategy=file_encoding_strategy,
556+
)
540557
body = _create_prediction_body(
541558
version,
542559
input,
@@ -593,11 +610,12 @@ def _create_prediction_body( # pylint: disable=too-many-arguments
593610
webhook_completed: Optional[str] = None,
594611
webhook_events_filter: Optional[List[str]] = None,
595612
stream: Optional[bool] = None,
613+
**_kwargs,
596614
) -> Dict[str, Any]:
597615
body = {}
598616

599617
if input is not None:
600-
body["input"] = encode_json(input, upload_file=upload_file)
618+
body["input"] = input
601619

602620
if version is not None:
603621
body["version"] = version.id if isinstance(version, Version) else version

0 commit comments

Comments
 (0)