Skip to content

Commit 0002e8f

Browse files
authored
feat: Add RPC request builder class for additional filters (#372)
1 parent 6d55e49 commit 0002e8f

6 files changed

+143
-2
lines changed

infra/init.sql

+7
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,10 @@ insert into public.issues (id, title, tags) values
6969
(2, 'Use better names', array['is:open', 'severity:low', 'priority:medium']),
7070
(3, 'Add missing postgrest filters', array['is:open', 'severity:low', 'priority:high']),
7171
(4, 'Add alias to filters', array['is:closed', 'severity:low', 'priority:medium']);
72+
73+
create or replace function public.list_stored_countries()
74+
returns setof countries
75+
language sql
76+
as $function$
77+
select * from countries;
78+
$function$

postgrest/_async/request_builder.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ..base_request_builder import (
1010
APIResponse,
1111
BaseFilterRequestBuilder,
12+
BaseRPCRequestBuilder,
1213
BaseSelectRequestBuilder,
1314
CountMethod,
1415
SingleAPIResponse,
@@ -164,7 +165,7 @@ def __init__(
164165

165166
# this exists for type-safety. see https://gist.github.com/anand2312/93d3abf401335fd3310d9e30112303bf
166167
class AsyncRPCFilterRequestBuilder(
167-
BaseFilterRequestBuilder[_ReturnT], AsyncSingleRequestBuilder[_ReturnT]
168+
BaseRPCRequestBuilder[_ReturnT], AsyncSingleRequestBuilder[_ReturnT]
168169
):
169170
def __init__(
170171
self,

postgrest/_sync/request_builder.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ..base_request_builder import (
1010
APIResponse,
1111
BaseFilterRequestBuilder,
12+
BaseRPCRequestBuilder,
1213
BaseSelectRequestBuilder,
1314
CountMethod,
1415
SingleAPIResponse,
@@ -164,7 +165,7 @@ def __init__(
164165

165166
# this exists for type-safety. see https://gist.github.com/anand2312/93d3abf401335fd3310d9e30112303bf
166167
class SyncRPCFilterRequestBuilder(
167-
BaseFilterRequestBuilder[_ReturnT], SyncSingleRequestBuilder[_ReturnT]
168+
BaseRPCRequestBuilder[_ReturnT], SyncSingleRequestBuilder[_ReturnT]
168169
):
169170
def __init__(
170171
self,

postgrest/base_request_builder.py

+50
Original file line numberDiff line numberDiff line change
@@ -576,3 +576,53 @@ def range(
576576
end - start + 1,
577577
)
578578
return self
579+
580+
581+
class BaseRPCRequestBuilder(BaseSelectRequestBuilder[_ReturnT]):
582+
def __init__(
583+
self,
584+
session: Union[AsyncClient, SyncClient],
585+
headers: Headers,
586+
params: QueryParams,
587+
) -> None:
588+
# Generic[T] is an instance of typing._GenericAlias, so doing Generic[T].__init__
589+
# tries to call _GenericAlias.__init__ - which is the wrong method
590+
# The __origin__ attribute of the _GenericAlias is the actual class
591+
get_origin_and_cast(BaseSelectRequestBuilder[_ReturnT]).__init__(
592+
self, session, headers, params
593+
)
594+
595+
def select(
596+
self,
597+
*columns: str,
598+
) -> Self:
599+
"""Run a SELECT query.
600+
601+
Args:
602+
*columns: The names of the columns to fetch.
603+
Returns:
604+
:class:`BaseSelectRequestBuilder`
605+
"""
606+
method, params, headers, json = pre_select(*columns, count=None)
607+
self.params = self.params.add("select", params.get("select"))
608+
self.headers["Prefer"] = "return=representation"
609+
return self
610+
611+
def single(self) -> Self:
612+
"""Specify that the query will only return a single row in response.
613+
614+
.. caution::
615+
The API will raise an error if the query returned more than one row.
616+
"""
617+
self.headers["Accept"] = "application/vnd.pgrst.object+json"
618+
return self
619+
620+
def maybe_single(self) -> Self:
621+
"""Retrieves at most one row from the result. Result must be at most one row (e.g. using `eq` on a UNIQUE column), otherwise this will result in an error."""
622+
self.headers["Accept"] = "application/vnd.pgrst.object+json"
623+
return self
624+
625+
def csv(self) -> Self:
626+
"""Specify that the query must retrieve data as a single CSV string."""
627+
self.headers["Accept"] = "text/csv"
628+
return self

tests/_async/test_filter_request_builder_integration.py

+41
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pytest
2+
13
from .client import rest_client
24

35

@@ -387,3 +389,42 @@ async def test_or_on_reference_table():
387389
],
388390
},
389391
]
392+
393+
394+
async def test_rpc_with_single():
395+
res = (
396+
await rest_client()
397+
.rpc("list_stored_countries", {})
398+
.select("nicename, country_name, iso")
399+
.eq("nicename", "Albania")
400+
.single()
401+
.execute()
402+
)
403+
404+
assert res.data == {"nicename": "Albania", "country_name": "ALBANIA", "iso": "AL"}
405+
406+
407+
async def test_rpc_with_limit():
408+
res = (
409+
await rest_client()
410+
.rpc("list_stored_countries", {})
411+
.select("nicename, country_name, iso")
412+
.eq("nicename", "Albania")
413+
.limit(1)
414+
.execute()
415+
)
416+
417+
assert res.data == [{"nicename": "Albania", "country_name": "ALBANIA", "iso": "AL"}]
418+
419+
420+
@pytest.mark.skip(reason="Need to re-implement range to use query parameters")
421+
async def test_rpc_with_range():
422+
res = (
423+
await rest_client()
424+
.rpc("list_stored_countries", {})
425+
.select("nicename, iso")
426+
.range(0, 1)
427+
.execute()
428+
)
429+
430+
assert res.data == [{"nicename": "Albania", "iso": "AL"}]

tests/_sync/test_filter_request_builder_integration.py

+41
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pytest
2+
13
from .client import rest_client
24

35

@@ -380,3 +382,42 @@ def test_or_on_reference_table():
380382
],
381383
},
382384
]
385+
386+
387+
def test_rpc_with_single():
388+
res = (
389+
rest_client()
390+
.rpc("list_stored_countries", {})
391+
.select("nicename, country_name, iso")
392+
.eq("nicename", "Albania")
393+
.single()
394+
.execute()
395+
)
396+
397+
assert res.data == {"nicename": "Albania", "country_name": "ALBANIA", "iso": "AL"}
398+
399+
400+
def test_rpc_with_limit():
401+
res = (
402+
rest_client()
403+
.rpc("list_stored_countries", {})
404+
.select("nicename, country_name, iso")
405+
.eq("nicename", "Albania")
406+
.limit(1)
407+
.execute()
408+
)
409+
410+
assert res.data == [{"nicename": "Albania", "country_name": "ALBANIA", "iso": "AL"}]
411+
412+
413+
@pytest.mark.skip(reason="Need to re-implement range to use query parameters")
414+
def test_rpc_with_range():
415+
res = (
416+
rest_client()
417+
.rpc("list_stored_countries", {})
418+
.select("nicename, iso")
419+
.range(0, 1)
420+
.execute()
421+
)
422+
423+
assert res.data == [{"nicename": "Albania", "iso": "AL"}]

0 commit comments

Comments
 (0)