Skip to content

Commit 4b6bf6f

Browse files
matttGothReigen
authored andcommitted
Add support for accounts.current endpoint (#221)
Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 371731c commit 4b6bf6f

File tree

3 files changed

+110
-0
lines changed

3 files changed

+110
-0
lines changed

replicate/account.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from typing import Any, Dict, Literal
2+
3+
from replicate.resource import Namespace, Resource
4+
5+
6+
class Account(Resource):
7+
"""
8+
A user or organization account on Replicate.
9+
"""
10+
11+
type: Literal["user", "organization"]
12+
"""The type of account."""
13+
14+
username: str
15+
"""The username of the account."""
16+
17+
name: str
18+
"""The name of the account."""
19+
20+
github_url: str
21+
"""The GitHub URL of the account."""
22+
23+
24+
class Accounts(Namespace):
25+
"""
26+
Namespace for operations related to accounts.
27+
"""
28+
29+
def current(self) -> Account:
30+
"""
31+
Get the current account.
32+
33+
Returns:
34+
Account: The current account.
35+
"""
36+
37+
resp = self._client._request("GET", "/v1/account")
38+
obj = resp.json()
39+
40+
return _json_to_account(obj)
41+
42+
async def async_current(self) -> Account:
43+
"""
44+
Get the current account.
45+
46+
Returns:
47+
Account: The current account.
48+
"""
49+
50+
resp = await self._client._async_request("GET", "/v1/account")
51+
obj = resp.json()
52+
53+
return _json_to_account(obj)
54+
55+
56+
def _json_to_account(json: Dict[str, Any]) -> Account:
57+
return Account(**json)

replicate/client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing_extensions import Unpack
2121

2222
from replicate.__about__ import __version__
23+
from replicate.account import Accounts
2324
from replicate.collection import Collections
2425
from replicate.deployment import Deployments
2526
from replicate.exceptions import ReplicateError
@@ -93,6 +94,14 @@ async def _async_request(self, method: str, path: str, **kwargs) -> httpx.Respon
9394

9495
return resp
9596

97+
@property
98+
def accounts(self) -> Accounts:
99+
"""
100+
Namespace for operations related to accounts.
101+
"""
102+
103+
return Accounts(client=self)
104+
96105
@property
97106
def collections(self) -> Collections:
98107
"""

tests/test_account.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import httpx
2+
import pytest
3+
import respx
4+
5+
from replicate.account import Account
6+
from replicate.client import Client
7+
8+
router = respx.Router(base_url="https://api.replicate.com/v1")
9+
router.route(
10+
method="GET",
11+
path="/account",
12+
name="accounts.current",
13+
).mock(
14+
return_value=httpx.Response(
15+
200,
16+
json={
17+
"type": "organization",
18+
"username": "replicate",
19+
"name": "Replicate",
20+
"github_url": "https://github.com/replicate",
21+
},
22+
)
23+
)
24+
router.route(host="api.replicate.com").pass_through()
25+
26+
27+
@pytest.mark.asyncio
28+
@pytest.mark.parametrize("async_flag", [True, False])
29+
async def test_account_current(async_flag):
30+
client = Client(
31+
api_token="test-token", transport=httpx.MockTransport(router.handler)
32+
)
33+
34+
if async_flag:
35+
account = await client.accounts.async_current()
36+
else:
37+
account = client.accounts.current()
38+
39+
assert router["accounts.current"].called
40+
assert isinstance(account, Account)
41+
assert account.type == "organization"
42+
assert account.username == "replicate"
43+
assert account.name == "Replicate"
44+
assert account.github_url == "https://github.com/replicate"

0 commit comments

Comments
 (0)