diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 27ebcc59..dd8ef02b 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -423,6 +423,7 @@ def __init__( limit: Optional[int] = None, page_size: int = DEFAULT_PAGE_SIZE, sort_fields: Optional[List[str]] = None, + return_fields: Optional[List[str]] = None, nocontent: bool = False, ): if not has_redisearch(model.db()): @@ -447,6 +448,11 @@ def __init__( else: self.sort_fields = [] + if return_fields: + self.return_fields = self.validate_return_fields(return_fields) + else: + self.return_fields = [] + self._expression = None self._query: Optional[str] = None self._pagination: List[str] = [] @@ -504,8 +510,19 @@ def query(self): if self._query.startswith("(") or self._query == "*" else f"({self._query})" ) + f"=>[{self.knn}]" + if self.return_fields: + self._query += f" RETURN {','.join(self.return_fields)}" return self._query + def validate_return_fields(self, return_fields: List[str]): + for field in return_fields: + if field not in self.model.__fields__: # type: ignore + raise QueryNotSupportedError( + f"You tried to return the field {field}, but that field " + f"does not exist on the model {self.model}" + ) + return return_fields + @property def query_params(self): params: List[Union[str, bytes]] = [] @@ -967,6 +984,11 @@ def sort_by(self, *fields: str): if not fields: return self return self.copy(sort_fields=list(fields)) + + def return_fields(self, *fields: str): + if not fields: + return self + return self.copy(return_fields=list(fields)) async def update(self, use_transaction=True, **field_values): """ @@ -1553,7 +1575,9 @@ def find( *expressions: Union[Any, Expression], knn: Optional[KNNExpression] = None, ) -> FindQuery: - return FindQuery(expressions=expressions, knn=knn, model=cls) + return FindQuery( + expressions=expressions, knn=knn, model=cls + ) @classmethod def from_redis(cls, res: Any): diff --git a/tests/test_json_model.py b/tests/test_json_model.py index ea275552..aff67410 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -950,9 +950,21 @@ class TypeWithUuid(JsonModel): await item.save() +@py_test_mark_asyncio +async def test_return_specified_fields(members, m): + member1, member2, member3 = members + actual = await m.Member.find( + (m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins") + | (m.Member.last_name == "Smith") + ).all() + assert actual == [ + {"first_name": "Andrew", "last_name": "Brookins"}, + {"first_name": "Andrew", "last_name": "Smith"}, + ] + @py_test_mark_asyncio -async def test_xfix_queries(m): +async def test_xfix_queries(m):4 await m.Member( first_name="Steve", last_name="Lorello",