Skip to content

added return_fields function, attempting to optionally limit fields r… #633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def __init__(
limit: Optional[int] = None,
page_size: int = DEFAULT_PAGE_SIZE,
sort_fields: Optional[List[str]] = None,
return_fields: Optional[List[str]] = None,
nocontent: bool = False,
):
if not has_redisearch(model.db()):
Expand All @@ -445,6 +446,11 @@ def __init__(
else:
self.sort_fields = []

if return_fields:
self.return_fields = self.validate_return_fields(return_fields)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since return_fields is being declared below as a method you can't perform this assignment (it makes the linter mad and probably leads to some invalid state)

else:
self.return_fields = []

self._expression = None
self._query: Optional[str] = None
self._pagination: List[str] = []
Expand Down Expand Up @@ -502,8 +508,19 @@ def query(self):
if self._query.startswith("(") or self._query == "*"
else f"({self._query})"
) + f"=>[{self.knn}]"
if self.return_fields:
self._query += f" RETURN {','.join(self.return_fields)}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RETURN statement shouldn't be added to the query string, that's the filtration part of the query. Rather it needs to be passed to the list of args being passed to Redis. See the execute method inside of find_query to see what I mean, the RETURN, the number of fields being returned, and each individual field being returned need to be added as individual arguments to that list argument for it to be interpreted correctly.

Copy link

@DennyD17 DennyD17 Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, what are you going to do with deserialization?
It seems like the query with RETURN returns data in different way... it was my problem

I mean https://github.com/redis/redis-om-python/blob/main/aredis_om/model/model.py#L1583

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me this works

    @staticmethod
    def create_return_part(fields: List[str] | Set[str]) -> List[str]:
        q = ["RETURN", "n"]

        for field in fields:
            q.extend([f"$.{field}", "AS", field])
        q[1] = len(q) - 2
        return q

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO you'd return a dict[string,dict[string,string]] as that's essentially what the API responds with (technically it's an array of strings and arrays with an integer at it's head)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm talking about this behavior:
I have Redis (Search + Json)
When I run query (using redis-cli for example) without RETURN:
FT.SEARCH ':foo:index' '(@parent_id:[2 2])' LIMIT 1 1

I get the output like JSON document
3) 1) "$" 2) "{\"updated_at\":\"2024-09-13 12:14:55.877717\",\"id\":3,\"page_id\":3

And here are right data types (id and page id are integers)

But when I use RETURN
FT.SEARCH ':foo:index' '(@parent_id:[2 2])' LIMIT 0 1000 RETURN 6 $.id AS id, $.page_id AS page_id

I get the output like this

3) 1) "id,"
   2) "2"
   3) "page_id"
   4) "2"

So my question is how are you going to deserialize this results?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My fault guys,
Everything will be OK after validation
field.validate(value, values, loc=field.alias, cls=cls_)

I've turned it off, very expensive to run validations on read queries

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But please check if it works with dict fields, I get error "value is not a valid dict" while processing '{"foo": "bar"}'

return self._query

def validate_return_fields(self, return_fields: List[str]):
for field in return_fields:
if field not in self.model.__fields__: # type: ignore
raise QueryNotSupportedError(
f"You tried to return the field {field}, but that field "
f"does not exist on the model {self.model}"
)
return return_fields

@property
def query_params(self):
params: List[Union[str, bytes]] = []
Expand Down Expand Up @@ -956,6 +973,11 @@ def sort_by(self, *fields: str):
if not fields:
return self
return self.copy(sort_fields=list(fields))

def return_fields(self, *fields: str):
if not fields:
return self
return self.copy(return_fields=list(fields))

async def update(self, use_transaction=True, **field_values):
"""
Expand Down Expand Up @@ -1531,7 +1553,9 @@ def find(
*expressions: Union[Any, Expression],
knn: Optional[KNNExpression] = None,
) -> FindQuery:
return FindQuery(expressions=expressions, knn=knn, model=cls)
return FindQuery(
expressions=expressions, knn=knn, model=cls
)

@classmethod
def from_redis(cls, res: Any):
Expand Down
14 changes: 13 additions & 1 deletion tests/test_json_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,9 +935,21 @@ class TypeWithUuid(JsonModel):

await item.save()

@py_test_mark_asyncio
async def test_return_specified_fields(members, m):
member1, member2, member3 = members
actual = await m.Member.find(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One issue to contend with here is that the model validation will not work correctly when running find, you'll need to make sure you return a dictionary or some sub-set of the model to prevent validation errors from being tossed by the find.

(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
| (m.Member.last_name == "Smith")
).all()
assert actual == [
{"first_name": "Andrew", "last_name": "Brookins"},
{"first_name": "Andrew", "last_name": "Smith"},
]


@py_test_mark_asyncio
async def test_xfix_queries(m):
async def test_xfix_queries(m):4
await m.Member(
first_name="Steve",
last_name="Lorello",
Expand Down
Loading