diff --git a/elasticsearch/helpers/vectorstore/_async/strategies.py b/elasticsearch/helpers/vectorstore/_async/strategies.py index a7f813f43..10524e243 100644 --- a/elasticsearch/helpers/vectorstore/_async/strategies.py +++ b/elasticsearch/helpers/vectorstore/_async/strategies.py @@ -283,10 +283,9 @@ def _hybrid( ) -> Dict[str, Any]: # Add a query to the knn query. # RRF is used to even the score from the knn query and text query - # RRF has two optional parameters: {'rank_constant':int, 'window_size':int} + # RRF has two optional parameters: {'rank_constant':int, 'rank_window_size':int} # https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html - query_body = { - "knn": knn, + standard_query = { "query": { "bool": { "must": [ @@ -300,14 +299,36 @@ def _hybrid( ], "filter": filter, } - }, + } } - if isinstance(self.rrf, Dict): - query_body["rank"] = {"rrf": self.rrf} - elif isinstance(self.rrf, bool) and self.rrf is True: - query_body["rank"] = {"rrf": {}} - + if self.rrf is False: + query_body = { + "knn": knn, + **standard_query, + } + else: + rrf_options = {} + if isinstance(self.rrf, Dict): + if "rank_constant" in self.rrf: + rrf_options["rank_constant"] = self.rrf["rank_constant"] + if "window_size" in self.rrf: + # 'window_size' was renamed to 'rank_window_size', but we support + # the older name for backwards compatibility + rrf_options["rank_window_size"] = self.rrf["window_size"] + if "rank_window_size" in self.rrf: + rrf_options["rank_window_size"] = self.rrf["rank_window_size"] + query_body = { + "retriever": { + "rrf": { + "retrievers": [ + {"standard": standard_query}, + {"knn": knn}, + ], + **rrf_options, + }, + }, + } return query_body def needs_inference(self) -> bool: diff --git a/elasticsearch/helpers/vectorstore/_sync/strategies.py b/elasticsearch/helpers/vectorstore/_sync/strategies.py index 928d34143..99c9baec2 100644 --- a/elasticsearch/helpers/vectorstore/_sync/strategies.py +++ b/elasticsearch/helpers/vectorstore/_sync/strategies.py @@ -283,10 +283,9 @@ def _hybrid( ) -> Dict[str, Any]: # Add a query to the knn query. # RRF is used to even the score from the knn query and text query - # RRF has two optional parameters: {'rank_constant':int, 'window_size':int} + # RRF has two optional parameters: {'rank_constant':int, 'rank_window_size':int} # https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html - query_body = { - "knn": knn, + standard_query = { "query": { "bool": { "must": [ @@ -300,14 +299,36 @@ def _hybrid( ], "filter": filter, } - }, + } } - if isinstance(self.rrf, Dict): - query_body["rank"] = {"rrf": self.rrf} - elif isinstance(self.rrf, bool) and self.rrf is True: - query_body["rank"] = {"rrf": {}} - + if self.rrf is False: + query_body = { + "knn": knn, + **standard_query, + } + else: + rrf_options = {} + if isinstance(self.rrf, Dict): + if "rank_constant" in self.rrf: + rrf_options["rank_constant"] = self.rrf["rank_constant"] + if "window_size" in self.rrf: + # 'window_size' was renamed to 'rank_window_size', but we support + # the older name for backwards compatibility + rrf_options["rank_window_size"] = self.rrf["window_size"] + if "rank_window_size" in self.rrf: + rrf_options["rank_window_size"] = self.rrf["rank_window_size"] + query_body = { + "retriever": { + "rrf": { + "retrievers": [ + {"standard": standard_query}, + {"knn": knn}, + ], + **rrf_options, + }, + }, + } return query_body def needs_inference(self) -> bool: diff --git a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py index 820746acd..096beaef5 100644 --- a/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py +++ b/test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py @@ -33,6 +33,7 @@ VectorStore, ) from elasticsearch.helpers.vectorstore._sync._utils import model_is_deployed +from test_elasticsearch.utils import es_version from . import ConsistentFakeEmbeddings, FakeEmbeddings @@ -337,6 +338,9 @@ def test_search_knn_with_hybrid_search( self, sync_client: Elasticsearch, index: str ) -> None: """Test end to end construction and search with metadata.""" + if es_version(sync_client) < (8, 14): + pytest.skip("This test requires Elasticsearch 8.14 or newer") + store = VectorStore( index=index, retrieval_strategy=DenseVectorStrategy(hybrid=True), @@ -349,20 +353,48 @@ def test_search_knn_with_hybrid_search( def assert_query(query_body: dict, query: Optional[str]) -> dict: assert query_body == { - "knn": { - "field": "vector_field", - "filter": [], - "k": 1, - "num_candidates": 50, - "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - }, - "query": { - "bool": { - "filter": [], - "must": [{"match": {"text_field": {"query": "foo"}}}], + "retriever": { + "rrf": { + "retrievers": [ + { + "standard": { + "query": { + "bool": { + "filter": [], + "must": [ + { + "match": { + "text_field": {"query": "foo"} + } + } + ], + } + }, + }, + }, + { + "knn": { + "field": "vector_field", + "filter": [], + "k": 1, + "num_candidates": 50, + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ], + }, + }, + ], } - }, - "rank": {"rrf": {}}, + } } return query_body @@ -373,6 +405,9 @@ def test_search_knn_with_hybrid_search_rrf( self, sync_client: Elasticsearch, index: str ) -> None: """Test end to end construction and rrf hybrid search with metadata.""" + if es_version(sync_client) < (8, 14): + pytest.skip("This test requires Elasticsearch 8.14 or newer") + texts = ["foo", "bar", "baz"] def assert_query( @@ -380,48 +415,67 @@ def assert_query( query: Optional[str], expected_rrf: Union[dict, bool], ) -> dict: - cmp_query_body = { - "knn": { - "field": "vector_field", - "filter": [], - "k": 3, - "num_candidates": 50, - "query_vector": [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 0.0, - ], - }, + standard_query = { "query": { "bool": { "filter": [], "must": [{"match": {"text_field": {"query": "foo"}}}], } - }, + } + } + knn_query = { + "field": "vector_field", + "filter": [], + "k": 3, + "num_candidates": 50, + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ], } - if isinstance(expected_rrf, dict): - cmp_query_body["rank"] = {"rrf": expected_rrf} - elif isinstance(expected_rrf, bool) and expected_rrf is True: - cmp_query_body["rank"] = {"rrf": {}} + if expected_rrf is not False: + cmp_query_body = { + "retriever": { + "rrf": { + "retrievers": [ + {"standard": standard_query}, + {"knn": knn_query}, + ], + } + } + } + if isinstance(expected_rrf, dict): + cmp_query_body["retriever"]["rrf"].update(expected_rrf) + else: + cmp_query_body = { + "knn": knn_query, + **standard_query, + } assert query_body == cmp_query_body return query_body # 1. check query_body is okay - rrf_test_cases: List[Union[dict, bool]] = [ - True, - False, - {"rank_constant": 1, "window_size": 5}, - ] + if es_version(sync_client) >= (8, 14): + rrf_test_cases: List[Union[dict, bool]] = [ + True, + False, + {"rank_constant": 1, "rank_window_size": 5}, + ] + else: + # for 8.13.x and older there is no retriever query, so we can only + # run hybrid searches with rrf=False + rrf_test_cases: List[Union[dict, bool]] = [False] for rrf_test_case in rrf_test_cases: store = VectorStore( index=index, @@ -441,21 +495,47 @@ def assert_query( # 2. check query result is okay es_output = store.client.search( index=index, - query={ - "bool": { - "filter": [], - "must": [{"match": {"text_field": {"query": "foo"}}}], + retriever={ + "rrf": { + "retrievers": [ + { + "knn": { + "field": "vector_field", + "filter": [], + "k": 3, + "num_candidates": 50, + "query_vector": [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + ], + }, + }, + { + "standard": { + "query": { + "bool": { + "filter": [], + "must": [ + {"match": {"text_field": {"query": "foo"}}} + ], + } + }, + }, + }, + ], + "rank_constant": 1, + "rank_window_size": 5, } }, - knn={ - "field": "vector_field", - "filter": [], - "k": 3, - "num_candidates": 50, - "query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0], - }, size=3, - rank={"rrf": {"rank_constant": 1, "window_size": 5}}, ) assert [o["_source"]["text_field"] for o in output] == [