Skip to content

Generic queries #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: user-models
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions postgrest_py/_async/client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from typing import Dict, Union, cast
from typing import Dict, Tuple, Type, TypeVar, Union, cast

from deprecation import deprecated
from httpx import Response, Timeout
from pydantic import BaseModel

from .. import __version__
from ..base_client import (
Expand All @@ -12,7 +13,9 @@
BasePostgrestClient,
)
from ..utils import AsyncClient
from .request_builder import AsyncRequestBuilder
from .request_builder import AsyncRequestBuilder, _AsyncModelRequestBuilder

_MT = TypeVar("_MT", bound=BaseModel)


class AsyncPostgrestClient(BasePostgrestClient):
Expand Down Expand Up @@ -56,18 +59,38 @@ async def __aexit__(self, exc_type, exc, tb) -> None:
async def aclose(self) -> None:
await self.session.aclose()

def from_(self, table: str) -> AsyncRequestBuilder:
"""Perform a table operation."""
def _pre_table_op(self, table_name: str) -> Tuple[AsyncClient, str]:
"""Prepare the session and API route before a query."""
base_url = str(self.session.base_url)
headers = dict(self.session.headers.items())
session = self.create_session(base_url, headers, self.session.timeout)
session.auth = self.session.auth
return session, f"/{table_name}"

def from_(self, table: str) -> AsyncRequestBuilder:
"""Perform a table operation."""
session, path = self._pre_table_op(table)
return AsyncRequestBuilder(session, f"/{table}")

def table(self, table: str) -> AsyncRequestBuilder:
"""Alias to self.from_()."""
return self.from_(table)

def from_model(self, model: Type[_MT]) -> _AsyncModelRequestBuilder[_MT]:
"""Perform a table operation, passing in a pydantic BaseModel.
The model will be used to parse the rows returned by the query.

Note:
The name of the table can be:
1) either the same as the name of the model
2) set as a ClassVar with the name __table_name__

If the class var is set, that will take priority.
"""
table_name = getattr(model, "__table_name__", model.__name__)
session, path = self._pre_table_op(table_name)
return _AsyncModelRequestBuilder[model](session, path)

@deprecated("0.2.0", "1.0.0", __version__, "Use self.from_() instead")
def from_table(self, table: str) -> AsyncRequestBuilder:
"""Alias to self.from_()."""
Expand Down
126 changes: 125 additions & 1 deletion postgrest_py/_async/request_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Dict, Optional, Type, TypeVar, Union
from typing import Any, Dict, Generic, Optional, Type, TypeVar, Union

from pydantic import BaseModel, ValidationError

Expand Down Expand Up @@ -149,3 +149,127 @@ def delete(
returning=returning,
)
return AsyncFilterRequestBuilder(self.session, self.path, method, json)


_T = TypeVar("_T", bound=BaseModel)

# The following are replicas of the normal classes above
# Their only purpose is to ensure type-safety and provide good
# editor autocompletion, when using pydantic BaseModels with queries.


class _AsyncModelQueryRequestBuilder(Generic[_T], AsyncQueryRequestBuilder):
async def execute(self) -> APIResponse[_T]:
# super().execute(model=_T) has NOT been used here
# as pyright was raising errors for it.
r = await self.session.request(
self.http_method,
self.path,
json=self.json,
)
try:
return APIResponse[_T].from_http_request_response(r)
except ValueError as e:
raise APIError(r.json()) from e


class _AsyncModelFilterRequestBuilder(BaseFilterRequestBuilder, _AsyncModelQueryRequestBuilder[_T]): # type: ignore
def __init__(
self,
session: AsyncClient,
path: str,
http_method: str,
json: dict,
) -> None:
BaseFilterRequestBuilder.__init__(self, session)
_AsyncModelQueryRequestBuilder[_T].__init__(
self, session, path, http_method, json
)


class _AsyncModelSelectRequestBuilder(
BaseSelectRequestBuilder, _AsyncModelQueryRequestBuilder[_T]
):
def __init__(
self,
session: AsyncClient,
path: str,
http_method: str,
json: dict,
) -> None:
BaseSelectRequestBuilder.__init__(self, session)
_AsyncModelQueryRequestBuilder[_T].__init__(
self, session, path, http_method, json
)


