Skip to content

Commit ba9ad8d

Browse files
authored
feat: generic query builders (#309)
* feat: make all query builders generic * feat: return generic request builders from client methods * chore: use typing.List instead of builtin * chore: use typing.List * fix: correct type of APIResponse.data * feat: make RPCFilterRequestBuilder This makes sure the return types of rpc() and other query methods are correct. See https://gist.github.com/anand2312/93d3abf401335fd3310d9e30112303bf for an explanation. * chore: use typing.List * feat: make get_origin_and_cast This fixes the type-checker error raised while accessing RequestBuilder[T].__origin__ * fix: use typing.List
1 parent 3329234 commit ba9ad8d

File tree

6 files changed

+239
-150
lines changed

6 files changed

+239
-150
lines changed

postgrest/_async/client.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Dict, Union, cast
3+
from typing import Any, Dict, Union, cast
44

55
from deprecation import deprecated
66
from httpx import Headers, QueryParams, Timeout
@@ -12,7 +12,9 @@
1212
DEFAULT_POSTGREST_CLIENT_TIMEOUT,
1313
)
1414
from ..utils import AsyncClient
15-
from .request_builder import AsyncFilterRequestBuilder, AsyncRequestBuilder
15+
from .request_builder import AsyncRequestBuilder, AsyncRPCFilterRequestBuilder
16+
17+
_TableT = Dict[str, Any]
1618

1719

1820
class AsyncPostgrestClient(BasePostgrestClient):
@@ -57,17 +59,17 @@ async def aclose(self) -> None:
5759
"""Close the underlying HTTP connections."""
5860
await self.session.aclose()
5961

60-
def from_(self, table: str) -> AsyncRequestBuilder:
62+
def from_(self, table: str) -> AsyncRequestBuilder[_TableT]:
6163
"""Perform a table operation.
6264
6365
Args:
6466
table: The name of the table
6567
Returns:
6668
:class:`AsyncRequestBuilder`
6769
"""
68-
return AsyncRequestBuilder(self.session, f"/{table}")
70+
return AsyncRequestBuilder[_TableT](self.session, f"/{table}")
6971

70-
def table(self, table: str) -> AsyncRequestBuilder:
72+
def table(self, table: str) -> AsyncRequestBuilder[_TableT]:
7173
"""Alias to :meth:`from_`."""
7274
return self.from_(table)
7375

@@ -76,24 +78,26 @@ def from_table(self, table: str) -> AsyncRequestBuilder:
7678
"""Alias to :meth:`from_`."""
7779
return self.from_(table)
7880

79-
async def rpc(self, func: str, params: dict) -> AsyncFilterRequestBuilder:
81+
async def rpc(self, func: str, params: dict) -> AsyncRPCFilterRequestBuilder[Any]:
8082
"""Perform a stored procedure call.
8183
8284
Args:
8385
func: The name of the remote procedure to run.
8486
params: The parameters to be passed to the remote procedure.
8587
Returns:
86-
:class:`AsyncFilterRequestBuilder`
88+
:class:`AsyncRPCFilterRequestBuilder`
8789
Example:
8890
.. code-block:: python
8991
9092
await client.rpc("foobar", {"arg": "value"}).execute()
9193
92-
.. versionchanged:: 0.11.0
94+
.. versionchanged:: 0.10.9
95+
This method now returns a :class:`AsyncRPCFilterRequestBuilder`.
96+
.. versionchanged:: 0.10.2
9397
This method now returns a :class:`AsyncFilterRequestBuilder` which allows you to
9498
filter on the RPC's resultset.
9599
"""
96100
# the params here are params to be sent to the RPC and not the queryparams!
97-
return AsyncFilterRequestBuilder(
101+
return AsyncRPCFilterRequestBuilder[Any](
98102
self.session, f"/rpc/{func}", "POST", Headers(), QueryParams(), json=params
99103
)

postgrest/_async/request_builder.py

+62-35
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from json import JSONDecodeError
4-
from typing import Optional, Union
4+
from typing import Any, Generic, Optional, TypeVar, Union
55

66
from httpx import Headers, QueryParams
77
from pydantic import ValidationError
@@ -20,10 +20,12 @@
2020
)
2121
from ..exceptions import APIError, generate_default_error_message
2222
from ..types import ReturnMethod
23-
from ..utils import AsyncClient
23+
from ..utils import AsyncClient, get_origin_and_cast
2424

