Skip to content

Commit 8480c3a

Browse files
committed
Add support for files API endpoints
Signed-off-by: Mattt Zmuda <[email protected]>
1 parent b0bbc04 commit 8480c3a

File tree

1 file changed

+129
-1
lines changed

1 file changed

+129
-1
lines changed

replicate/files.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,138 @@
22
import io
33
import mimetypes
44
import os
5-
from typing import Optional
5+
from typing import Any, Dict, List, Optional, Tuple
66

77
import httpx
88

9+
from replicate.resource import Namespace, Resource
10+
11+
12+
class File(Resource):
13+
"""
14+
A file uploaded to Replicate that can be used as an input to a model.
15+
"""
16+
17+
id: str
18+
"""The ID of the file."""
19+
20+
name: str
21+
"""The name of the file."""
22+
23+
content_type: str
24+
"""The content type of the file."""
25+
26+
size: int
27+
"""The size of the file in bytes."""
28+
29+
etag: str
30+
"""The ETag of the file."""
31+
32+
checksum: str
33+
"""The checksum of the file."""
34+
35+
metadata: Dict[str, Any]
36+
"""The metadata of the file."""
37+
38+
created_at: str
39+
"""The time the file was created."""
40+
41+
expires_at: Optional[str]
42+
"""The time the file will expire."""
43+
44+
urls: Dict[str, str]
45+
"""The URLs of the file."""
46+
47+
48+
class Files(Namespace):
49+
def create(
50+
self, file: io.IOBase, metadata: Optional[Dict[str, Any]] = None
51+
) -> File:
52+
"""Create a file that can be passed as an input when running a model."""
53+
54+
file.seek(0)
55+
56+
resp = self._client._request(
57+
"POST",
58+
"/files",
59+
data={
60+
"content": _file_content(file),
61+
"metadata": metadata,
62+
},
63+
timeout=None,
64+
)
65+
66+
return _json_to_file(resp.json())
67+
68+
async def async_create(
69+
self, file: io.IOBase, metadata: Optional[Dict[str, Any]] = None
70+
) -> File:
71+
"""Create a file asynchronously that can be passed as an input when running a model."""
72+
73+
file.seek(0)
74+
75+
resp = await self._client._async_request(
76+
"POST",
77+
"/files",
78+
data={
79+
"content": _file_content(file),
80+
"metadata": metadata,
81+
},
82+
timeout=None,
83+
)
84+
85+
return _json_to_file(resp.json())
86+
87+
def get(self, file_id: str) -> File:
88+
"""Get a file from the server by its ID."""
89+
90+
resp = self._client._request("GET", f"/files/{file_id}")
91+
return _json_to_file(resp.json())
92+
93+
async def async_get(self, file_id: str) -> File:
94+
"""Get a file from the server by its ID asynchronously."""
95+
96+
resp = await self._client._async_request("GET", f"/files/{file_id}")
97+
return _json_to_file(resp.json())
98+
99+
def list(self) -> List[File]:
100+
"""List all files on the server."""
101+
102+
resp = self._client._request("GET", "/files")
103+
return [_json_to_file(file_json) for file_json in resp.json()]
104+
105+
async def async_list(self) -> List[File]:
106+
"""List all files on the server asynchronously."""
107+
108+
resp = await self._client._async_request("GET", "/files")
109+
return [_json_to_file(file_json) for file_json in resp.json()]
110+
111+
def delete(self, file_id: str) -> File:
112+
"""Delete a file from the server by its ID."""
113+
114+
resp = self._client._request("DELETE", f"/files/{file_id}")
115+
return _json_to_file(resp.json())
116+
117+
async def async_delete(self, file_id: str) -> File:
118+
"""Delete a file from the server by its ID asynchronously."""
119+
120+
resp = await self._client._async_request("DELETE", f"/files/{file_id}")
121+
return _json_to_file(resp.json())
122+
123+
124+
def _file_content(file: io.IOBase) -> Tuple[str, io.IOBase, str]:
125+
"""Get the file content details including name, file object, and content type."""
126+
127+
name = getattr(file, "name", "output")
128+
content_type = (
129+
mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
130+
)
131+
return (os.path.basename(name), file, content_type)
132+
133+
134+
def _json_to_file(json: Dict[str, Any]) -> File:
135+
return File(**json)
136+
9137

10138
def upload_file(file: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
11139
"""

0 commit comments

Comments
 (0)