1
1
from __future__ import annotations
2
2
3
3
from json import JSONDecodeError
4
- from typing import Optional , Union
4
+ from typing import Any , Generic , Optional , TypeVar , Union
5
5
6
6
from httpx import Headers , QueryParams
7
7
from pydantic import ValidationError
20
20
)
21
21
from ..exceptions import APIError , generate_default_error_message
22
22
from ..types import ReturnMethod
23
- from ..utils import AsyncClient
23
+ from ..utils import AsyncClient , get_origin_and_cast
24
24
25
+ _ReturnT = TypeVar ("_ReturnT" )
25
26
26
- class AsyncQueryRequestBuilder :
27
+
28
+ class AsyncQueryRequestBuilder (Generic [_ReturnT ]):
27
29
def __init__ (
28
30
self ,
29
31
session : AsyncClient ,
@@ -40,7 +42,7 @@ def __init__(
40
42
self .params = params
41
43
self .json = json
42
44
43
- async def execute (self ) -> APIResponse :
45
+ async def execute (self ) -> APIResponse [ _ReturnT ] :
44
46
"""Execute the query.
45
47
46
48
.. tip::
@@ -63,7 +65,7 @@ async def execute(self) -> APIResponse:
63
65
if (
64
66
200 <= r .status_code <= 299
65
67
): # Response.ok from JS (https://developer.mozilla.org/en-US/docs/Web/API/Response/ok)
66
- return APIResponse .from_http_request_response (r )
68
+ return APIResponse [ _ReturnT ] .from_http_request_response (r )
67
69
else :
68
70
raise APIError (r .json ())
69
71
except ValidationError as e :
@@ -72,7 +74,7 @@ async def execute(self) -> APIResponse:
72
74
raise APIError (generate_default_error_message (r ))
73
75
74
76
75
- class AsyncSingleRequestBuilder :
77
+ class AsyncSingleRequestBuilder ( Generic [ _ReturnT ]) :
76
78
def __init__ (
77
79
self ,
78
80
session : AsyncClient ,
@@ -89,7 +91,7 @@ def __init__(
89
91
self .params = params
90
92
self .json = json
91
93
92
- async def execute (self ) -> SingleAPIResponse :
94
+ async def execute (self ) -> SingleAPIResponse [ _ReturnT ] :
93
95
"""Execute the query.
94
96
95
97
.. tip::
@@ -112,7 +114,7 @@ async def execute(self) -> SingleAPIResponse:
112
114
if (
113
115
200 <= r .status_code <= 299
114
116
): # Response.ok from JS (https://developer.mozilla.org/en-US/docs/Web/API/Response/ok)
115
- return SingleAPIResponse .from_http_request_response (r )
117
+ return SingleAPIResponse [ _ReturnT ] .from_http_request_response (r )
116
118
else :
117
119
raise APIError (r .json ())
118
120
except ValidationError as e :
@@ -121,11 +123,11 @@ async def execute(self) -> SingleAPIResponse:
121
123
raise APIError (generate_default_error_message (r ))
122
124
123
125
124
- class AsyncMaybeSingleRequestBuilder (AsyncSingleRequestBuilder ):
125
- async def execute (self ) -> Optional [SingleAPIResponse ]:
126
+ class AsyncMaybeSingleRequestBuilder (AsyncSingleRequestBuilder [ _ReturnT ] ):
127
+ async def execute (self ) -> Optional [SingleAPIResponse [ _ReturnT ] ]:
126
128
r = None
127
129
try :
128
- r = await super () .execute ()
130
+ r = await AsyncSingleRequestBuilder [ _ReturnT ] .execute (self )
129
131
except APIError as e :
130
132
if e .details and "The result contains 0 rows" in e .details :
131
133
return None
@@ -142,7 +144,7 @@ async def execute(self) -> Optional[SingleAPIResponse]:
142
144
143
145
144
146
# ignoring type checking as a workaround for https://github.com/python/mypy/issues/9319
145
- class AsyncFilterRequestBuilder (BaseFilterRequestBuilder , AsyncQueryRequestBuilder ): # type: ignore
147
+ class AsyncFilterRequestBuilder (BaseFilterRequestBuilder [ _ReturnT ] , AsyncQueryRequestBuilder [ _ReturnT ] ): # type: ignore
146
148
def __init__ (
147
149
self ,
148
150
session : AsyncClient ,
@@ -152,14 +154,37 @@ def __init__(
152
154
params : QueryParams ,
153
155
json : dict ,
154
156
) -> None :
155
- BaseFilterRequestBuilder .__init__ (self , session , headers , params )
156
- AsyncQueryRequestBuilder .__init__ (
157
+ get_origin_and_cast (BaseFilterRequestBuilder [_ReturnT ]).__init__ (
158
+ self , session , headers , params
159
+ )
160
+ get_origin_and_cast (AsyncQueryRequestBuilder [_ReturnT ]).__init__ (
161
+ self , session , path , http_method , headers , params , json
162
+ )
163
+
164
+
165
+ # this exists for type-safety. see https://gist.github.com/anand2312/93d3abf401335fd3310d9e30112303bf
166
+ class AsyncRPCFilterRequestBuilder (
167
+ BaseFilterRequestBuilder [_ReturnT ], AsyncSingleRequestBuilder [_ReturnT ]
168
+ ):
169
+ def __init__ (
170
+ self ,
171
+ session : AsyncClient ,
172
+ path : str ,
173
+ http_method : str ,
174
+ headers : Headers ,
175
+ params : QueryParams ,
176
+ json : dict ,
177
+ ) -> None :
178
+ get_origin_and_cast (BaseFilterRequestBuilder [_ReturnT ]).__init__ (
179
+ self , session , headers , params
180
+ )
181
+ get_origin_and_cast (AsyncSingleRequestBuilder [_ReturnT ]).__init__ (
157
182
self , session , path , http_method , headers , params , json
158
183
)
159
184
160
185
161
186
# ignoring type checking as a workaround for https://github.com/python/mypy/issues/9319
162
- class AsyncSelectRequestBuilder (BaseSelectRequestBuilder , AsyncQueryRequestBuilder ): # type: ignore
187
+ class AsyncSelectRequestBuilder (BaseSelectRequestBuilder [ _ReturnT ] , AsyncQueryRequestBuilder [ _ReturnT ] ): # type: ignore
163
188
def __init__ (
164
189
self ,
165
190
session : AsyncClient ,
@@ -169,19 +194,21 @@ def __init__(
169
194
params : QueryParams ,
170
195
json : dict ,
171
196
) -> None :
172
- BaseSelectRequestBuilder .__init__ (self , session , headers , params )
173
- AsyncQueryRequestBuilder .__init__ (
197
+ get_origin_and_cast (BaseSelectRequestBuilder [_ReturnT ]).__init__ (
198
+ self , session , headers , params
199
+ )
200
+ get_origin_and_cast (AsyncQueryRequestBuilder [_ReturnT ]).__init__ (
174
201
self , session , path , http_method , headers , params , json
175
202
)
176
203
177
- def single (self ) -> AsyncSingleRequestBuilder :
204
+ def single (self ) -> AsyncSingleRequestBuilder [ _ReturnT ] :
178
205
"""Specify that the query will only return a single row in response.
179
206
180
207
.. caution::
181
208
The API will raise an error if the query returned more than one row.
182
209
"""
183
210
self .headers ["Accept" ] = "application/vnd.pgrst.object+json"
184
- return AsyncSingleRequestBuilder (
211
+ return AsyncSingleRequestBuilder [ _ReturnT ] (
185
212
headers = self .headers ,
186
213
http_method = self .http_method ,
187
214
json = self .json ,
@@ -190,10 +217,10 @@ def single(self) -> AsyncSingleRequestBuilder:
190
217
session = self .session , # type: ignore
191
218
)
192
219
193
- def maybe_single (self ) -> AsyncMaybeSingleRequestBuilder :
220
+ def maybe_single (self ) -> AsyncMaybeSingleRequestBuilder [ _ReturnT ] :
194
221
"""Retrieves at most one row from the result. Result must be at most one row (e.g. using `eq` on a UNIQUE column), otherwise this will result in an error."""
195
222
self .headers ["Accept" ] = "application/vnd.pgrst.object+json"
196
- return AsyncMaybeSingleRequestBuilder (
223
+ return AsyncMaybeSingleRequestBuilder [ _ReturnT ] (
197
224
headers = self .headers ,
198
225
http_method = self .http_method ,
199
226
json = self .json ,
@@ -203,8 +230,8 @@ def maybe_single(self) -> AsyncMaybeSingleRequestBuilder:
203
230
)
204
231
205
232
def text_search (
206
- self , column : str , query : str , options : Dict [str , any ] = {}
207
- ) -> AsyncFilterRequestBuilder :
233
+ self , column : str , query : str , options : dict [str , Any ] = {}
234
+ ) -> AsyncFilterRequestBuilder [ _ReturnT ] :
208
235
type_ = options .get ("type" )
209
236
type_part = ""
210
237
if type_ == "plain" :
@@ -216,7 +243,7 @@ def text_search(
216
243
config_part = f"({ options .get ('config' )} )" if options .get ("config" ) else ""
217
244
self .params = self .params .add (column , f"{ type_part } fts{ config_part } .{ query } " )
218
245
219
- return AsyncQueryRequestBuilder (
246
+ return AsyncQueryRequestBuilder [ _ReturnT ] (
220
247
headers = self .headers ,
221
248
http_method = self .http_method ,
222
249
json = self .json ,
@@ -226,7 +253,7 @@ def text_search(
226
253
)
227
254
228
255
229
- class AsyncRequestBuilder :
256
+ class AsyncRequestBuilder ( Generic [ _ReturnT ]) :
230
257
def __init__ (self , session : AsyncClient , path : str ) -> None :
231
258
self .session = session
232
259
self .path = path
@@ -235,7 +262,7 @@ def select(
235
262
self ,
236
263
* columns : str ,
237
264
count : Optional [CountMethod ] = None ,
238
- ) -> AsyncSelectRequestBuilder :
265
+ ) -> AsyncSelectRequestBuilder [ _ReturnT ] :
239
266
"""Run a SELECT query.
240
267
241
268
Args:
@@ -245,7 +272,7 @@ def select(
245
272
:class:`AsyncSelectRequestBuilder`
246
273
"""
247
274
method , params , headers , json = pre_select (* columns , count = count )
248
- return AsyncSelectRequestBuilder (
275
+ return AsyncSelectRequestBuilder [ _ReturnT ] (
249
276
self .session , self .path , method , headers , params , json
250
277
)
251
278
@@ -256,7 +283,7 @@ def insert(
256
283
count : Optional [CountMethod ] = None ,
257
284
returning : ReturnMethod = ReturnMethod .representation ,
258
285
upsert : bool = False ,
259
- ) -> AsyncQueryRequestBuilder :
286
+ ) -> AsyncQueryRequestBuilder [ _ReturnT ] :
260
287
"""Run an INSERT query.
261
288
262
289
Args:
@@ -273,7 +300,7 @@ def insert(
273
300
returning = returning ,
274
301
upsert = upsert ,
275
302
)
276
- return AsyncQueryRequestBuilder (
303
+ return AsyncQueryRequestBuilder [ _ReturnT ] (
277
304
self .session , self .path , method , headers , params , json
278
305
)
279
306
@@ -285,7 +312,7 @@ def upsert(
285
312
returning : ReturnMethod = ReturnMethod .representation ,
286
313
ignore_duplicates : bool = False ,
287
314
on_conflict : str = "" ,
288
- ) -> AsyncQueryRequestBuilder :
315
+ ) -> AsyncQueryRequestBuilder [ _ReturnT ] :
289
316
"""Run an upsert (INSERT ... ON CONFLICT DO UPDATE) query.
290
317
291
318
Args:
@@ -304,7 +331,7 @@ def upsert(
304
331
ignore_duplicates = ignore_duplicates ,
305
332
on_conflict = on_conflict ,
306
333
)
307
- return AsyncQueryRequestBuilder (
334
+ return AsyncQueryRequestBuilder [ _ReturnT ] (
308
335
self .session , self .path , method , headers , params , json
309
336
)
310
337
@@ -314,7 +341,7 @@ def update(
314
341
* ,
315
342
count : Optional [CountMethod ] = None ,
316
343
returning : ReturnMethod = ReturnMethod .representation ,
317
- ) -> AsyncFilterRequestBuilder :
344
+ ) -> AsyncFilterRequestBuilder [ _ReturnT ] :
318
345
"""Run an UPDATE query.
319
346
320
347
Args:
@@ -329,7 +356,7 @@ def update(
329
356
count = count ,
330
357
returning = returning ,
331
358
)
332
- return AsyncFilterRequestBuilder (
359
+ return AsyncFilterRequestBuilder [ _ReturnT ] (
333
360
self .session , self .path , method , headers , params , json
334
361
)
335
362
@@ -338,7 +365,7 @@ def delete(
338
365
* ,
339
366
count : Optional [CountMethod ] = None ,
340
367
returning : ReturnMethod = ReturnMethod .representation ,
341
- ) -> AsyncFilterRequestBuilder :
368
+ ) -> AsyncFilterRequestBuilder [ _ReturnT ] :
342
369
"""Run a DELETE query.
343
370
344
371
Args:
@@ -351,7 +378,7 @@ def delete(
351
378
count = count ,
352
379
returning = returning ,
353
380
)
354
- return AsyncFilterRequestBuilder (
381
+ return AsyncFilterRequestBuilder [ _ReturnT ] (
355
382
self .session , self .path , method , headers , params , json
356
383
)
357
384
0 commit comments