25+
_ReturnT = TypeVar("_ReturnT")
2526

26-
class AsyncQueryRequestBuilder:
27+
28+
class AsyncQueryRequestBuilder(Generic[_ReturnT]):
2729
def __init__(
2830
self,
2931
session: AsyncClient,
@@ -40,7 +42,7 @@ def __init__(
4042
self.params = params
4143
self.json = json
4244

43-
async def execute(self) -> APIResponse:
45+
async def execute(self) -> APIResponse[_ReturnT]:
4446
"""Execute the query.
4547
4648
.. tip::
@@ -63,7 +65,7 @@ async def execute(self) -> APIResponse:
6365
if (
6466
200 <= r.status_code <= 299
6567
): # Response.ok from JS (https://developer.mozilla.org/en-US/docs/Web/API/Response/ok)
66-
return APIResponse.from_http_request_response(r)
68+
return APIResponse[_ReturnT].from_http_request_response(r)
6769
else:
6870
raise APIError(r.json())
6971
except ValidationError as e:
@@ -72,7 +74,7 @@ async def execute(self) -> APIResponse:
7274
raise APIError(generate_default_error_message(r))
7375

7476

75-
class AsyncSingleRequestBuilder:
77+
class AsyncSingleRequestBuilder(Generic[_ReturnT]):
7678
def __init__(
7779
self,
7880
session: AsyncClient,
@@ -89,7 +91,7 @@ def __init__(
8991
self.params = params
9092
self.json = json
9193

92-
async def execute(self) -> SingleAPIResponse:
94+
async def execute(self) -> SingleAPIResponse[_ReturnT]:
9395
"""Execute the query.
9496
9597
.. tip::
@@ -112,7 +114,7 @@ async def execute(self) -> SingleAPIResponse:
112114
if (
113115
200 <= r.status_code <= 299
114116
): # Response.ok from JS (https://developer.mozilla.org/en-US/docs/Web/API/Response/ok)
115-
return SingleAPIResponse.from_http_request_response(r)
117+
return SingleAPIResponse[_ReturnT].from_http_request_response(r)
116118
else:
117119
raise APIError(r.json())
118120
except ValidationError as e:
@@ -121,11 +123,11 @@ async def execute(self) -> SingleAPIResponse:
121123
raise APIError(generate_default_error_message(r))
122124

123125

124-
class AsyncMaybeSingleRequestBuilder(AsyncSingleRequestBuilder):
125-
async def execute(self) -> Optional[SingleAPIResponse]:
126+
class AsyncMaybeSingleRequestBuilder(AsyncSingleRequestBuilder[_ReturnT]):
127+
async def execute(self) -> Optional[SingleAPIResponse[_ReturnT]]:
126128
r = None
127129
try:
128-
r = await super().execute()
130+
r = await AsyncSingleRequestBuilder[_ReturnT].execute(self)
129131
except APIError as e:
130132
if e.details and "The result contains 0 rows" in e.details:
131133
return None
@@ -142,7 +144,7 @@ async def execute(self) -> Optional[SingleAPIResponse]:
142144

143145

144146
# ignoring type checking as a workaround for https://github.com/python/mypy/issues/9319
145-
class AsyncFilterRequestBuilder(BaseFilterRequestBuilder, AsyncQueryRequestBuilder): # type: ignore
147+
class AsyncFilterRequestBuilder(BaseFilterRequestBuilder[_ReturnT], AsyncQueryRequestBuilder[_ReturnT]): # type: ignore
146148
def __init__(
147149
self,
148150
session: AsyncClient,
@@ -152,14 +154,37 @@ def __init__(
152154
params: QueryParams,
153155
json: dict,
154156
) -> None:
155-
BaseFilterRequestBuilder.__init__(self, session, headers, params)
156-
AsyncQueryRequestBuilder.__init__(
157+
get_origin_and_cast(BaseFilterRequestBuilder[_ReturnT]).__init__(
158+
self, session, headers, params
159+
)
160+
get_origin_and_cast(AsyncQueryRequestBuilder[_ReturnT]).__init__(
161+
self, session, path, http_method, headers, params, json
162+
)
163+
164+
165+
# this exists for type-safety. see https://gist.github.com/anand2312/93d3abf401335fd3310d9e30112303bf
166+
class AsyncRPCFilterRequestBuilder(
167+
BaseFilterRequestBuilder[_ReturnT], AsyncSingleRequestBuilder[_ReturnT]
168+
):
169+
def __init__(
170+
self,
171+
session: AsyncClient,
172+
path: str,
173+
http_method: str,
174+
headers: Headers,
175+
params: QueryParams,
176+
json: dict,
177+
) -> None:
178+
get_origin_and_cast(BaseFilterRequestBuilder[_ReturnT]).__init__(
179+
self, session, headers, params
180+
)
181+
get_origin_and_cast(AsyncSingleRequestBuilder[_ReturnT]).__init__(
157182
self, session, path, http_method, headers, params, json
158183
)
159184

160185

161186
# ignoring type checking as a workaround for https://github.com/python/mypy/issues/9319
162-
class AsyncSelectRequestBuilder(BaseSelectRequestBuilder, AsyncQueryRequestBuilder): # type: ignore
187+
class AsyncSelectRequestBuilder(BaseSelectRequestBuilder[_ReturnT], AsyncQueryRequestBuilder[_ReturnT]): # type: ignore
163188
def __init__(
164189
self,
165190
session: AsyncClient,
@@ -169,19 +194,21 @@ def __init__(
169194
params: QueryParams,
170195
json: dict,
171196
) -> None:
172-
BaseSelectRequestBuilder.__init__(self, session, headers, params)
173-
AsyncQueryRequestBuilder.__init__(
197+
get_origin_and_cast(BaseSelectRequestBuilder[_ReturnT]).__init__(
198+
self, session, headers, params
199+
)
200+
get_origin_and_cast(AsyncQueryRequestBuilder[_ReturnT]).__init__(
174201
self, session, path, http_method, headers, params, json
175202
)
176203

177-
def single(self) -> AsyncSingleRequestBuilder:
204+
def single(self) -> AsyncSingleRequestBuilder[_ReturnT]:
178205
"""Specify that the query will only return a single row in response.
179206
180207
.. caution::
181208
The API will raise an error if the query returned more than one row.
182209
"""
183210
self.headers["Accept"] = "application/vnd.pgrst.object+json"
184-
return AsyncSingleRequestBuilder(
211+
return AsyncSingleRequestBuilder[_ReturnT](
185212
headers=self.headers,
186213
http_method=self.http_method,
187214
json=self.json,
@@ -190,10 +217,10 @@ def single(self) -> AsyncSingleRequestBuilder:
190217
session=self.session, # type: ignore
191218
)
192219

193-
def maybe_single(self) -> AsyncMaybeSingleRequestBuilder:
220+
def maybe_single(self) -> AsyncMaybeSingleRequestBuilder[_ReturnT]:
194221
"""Retrieves at most one row from the result. Result must be at most one row (e.g. using `eq` on a UNIQUE column), otherwise this will result in an error."""
195222
self.headers["Accept"] = "application/vnd.pgrst.object+json"
196-
return AsyncMaybeSingleRequestBuilder(
223+
return AsyncMaybeSingleRequestBuilder[_ReturnT](
197224
headers=self.headers,
198225
http_method=self.http_method,
199226
json=self.json,
@@ -203,8 +230,8 @@ def maybe_single(self) -> AsyncMaybeSingleRequestBuilder:
203230
)
204231

205232
def text_search(
206-
self, column: str, query: str, options: Dict[str, any] = {}
207-
) -> AsyncFilterRequestBuilder:
233+
self, column: str, query: str, options: dict[str, Any] = {}
234+
) -> AsyncFilterRequestBuilder[_ReturnT]:
208235
type_ = options.get("type")
209236
type_part = ""
210237
if type_ == "plain":
@@ -216,7 +243,7 @@ def text_search(
216243
config_part = f"({options.get('config')})" if options.get("config") else ""
217244
self.params = self.params.add(column, f"{type_part}fts{config_part}.{query}")
218245

219-
return AsyncQueryRequestBuilder(
246+
return AsyncQueryRequestBuilder[_ReturnT](
220247
headers=self.headers,
221248
http_method=self.http_method,
222249
json=self.json,
@@ -226,7 +253,7 @@ def text_search(
226253
)
227254

228255

229-
class AsyncRequestBuilder:
256+
class AsyncRequestBuilder(Generic[_ReturnT]):
230257
def __init__(self, session: AsyncClient, path: str) -> None:
231258
self.session = session
232259
self.path = path
@@ -235,7 +262,7 @@ def select(
235262
self,
236263
*columns: str,
237264
count: Optional[CountMethod] = None,
238-
) -> AsyncSelectRequestBuilder:
265+
) -> AsyncSelectRequestBuilder[_ReturnT]:
239266
"""Run a SELECT query.
240267
241268
Args:
@@ -245,7 +272,7 @@ def select(
245272
:class:`AsyncSelectRequestBuilder`
246273
"""
247274
method, params, headers, json = pre_select(*columns, count=count)
248-
return AsyncSelectRequestBuilder(
275+
return AsyncSelectRequestBuilder[_ReturnT](
249276
self.session, self.path, method, headers, params, json
250277
)
251278

@@ -256,7 +283,7 @@ def insert(
256283
count: Optional[CountMethod] = None,
257284
returning: ReturnMethod = ReturnMethod.representation,
258285
upsert: bool = False,
259-
) -> AsyncQueryRequestBuilder:
286+
) -> AsyncQueryRequestBuilder[_ReturnT]:
260287
"""Run an INSERT query.
261288
262289
Args:
@@ -273,7 +300,7 @@ def insert(
273300
returning=returning,
274301
upsert=upsert,
275302
)
276-
return AsyncQueryRequestBuilder(
303+
return AsyncQueryRequestBuilder[_ReturnT](
277304
self.session, self.path, method, headers, params, json
278305
)
279306

@@ -285,7 +312,7 @@ def upsert(
285312
returning: ReturnMethod = ReturnMethod.representation,
286313
ignore_duplicates: bool = False,
287314
on_conflict: str = "",
288-
) -> AsyncQueryRequestBuilder:
315+
) -> AsyncQueryRequestBuilder[_ReturnT]:
289316
"""Run an upsert (INSERT ... ON CONFLICT DO UPDATE) query.
290317
291318
Args:
@@ -304,7 +331,7 @@ def upsert(
304331
ignore_duplicates=ignore_duplicates,
305332
on_conflict=on_conflict,
306333
)
307-
return AsyncQueryRequestBuilder(
334+
return AsyncQueryRequestBuilder[_ReturnT](
308335
self.session, self.path, method, headers, params, json
309336
)
310337

@@ -314,7 +341,7 @@ def update(
314341
*,
315342
count: Optional[CountMethod] = None,
316343
returning: ReturnMethod = ReturnMethod.representation,
317-
) -> AsyncFilterRequestBuilder:
344+
) -> AsyncFilterRequestBuilder[_ReturnT]:
318345
"""Run an UPDATE query.
319346
320347
Args:
@@ -329,7 +356,7 @@ def update(
329356
count=count,
330357
returning=returning,
331358
)
332-
return AsyncFilterRequestBuilder(
359+
return AsyncFilterRequestBuilder[_ReturnT](
333360
self.session, self.path, method, headers, params, json
334361
)
335362

@@ -338,7 +365,7 @@ def delete(
338365
*,
339366
count: Optional[CountMethod] = None,
340367
returning: ReturnMethod = ReturnMethod.representation,
341-
) -> AsyncFilterRequestBuilder:
368+
) -> AsyncFilterRequestBuilder[_ReturnT]:
342369
"""Run a DELETE query.
343370
344371
Args:
@@ -351,7 +378,7 @@ def delete(
351378
count=count,
352379
returning=returning,
353380
)
354-
return AsyncFilterRequestBuilder(
381+
return AsyncFilterRequestBuilder[_ReturnT](
355382
self.session, self.path, method, headers, params, json
356383
)
357384

0 commit comments

Comments
 (0)