diff --git a/postgrest_py/base_request_builder.py b/postgrest_py/base_request_builder.py index 5853c0cc..ac7591cc 100644 --- a/postgrest_py/base_request_builder.py +++ b/postgrest_py/base_request_builder.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from re import search from typing import Any, Dict, Iterable, Optional, Tuple, Type, Union @@ -209,6 +210,27 @@ def cd(self, column: str, values: Iterable[Any]): values = ",".join(values) return self.filter(column, Filters.CD, f"{{{values}}}") + def contains(self, column: str, value: Union[Iterable[Any], str, Dict[Any, Any]]): + if isinstance(value, str): + # range types can be inclusive '[', ']' or exclusive '(', ')' so just + # keep it simple and accept a string + return self.filter(column, Filters.CS, value) + if not isinstance(value, dict) and isinstance(value, Iterable): + # Expected to be some type of iterable + stringified_values = ",".join(value) + return self.filter(column, Filters.CS, f"{{{stringified_values}}}") + + return self.filter(column, Filters.CS, json.dumps(value)) + + def contained_by(self, column: str, value: Union[Iterable[Any], str, Dict[Any, Any]]): + if isinstance(value, str): + # range + return self.filter(column, Filters.CD, value) + if not isinstance(value, dict) and isinstance(value, Iterable): + stringified_values = ",".join(value) + return self.filter(column, Filters.CD, f"{{{stringified_values}}}") + return self.filter(column, Filters.CD, json.dumps(value)) + def ov(self, column: str, values: Iterable[Any]): values = map(sanitize_param, values) values = ",".join(values) diff --git a/tests/_async/test_filter_request_builder.py b/tests/_async/test_filter_request_builder.py index 28554271..78c2054a 100644 --- a/tests/_async/test_filter_request_builder.py +++ b/tests/_async/test_filter_request_builder.py @@ -40,3 +40,37 @@ def test_multivalued_param(filter_request_builder): def test_match(filter_request_builder): builder = filter_request_builder.match({"id": "1", "done": "false"}) assert str(builder.session.params) == "id=eq.1&done=eq.false" + + +def test_contains(filter_request_builder): + builder = filter_request_builder.contains("x", "a") + + assert str(builder.session.params) == "x=cs.a" + + +def test_contains_dictionary(filter_request_builder): + builder = filter_request_builder.contains("x", {"a": "b"}) + + # {"a":"b"} + assert str(builder.session.params) == "x=cs.%7B%22a%22%3A+%22b%22%7D" + + +def test_contains_any_item(filter_request_builder): + builder = filter_request_builder.contains("x", ["a", "b"]) + + # {a,b} + assert str(builder.session.params) == "x=cs.%7Ba%2Cb%7D" + + +def test_contains_in_list(filter_request_builder): + builder = filter_request_builder.contains("x", '[{"a": "b"}]') + + # [{"a":+"b"}] (the + represents the space) + assert str(builder.session.params) == "x=cs.%5B%7B%22a%22%3A+%22b%22%7D%5D" + + +def test_contained_by_mixed_items(filter_request_builder): + builder = filter_request_builder.contained_by("x", ["a", '["b", "c"]']) + + # {a,["b",+"c"]} + assert str(builder.session.params) == "x=cd.%7Ba%2C%5B%22b%22%2C+%22c%22%5D%7D" diff --git a/tests/_sync/test_filter_request_builder.py b/tests/_sync/test_filter_request_builder.py index a5a4c6ca..6d06d44e 100644 --- a/tests/_sync/test_filter_request_builder.py +++ b/tests/_sync/test_filter_request_builder.py @@ -40,3 +40,37 @@ def test_multivalued_param(filter_request_builder): def test_match(filter_request_builder): builder = filter_request_builder.match({"id": "1", "done": "false"}) assert str(builder.session.params) == "id=eq.1&done=eq.false" + + +def test_contains(filter_request_builder): + builder = filter_request_builder.contains("x", "a") + + assert str(builder.session.params) == "x=cs.a" + + +def test_contains_dictionary(filter_request_builder): + builder = filter_request_builder.contains("x", {"a": "b"}) + + # {"a":"b"} + assert str(builder.session.params) == "x=cs.%7B%22a%22%3A+%22b%22%7D" + + +def test_contains_any_item(filter_request_builder): + builder = filter_request_builder.contains("x", ["a", "b"]) + + # {a,b} + assert str(builder.session.params) == "x=cs.%7Ba%2Cb%7D" + + +def test_contains_in_list(filter_request_builder): + builder = filter_request_builder.contains("x", '[{"a": "b"}]') + + # [{"a":+"b"}] (the + represents the space) + assert str(builder.session.params) == "x=cs.%5B%7B%22a%22%3A+%22b%22%7D%5D" + + +def test_contained_by_mixed_items(filter_request_builder): + builder = filter_request_builder.contained_by("x", ["a", '["b", "c"]']) + + # {a,["b",+"c"]} + assert str(builder.session.params) == "x=cd.%7Ba%2C%5B%22b%22%2C+%22c%22%5D%7D"