Skip to content

Commit 54f9c32

Browse files
authored
Add support for files API endpoints (#226)
Signed-off-by: Mattt Zmuda <[email protected]>
1 parent d7ac942 commit 54f9c32

File tree

10 files changed

+613
-53
lines changed

10 files changed

+613
-53
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ rye = { dev-dependencies = [
3333
] }
3434

3535
[tool.pytest.ini_options]
36+
asyncio_mode = "auto"
3637
testpaths = "tests/"
3738

3839
[tool.setuptools]

replicate/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
async_paginate = _async_paginate
1515

1616
collections = default_client.collections
17-
hardware = default_client.hardware
1817
deployments = default_client.deployments
18+
files = default_client.files
19+
hardware = default_client.hardware
1920
models = default_client.models
2021
predictions = default_client.predictions
2122
trainings = default_client.trainings

replicate/client.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from replicate.collection import Collections
2525
from replicate.deployment import Deployments
2626
from replicate.exceptions import ReplicateError
27+
from replicate.file import Files
2728
from replicate.hardware import HardwareNamespace as Hardware
2829
from replicate.model import Models
2930
from replicate.prediction import Predictions
@@ -117,6 +118,13 @@ def deployments(self) -> Deployments:
117118
"""
118119
return Deployments(client=self)
119120

121+
@property
122+
def files(self) -> Files:
123+
"""
124+
Namespace for operations related to files.
125+
"""
126+
return Files(client=self)
127+
120128
@property
121129
def hardware(self) -> Hardware:
122130
"""

replicate/file.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
import base64
2+
import io
3+
import json
4+
import mimetypes
5+
import os
6+
import pathlib
7+
from typing import Any, BinaryIO, Dict, List, Optional, TypedDict, Union
8+
9+
import httpx
10+
from typing_extensions import NotRequired, Unpack
11+
12+
from replicate.resource import Namespace, Resource
13+
14+
15+
class File(Resource):
16+
"""
17+
A file uploaded to Replicate that can be used as an input to a model.
18+
"""
19+
20+
id: str
21+
"""The ID of the file."""
22+
23+
name: str
24+
"""The name of the file."""
25+
26+
content_type: str
27+
"""The content type of the file."""
28+
29+
size: int
30+
"""The size of the file in bytes."""
31+
32+
etag: str
33+
"""The ETag of the file."""
34+
35+
checksums: Dict[str, str]
36+
"""The checksums of the file."""
37+
38+
metadata: Dict[str, Any]
39+
"""The metadata of the file."""
40+
41+
created_at: str
42+
"""The time the file was created."""
43+
44+
expires_at: Optional[str]
45+
"""The time the file will expire."""
46+
47+
urls: Dict[str, str]
48+
"""The URLs of the file."""
49+
50+
51+
class Files(Namespace):
52+
class CreateFileParams(TypedDict):
53+
"""Parameters for creating a file."""
54+
55+
filename: NotRequired[str]
56+
"""The name of the file."""
57+
58+
content_type: NotRequired[str]
59+
"""The content type of the file."""
60+
61+
metadata: NotRequired[Dict[str, Any]]
62+
"""The file metadata."""
63+
64+
def create(
65+
self,
66+
file: Union[str, pathlib.Path, BinaryIO, io.IOBase],
67+
**params: Unpack["Files.CreateFileParams"],
68+
) -> File:
69+
"""
70+
Upload a file that can be passed as an input when running a model.
71+
"""
72+
73+
if isinstance(file, (str, pathlib.Path)):
74+
with open(file, "rb") as f:
75+
return self.create(f, **params)
76+
elif not isinstance(file, (io.IOBase, BinaryIO)):
77+
raise ValueError(
78+
"Unsupported file type. Must be a file path or file-like object."
79+
)
80+
81+
resp = self._client._request(
82+
"POST", "/v1/files", timeout=None, **_create_file_params(file, **params)
83+
)
84+
85+
return _json_to_file(resp.json())
86+
87+
async def async_create(
88+
self,
89+
file: Union[str, pathlib.Path, BinaryIO, io.IOBase],
90+
**params: Unpack["Files.CreateFileParams"],
91+
) -> File:
92+
"""Upload a file asynchronously that can be passed as an input when running a model."""
93+
94+
if isinstance(file, (str, pathlib.Path)):
95+
with open(file, "rb") as f:
96+
return self.create(f, **params)
97+
elif not isinstance(file, (io.IOBase, BinaryIO)):
98+
raise ValueError(
99+
"Unsupported file type. Must be a file path or file-like object."
100+
)
101+
102+
resp = await self._client._async_request(
103+
"POST", "/v1/files", timeout=None, **_create_file_params(file, **params)
104+
)
105+
106+
return _json_to_file(resp.json())
107+
108+
def get(self, file_id: str) -> File:
109+
"""Get an uploaded file by its ID."""
110+
111+
resp = self._client._request("GET", f"/v1/files/{file_id}")
112+
return _json_to_file(resp.json())
113+
114+
async def async_get(self, file_id: str) -> File:
115+
"""Get an uploaded file by its ID asynchronously."""
116+
117+
resp = await self._client._async_request("GET", f"/v1/files/{file_id}")
118+
return _json_to_file(resp.json())
119+
120+
def list(self) -> List[File]:
121+
"""List all uploaded files."""
122+
123+
resp = self._client._request("GET", "/v1/files")
124+
return [_json_to_file(obj) for obj in resp.json().get("results", [])]
125+
126+
async def async_list(self) -> List[File]:
127+
"""List all uploaded files asynchronously."""
128+
129+
resp = await self._client._async_request("GET", "/v1/files")
130+
return [_json_to_file(obj) for obj in resp.json().get("results", [])]
131+
132+
def delete(self, file_id: str) -> None:
133+
"""Delete an uploaded file by its ID."""
134+
135+
_ = self._client._request("DELETE", f"/v1/files/{file_id}")
136+
137+
async def async_delete(self, file_id: str) -> None:
138+
"""Delete an uploaded file by its ID asynchronously."""
139+
140+
_ = await self._client._async_request("DELETE", f"/v1/files/{file_id}")
141+
142+
143+
def _create_file_params(
144+
file: Union[BinaryIO, io.IOBase],
145+
**params: Unpack["Files.CreateFileParams"],
146+
) -> Dict[str, Any]:
147+
file.seek(0)
148+
149+
if params is None:
150+
params = {}
151+
152+
filename = params.get("filename", os.path.basename(getattr(file, "name", "file")))
153+
content_type = (
154+
params.get("content_type")
155+
or mimetypes.guess_type(filename)[0]
156+
or "application/octet-stream"
157+
)
158+
metadata = params.get("metadata")
159+
160+
data = {}
161+
if metadata:
162+
data["metadata"] = json.dumps(metadata)
163+
164+
return {
165+
"files": {"content": (filename, file, content_type)},
166+
"data": data,
167+
}
168+
169+
170+
def _json_to_file(json: Dict[str, Any]) -> File: # pylint: disable=redefined-outer-name
171+
return File(**json)
172+
173+
174+
def upload_file(file: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
175+
"""
176+
Upload a file to the server.
177+
178+
Args:
179+
file: A file handle to upload.
180+
output_file_prefix: A string to prepend to the output file name.
181+
Returns:
182+
str: A URL to the uploaded file.
183+
"""
184+
# Lifted straight from cog.files
185+
186+
file.seek(0)
187+
188+
if output_file_prefix is not None:
189+
name = getattr(file, "name", "output")
190+
url = output_file_prefix + os.path.basename(name)
191+
resp = httpx.put(url, files={"file": file}, timeout=None) # type: ignore
192+
resp.raise_for_status()
193+
194+
return url
195+
196+
body = file.read()
197+
# Ensure the file handle is in bytes
198+
body = body.encode("utf-8") if isinstance(body, str) else body
199+
encoded_body = base64.b64encode(body).decode("utf-8")
200+
# Use getattr to avoid mypy complaints about io.IOBase having no attribute name
201+
mime_type = (
202+
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
203+
)
204+
return f"data:{mime_type};base64,{encoded_body}"

replicate/files.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

replicate/prediction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing_extensions import NotRequired, TypedDict, Unpack
2020

2121
from replicate.exceptions import ModelError, ReplicateError
22-
from replicate.files import upload_file
22+
from replicate.file import upload_file
2323
from replicate.json import encode_json
2424
from replicate.pagination import Page
2525
from replicate.resource import Namespace, Resource

replicate/training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from typing_extensions import NotRequired, Unpack
1515

16-
from replicate.files import upload_file
16+
from replicate.file import upload_file
1717
from replicate.identifier import ModelVersionIdentifier
1818
from replicate.json import encode_json
1919
from replicate.model import Model

0 commit comments

Comments
 (0)