class _AsyncModelRequestBuilder(Generic[_T], AsyncRequestBuilder):
def select(
self,
*columns: str,
count: Optional[CountMethod] = None,
) -> _AsyncModelSelectRequestBuilder[_T]:
method, json = pre_select(self.session, *columns, count=count)
return _AsyncModelSelectRequestBuilder[_T](self.session, self.path, method, json)

def insert(
self,
json: dict,
*,
count: Optional[CountMethod] = None,
returning: ReturnMethod = ReturnMethod.representation,
upsert: bool = False,
) -> _AsyncModelQueryRequestBuilder[_T]:
method, json = pre_insert(
self.session,
json,
count=count,
returning=returning,
upsert=upsert,
)
return _AsyncModelQueryRequestBuilder[_T](self.session, self.path, method, json)

def upsert(
self,
json: dict,
*,
count: Optional[CountMethod] = None,
returning: ReturnMethod = ReturnMethod.representation,
ignore_duplicates: bool = False,
) -> _AsyncModelQueryRequestBuilder[_T]:
method, json = pre_upsert(
self.session,
json,
count=count,
returning=returning,
ignore_duplicates=ignore_duplicates,
)
return _AsyncModelQueryRequestBuilder[_T](self.session, self.path, method, json)

def update(
self,
json: dict,
*,
count: Optional[CountMethod] = None,
returning: ReturnMethod = ReturnMethod.representation,
) -> _AsyncModelFilterRequestBuilder[_T]:
method, json = pre_update(
self.session,
json,
count=count,
returning=returning,
)
return _AsyncModelFilterRequestBuilder[_T](self.session, self.path, method, json)

def delete(
self,
*,
count: Optional[CountMethod] = None,
returning: ReturnMethod = ReturnMethod.representation,
) -> _AsyncModelFilterRequestBuilder[_T]:
method, json = pre_delete(
self.session,
count=count,
returning=returning,
)
return _AsyncModelFilterRequestBuilder[_T](self.session, self.path, method, json)
33 changes: 28 additions & 5 deletions postgrest_py/_sync/client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from typing import Dict, Union, cast
from typing import Dict, Tuple, Type, TypeVar, Union, cast

from deprecation import deprecated
from httpx import Response, Timeout
from pydantic import BaseModel

from .. import __version__
from ..base_client import (
Expand All @@ -12,7 +13,9 @@
BasePostgrestClient,
)
from ..utils import SyncClient
from .request_builder import SyncRequestBuilder
from .request_builder import SyncRequestBuilder, _SyncModelRequestBuilder

_MT = TypeVar("_MT", bound=BaseModel)


class SyncPostgrestClient(BasePostgrestClient):
Expand Down Expand Up @@ -56,18 +59,38 @@ def __exit__(self, exc_type, exc, tb) -> None:
def aclose(self) -> None:
self.session.aclose()

def from_(self, table: str) -> SyncRequestBuilder:
"""Perform a table operation."""
def _pre_table_op(self, table_name: str) -> Tuple[SyncClient, str]:
"""Prepare the session and API route before a query."""
base_url = str(self.session.base_url)
headers = dict(self.session.headers.items())
session = self.create_session(base_url, headers, self.session.timeout)
session.auth = self.session.auth
return SyncRequestBuilder(session, f"/{table}")
return session, f"/{table_name}"

def from_(self, table: str) -> SyncRequestBuilder:
"""Perform a table operation."""
session, path = self._pre_table_op(table)
return SyncRequestBuilder(session, path)

def table(self, table: str) -> SyncRequestBuilder:
"""Alias to self.from_()."""
return self.from_(table)

def from_model(self, model: Type[_MT]) -> _SyncModelRequestBuilder[_MT]:
"""Perform a table operation, passing in a pydantic BaseModel.
The model will be used to parse the rows returned by the query.

Note:
The name of the table can be:
1) either the same as the name of the model
2) set as a ClassVar with the name __table_name__

If the class var is set, that will take priority.
"""
table_name = getattr(model, "__table_name__", model.__name__)
session, path = self._pre_table_op(table_name)
return _SyncModelRequestBuilder[_MT](session, path)

