Skip to content

Commit ae5f80a

Browse files
authored
fix: upsert and insert with default_to_null boolean argument (#398)
1 parent b4c740d commit ae5f80a

File tree

5 files changed

+130
-9
lines changed

5 files changed

+130
-9
lines changed

postgrest/_async/request_builder.py

+11
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ def insert(
303303
count: Optional[CountMethod] = None,
304304
returning: ReturnMethod = ReturnMethod.representation,
305305
upsert: bool = False,
306+
default_to_null: bool = True,
306307
) -> AsyncQueryRequestBuilder[_ReturnT]:
307308
"""Run an INSERT query.
308309
@@ -311,6 +312,9 @@ def insert(
311312
count: The method to use to get the count of rows returned.
312313
returning: Either 'minimal' or 'representation'
313314
upsert: Whether the query should be an upsert.
315+
default_to_null: Make missing fields default to `null`.
316+
Otherwise, use the default value for the column.
317+
Only applies for bulk inserts.
314318
Returns:
315319
:class:`AsyncQueryRequestBuilder`
316320
"""
@@ -319,6 +323,7 @@ def insert(
319323
count=count,
320324
returning=returning,
321325
upsert=upsert,
326+
default_to_null=default_to_null,
322327
)
323328
return AsyncQueryRequestBuilder[_ReturnT](
324329
self.session, self.path, method, headers, params, json
@@ -332,6 +337,7 @@ def upsert(
332337
returning: ReturnMethod = ReturnMethod.representation,
333338
ignore_duplicates: bool = False,
334339
on_conflict: str = "",
340+
default_to_null: bool = True,
335341
) -> AsyncQueryRequestBuilder[_ReturnT]:
336342
"""Run an upsert (INSERT ... ON CONFLICT DO UPDATE) query.
337343
@@ -341,6 +347,10 @@ def upsert(
341347
returning: Either 'minimal' or 'representation'
342348
ignore_duplicates: Whether duplicate rows should be ignored.
343349
on_conflict: Specified columns to be made to work with UNIQUE constraint.
350+
default_to_null: Make missing fields default to `null`. Otherwise, use the
351+
default value for the column. This only applies when inserting new rows,
352+
not when merging with existing rows under `ignoreDuplicates: false`.
353+
This also only applies when doing bulk upserts.
344354
Returns:
345355
:class:`AsyncQueryRequestBuilder`
346356
"""
@@ -350,6 +360,7 @@ def upsert(
350360
returning=returning,
351361
ignore_duplicates=ignore_duplicates,
352362
on_conflict=on_conflict,
363+
default_to_null=default_to_null,
353364
)
354365
return AsyncQueryRequestBuilder[_ReturnT](
355366
self.session, self.path, method, headers, params, json

postgrest/_sync/request_builder.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(
3434
http_method: str,
3535
headers: Headers,
3636
params: QueryParams,
37-
json: dict,
37+
json: Union[dict, list],
3838
) -> None:
3939
self.session = session
4040
self.path = path
@@ -289,7 +289,7 @@ def select(
289289
*columns: The names of the columns to fetch.
290290
count: The method to use to get the count of rows returned.
291291
Returns:
292-
:class:`AsyncSelectRequestBuilder`
292+
:class:`SyncSelectRequestBuilder`
293293
"""
294294
method, params, headers, json = pre_select(*columns, count=count)
295295
return SyncSelectRequestBuilder[_ReturnT](
@@ -303,6 +303,7 @@ def insert(
303303
count: Optional[CountMethod] = None,
304304
returning: ReturnMethod = ReturnMethod.representation,
305305
upsert: bool = False,
306+
default_to_null: bool = True,
306307
) -> SyncQueryRequestBuilder[_ReturnT]:
307308
"""Run an INSERT query.
308309
@@ -311,14 +312,18 @@ def insert(
311312
count: The method to use to get the count of rows returned.
312313
returning: Either 'minimal' or 'representation'
313314
upsert: Whether the query should be an upsert.
315+
default_to_null: Make missing fields default to `null`.
316+
Otherwise, use the default value for the column.
317+
Only applies for bulk inserts.
314318
Returns:
315-
:class:`AsyncQueryRequestBuilder`
319+
:class:`SyncQueryRequestBuilder`
316320
"""
317321
method, params, headers, json = pre_insert(
318322
json,
319323
count=count,
320324
returning=returning,
321325
upsert=upsert,
326+
default_to_null=default_to_null,
322327
)
323328
return SyncQueryRequestBuilder[_ReturnT](
324329
self.session, self.path, method, headers, params, json
@@ -332,6 +337,7 @@ def upsert(
332337
returning: ReturnMethod = ReturnMethod.representation,
333338
ignore_duplicates: bool = False,
334339
on_conflict: str = "",
340+
default_to_null: bool = True,
335341
) -> SyncQueryRequestBuilder[_ReturnT]:
336342
"""Run an upsert (INSERT ... ON CONFLICT DO UPDATE) query.
337343
@@ -341,15 +347,20 @@ def upsert(
341347
returning: Either 'minimal' or 'representation'
342348
ignore_duplicates: Whether duplicate rows should be ignored.
343349
on_conflict: Specified columns to be made to work with UNIQUE constraint.
350+
default_to_null: Make missing fields default to `null`. Otherwise, use the
351+
default value for the column. This only applies when inserting new rows,
352+
not when merging with existing rows under `ignoreDuplicates: false`.
353+
This also only applies when doing bulk upserts.
344354
Returns:
345-
:class:`AsyncQueryRequestBuilder`
355+
:class:`SyncQueryRequestBuilder`
346356
"""
347357
method, params, headers, json = pre_upsert(
348358
json,
349359
count=count,
350360
returning=returning,
351361
ignore_duplicates=ignore_duplicates,
352362
on_conflict=on_conflict,
363+
default_to_null=default_to_null,
353364
)
354365
return SyncQueryRequestBuilder[_ReturnT](
355366
self.session, self.path, method, headers, params, json
@@ -369,7 +380,7 @@ def update(
369380
count: The method to use to get the count of rows returned.
370381
returning: Either 'minimal' or 'representation'
371382
Returns:
372-
:class:`AsyncFilterRequestBuilder`
383+
:class:`SyncFilterRequestBuilder`
373384
"""
374385
method, params, headers, json = pre_update(
375386
json,
@@ -392,7 +403,7 @@ def delete(
392403
count: The method to use to get the count of rows returned.
393404
returning: Either 'minimal' or 'representation'
394405
Returns:
395-
:class:`AsyncFilterRequestBuilder`
406+
:class:`SyncFilterRequestBuilder`
396407
"""
397408
method, params, headers, json = pre_delete(
398409
count=count,

postgrest/base_request_builder.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ class QueryArgs(NamedTuple):
4646
json: Dict[Any, Any]
4747

4848

49+
def _unique_columns(json: List[Dict]):
50+
unique_keys = {key for row in json for key in row.keys()}
51+
columns = ",".join([f'"{k}"' for k in unique_keys])
52+
return columns
53+
54+
4955
def pre_select(
5056
*columns: str,
5157
count: Optional[CountMethod] = None,
@@ -61,38 +67,51 @@ def pre_select(
6167

6268

6369
def pre_insert(
64-
json: dict,
70+
json: Union[dict, list],
6571
*,
6672
count: Optional[CountMethod],
6773
returning: ReturnMethod,
6874
upsert: bool,
75+
default_to_null: bool = True,
6976
) -> QueryArgs:
7077
prefer_headers = [f"return={returning}"]
7178
if count:
7279
prefer_headers.append(f"count={count}")
7380
if upsert:
7481
prefer_headers.append("resolution=merge-duplicates")
82+
if not default_to_null:
83+
prefer_headers.append("missing=default")
7584
headers = Headers({"Prefer": ",".join(prefer_headers)})
76-
return QueryArgs(RequestMethod.POST, QueryParams(), headers, json)
85+
# Adding 'columns' query parameters
86+
query_params = {}
87+
if isinstance(json, list):
88+
query_params = {"columns": _unique_columns(json)}
89+
return QueryArgs(RequestMethod.POST, QueryParams(query_params), headers, json)
7790

7891

7992
def pre_upsert(
80-
json: dict,
93+
json: Union[dict, list],
8194
*,
8295
count: Optional[CountMethod],
8396
returning: ReturnMethod,
8497
ignore_duplicates: bool,
8598
on_conflict: str = "",
99+
default_to_null: bool = True,
86100
) -> QueryArgs:
87101
query_params = {}
88102
prefer_headers = [f"return={returning}"]
89103
if count:
90104
prefer_headers.append(f"count={count}")
91105
resolution = "ignore" if ignore_duplicates else "merge"
92106
prefer_headers.append(f"resolution={resolution}-duplicates")
107+
if not default_to_null:
108+
prefer_headers.append("missing=default")
93109
headers = Headers({"Prefer": ",".join(prefer_headers)})
94110
if on_conflict:
95111
query_params["on_conflict"] = on_conflict
112+
# Adding 'columns' query parameters
113+
if isinstance(json, list):
114+
query_params["columns"] = _unique_columns(json)
96115
return QueryArgs(RequestMethod.POST, QueryParams(query_params), headers, json)
97116

98117

tests/_async/test_request_builder.py

+40
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,31 @@ def test_insert_with_upsert(self, request_builder: AsyncRequestBuilder):
7171
assert builder.http_method == "POST"
7272
assert builder.json == {"key1": "val1"}
7373

74+
def test_upsert_with_default_single(self, request_builder: AsyncRequestBuilder):
75+
builder = request_builder.upsert([{"key1": "val1"}], default_to_null=False)
76+
assert builder.headers.get_list("prefer", True) == [
77+
"return=representation",
78+
"resolution=merge-duplicates",
79+
"missing=default",
80+
]
81+
assert builder.http_method == "POST"
82+
assert builder.json == [{"key1": "val1"}]
83+
assert builder.params.get("columns") == '"key1"'
84+
85+
def test_bulk_insert_using_default(self, request_builder: AsyncRequestBuilder):
86+
builder = request_builder.insert(
87+
[{"key1": "val1", "key2": "val2"}, {"key3": "val3"}], default_to_null=False
88+
)
89+
assert builder.headers.get_list("prefer", True) == [
90+
"return=representation",
91+
"missing=default",
92+
]
93+
assert builder.http_method == "POST"
94+
assert builder.json == [{"key1": "val1", "key2": "val2"}, {"key3": "val3"}]
95+
assert set(builder.params["columns"].split(",")) == set(
96+
'"key1","key2","key3"'.split(",")
97+
)
98+
7499
def test_upsert(self, request_builder: AsyncRequestBuilder):
75100
builder = request_builder.upsert({"key1": "val1"})
76101

@@ -81,6 +106,21 @@ def test_upsert(self, request_builder: AsyncRequestBuilder):
81106
assert builder.http_method == "POST"
82107
assert builder.json == {"key1": "val1"}
83108

109+
def test_bulk_upsert_with_default(self, request_builder: AsyncRequestBuilder):
110+
builder = request_builder.upsert(
111+
[{"key1": "val1", "key2": "val2"}, {"key3": "val3"}], default_to_null=False
112+
)
113+
assert builder.headers.get_list("prefer", True) == [
114+
"return=representation",
115+
"resolution=merge-duplicates",
116+
"missing=default",
117+
]
118+
assert builder.http_method == "POST"
119+
assert builder.json == [{"key1": "val1", "key2": "val2"}, {"key3": "val3"}]
120+
assert set(builder.params["columns"].split(",")) == set(
121+
'"key1","key2","key3"'.split(",")
122+
)
123+
84124

85125
class TestUpdate:
86126
def test_update(self, request_builder: AsyncRequestBuilder):

tests/_sync/test_request_builder.py

+40
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,20 @@ def test_insert_with_upsert(self, request_builder: SyncRequestBuilder):
7171
assert builder.http_method == "POST"
7272
assert builder.json == {"key1": "val1"}
7373

74+
def test_bulk_insert_using_default(self, request_builder: SyncRequestBuilder):
75+
builder = request_builder.insert(
76+
[{"key1": "val1", "key2": "val2"}, {"key3": "val3"}], default_to_null=False
77+
)
78+
assert builder.headers.get_list("prefer", True) == [
79+
"return=representation",
80+
"missing=default",
81+
]
82+
assert builder.http_method == "POST"
83+
assert builder.json == [{"key1": "val1", "key2": "val2"}, {"key3": "val3"}]
84+
assert set(builder.params["columns"].split(",")) == set(
85+
'"key1","key2","key3"'.split(",")
86+
)
87+
7488
def test_upsert(self, request_builder: SyncRequestBuilder):
7589
builder = request_builder.upsert({"key1": "val1"})
7690

@@ -81,6 +95,32 @@ def test_upsert(self, request_builder: SyncRequestBuilder):
8195
assert builder.http_method == "POST"
8296
assert builder.json == {"key1": "val1"}
8397

98+
def test_upsert_with_default_single(self, request_builder: SyncRequestBuilder):
99+
builder = request_builder.upsert([{"key1": "val1"}], default_to_null=False)
100+
assert builder.headers.get_list("prefer", True) == [
101+
"return=representation",
102+
"resolution=merge-duplicates",
103+
"missing=default",
104+
]
105+
assert builder.http_method == "POST"
106+
assert builder.json == [{"key1": "val1"}]
107+
assert builder.params.get("columns") == '"key1"'
108+
109+
def test_bulk_upsert_with_default(self, request_builder: SyncRequestBuilder):
110+
builder = request_builder.upsert(
111+
[{"key1": "val1", "key2": "val2"}, {"key3": "val3"}], default_to_null=False
112+
)
113+
assert builder.headers.get_list("prefer", True) == [
114+
"return=representation",
115+
"resolution=merge-duplicates",
116+
"missing=default",
117+
]
118+
assert builder.http_method == "POST"
119+
assert builder.json == [{"key1": "val1", "key2": "val2"}, {"key3": "val3"}]
120+
assert set(builder.params["columns"].split(",")) == set(
121+
'"key1","key2","key3"'.split(",")
122+
)
123+
84124

85125
class TestUpdate:
86126
def test_update(self, request_builder: SyncRequestBuilder):

0 commit comments

Comments
 (0)