diff --git a/replicate/account.py b/replicate/account.py new file mode 100644 index 00000000..69e522b8 --- /dev/null +++ b/replicate/account.py @@ -0,0 +1,57 @@ +from typing import Any, Dict, Literal + +from replicate.resource import Namespace, Resource + + +class Account(Resource): + """ + A user or organization account on Replicate. + """ + + type: Literal["user", "organization"] + """The type of account.""" + + username: str + """The username of the account.""" + + name: str + """The name of the account.""" + + github_url: str + """The GitHub URL of the account.""" + + +class Accounts(Namespace): + """ + Namespace for operations related to accounts. + """ + + def current(self) -> Account: + """ + Get the current account. + + Returns: + Account: The current account. + """ + + resp = self._client._request("GET", "/v1/account") + obj = resp.json() + + return _json_to_account(obj) + + async def async_current(self) -> Account: + """ + Get the current account. + + Returns: + Account: The current account. + """ + + resp = await self._client._async_request("GET", "/v1/account") + obj = resp.json() + + return _json_to_account(obj) + + +def _json_to_account(json: Dict[str, Any]) -> Account: + return Account(**json) diff --git a/replicate/client.py b/replicate/client.py index 87da11f7..67bbdcf8 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -19,6 +19,7 @@ from typing_extensions import Unpack from replicate.__about__ import __version__ +from replicate.account import Accounts from replicate.collection import Collections from replicate.deployment import Deployments from replicate.exceptions import ReplicateError @@ -92,6 +93,14 @@ async def _async_request(self, method: str, path: str, **kwargs) -> httpx.Respon return resp + @property + def accounts(self) -> Accounts: + """ + Namespace for operations related to accounts. + """ + + return Accounts(client=self) + @property def collections(self) -> Collections: """ diff --git a/tests/test_account.py b/tests/test_account.py new file mode 100644 index 00000000..5212d1bf --- /dev/null +++ b/tests/test_account.py @@ -0,0 +1,44 @@ +import httpx +import pytest +import respx + +from replicate.account import Account +from replicate.client import Client + +router = respx.Router(base_url="https://api.replicate.com/v1") +router.route( + method="GET", + path="/account", + name="accounts.current", +).mock( + return_value=httpx.Response( + 200, + json={ + "type": "organization", + "username": "replicate", + "name": "Replicate", + "github_url": "https://github.com/replicate", + }, + ) +) +router.route(host="api.replicate.com").pass_through() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_account_current(async_flag): + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + + if async_flag: + account = await client.accounts.async_current() + else: + account = client.accounts.current() + + assert router["accounts.current"].called + assert isinstance(account, Account) + assert account.type == "organization" + assert account.username == "replicate" + assert account.name == "Replicate" + assert account.github_url == "https://github.com/replicate"