@deprecated("0.2.0", "1.0.0", __version__, "Use self.from_() instead")
def from_table(self, table: str) -> SyncRequestBuilder:
"""Alias to self.from_()."""
Expand Down
122 changes: 121 additions & 1 deletion postgrest_py/_sync/request_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Dict, Optional, Type, TypeVar, Union
from typing import Any, Dict, Generic, Optional, Type, TypeVar, Union

from pydantic import BaseModel

Expand Down Expand Up @@ -149,3 +149,123 @@ def delete(
returning=returning,
)
return SyncFilterRequestBuilder(self.session, self.path, method, json)


_T = TypeVar("_T", bound=BaseModel)

# The following are replicas of the normal classes above
# Their only purpose is to ensure type-safety and provide good
# editor autocompletion, when using pydantic BaseModels with queries.


class _SyncModelQueryRequestBuilder(Generic[_T], SyncQueryRequestBuilder):
def execute(self) -> APIResponse[_T]:
# super().execute(model=_T) has NOT been used here
# as pyright was raising errors for it.
r = self.session.request(
self.http_method,
self.path,
json=self.json,
)
try:
return APIResponse[_T].from_http_request_response(r)
except ValueError as e:
raise APIError(r.json()) from e


class _SyncModelFilterRequestBuilder(BaseFilterRequestBuilder, _SyncModelQueryRequestBuilder[_T]): # type: ignore
def __init__(
self,
session: SyncClient,
path: str,
http_method: str,
json: dict,
) -> None:
BaseFilterRequestBuilder.__init__(self, session)
_SyncModelQueryRequestBuilder[_T].__init__(self, session, path, http_method, json)


class _SyncModelSelectRequestBuilder(
BaseSelectRequestBuilder, _SyncModelQueryRequestBuilder[_T]
):
def __init__(
self,
session: SyncClient,
path: str,
http_method: str,
json: dict,
) -> None:
BaseSelectRequestBuilder.__init__(self, session)
_SyncModelQueryRequestBuilder[_T].__init__(self, session, path, http_method, json)


class _SyncModelRequestBuilder(Generic[_T], SyncRequestBuilder):
def select(
self,
*columns: str,
count: Optional[CountMethod] = None,
) -> _SyncModelSelectRequestBuilder[_T]:
method, json = pre_select(self.session, *columns, count=count)
return _SyncModelSelectRequestBuilder[_T](self.session, self.path, method, json)

def insert(
self,
json: dict,
*,
count: Optional[CountMethod] = None,
returning: ReturnMethod = ReturnMethod.representation,
upsert: bool = False,
) -> _SyncModelQueryRequestBuilder[_T]:
method, json = pre_insert(
self.session,
json,
count=count,
returning=returning,
upsert=upsert,
)
return _SyncModelQueryRequestBuilder[_T](self.session, self.path, method, json)

def upsert(
self,
json: dict,
*,
count: Optional[CountMethod] = None,
returning: ReturnMethod = ReturnMethod.representation,
ignore_duplicates: bool = False,
) -> _SyncModelQueryRequestBuilder[_T]:
method, json = pre_upsert(
self.session,
json,
count=count,
returning=returning,
ignore_duplicates=ignore_duplicates,
)
return _SyncModelQueryRequestBuilder[_T](self.session, self.path, method, json)

def update(
self,
json: dict,
*,
count: Optional[CountMethod] = None,
returning: ReturnMethod = ReturnMethod.representation,
) -> _SyncModelFilterRequestBuilder[_T]:
method, json = pre_update(
self.session,
json,
count=count,
returning=returning,
)
return _SyncModelFilterRequestBuilder[_T](self.session, self.path, method, json)

def delete(
self,
*,
count: Optional[CountMethod] = None,
returning: ReturnMethod = ReturnMethod.representation,
) -> _SyncModelFilterRequestBuilder[_T]:
method, json = pre_delete(
self.session,
count=count,
returning=returning,
)
return _SyncModelFilterRequestBuilder[_T](self.session, self.path, method, json)