Skip to content

Commit e22de7e

Browse files
Vectorstore: use a retriever query for hybrid search (#2666)
* Vectorstore: use a retriever query for hybrid search Fixes #2651 * only run hybrid search tests when using a stack version >= 8.14 * add support for rrf=False back
1 parent 14e6265 commit e22de7e

File tree

3 files changed

+194
-72
lines changed

3 files changed

+194
-72
lines changed

Diff for: elasticsearch/helpers/vectorstore/_async/strategies.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,9 @@ def _hybrid(
283283
) -> Dict[str, Any]:
284284
# Add a query to the knn query.
285285
# RRF is used to even the score from the knn query and text query
286-
# RRF has two optional parameters: {'rank_constant':int, 'window_size':int}
286+
# RRF has two optional parameters: {'rank_constant':int, 'rank_window_size':int}
287287
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
288-
query_body = {
289-
"knn": knn,
288+
standard_query = {
290289
"query": {
291290
"bool": {
292291
"must": [
@@ -300,14 +299,36 @@ def _hybrid(
300299
],
301300
"filter": filter,
302301
}
303-
},
302+
}
304303
}
305304

306-
if isinstance(self.rrf, Dict):
307-
query_body["rank"] = {"rrf": self.rrf}
308-
elif isinstance(self.rrf, bool) and self.rrf is True:
309-
query_body["rank"] = {"rrf": {}}
310-
305+
if self.rrf is False:
306+
query_body = {
307+
"knn": knn,
308+
**standard_query,
309+
}
310+
else:
311+
rrf_options = {}
312+
if isinstance(self.rrf, Dict):
313+
if "rank_constant" in self.rrf:
314+
rrf_options["rank_constant"] = self.rrf["rank_constant"]
315+
if "window_size" in self.rrf:
316+
# 'window_size' was renamed to 'rank_window_size', but we support
317+
# the older name for backwards compatibility
318+
rrf_options["rank_window_size"] = self.rrf["window_size"]
319+
if "rank_window_size" in self.rrf:
320+
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
321+
query_body = {
322+
"retriever": {
323+
"rrf": {
324+
"retrievers": [
325+
{"standard": standard_query},
326+
{"knn": knn},
327+
],
328+
**rrf_options,
329+
},
330+
},
331+
}
311332
return query_body
312333

313334
def needs_inference(self) -> bool:

Diff for: elasticsearch/helpers/vectorstore/_sync/strategies.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,9 @@ def _hybrid(
283283
) -> Dict[str, Any]:
284284
# Add a query to the knn query.
285285
# RRF is used to even the score from the knn query and text query
286-
# RRF has two optional parameters: {'rank_constant':int, 'window_size':int}
286+
# RRF has two optional parameters: {'rank_constant':int, 'rank_window_size':int}
287287
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
288-
query_body = {
289-
"knn": knn,
288+
standard_query = {
290289
"query": {
291290
"bool": {
292291
"must": [
@@ -300,14 +299,36 @@ def _hybrid(
300299
],
301300
"filter": filter,
302301
}
303-
},
302+
}
304303
}
305304

306-
if isinstance(self.rrf, Dict):
307-
query_body["rank"] = {"rrf": self.rrf}
308-
elif isinstance(self.rrf, bool) and self.rrf is True:
309-
query_body["rank"] = {"rrf": {}}
310-
305+
if self.rrf is False:
306+
query_body = {
307+
"knn": knn,
308+
**standard_query,
309+
}
310+
else:
311+
rrf_options = {}
312+
if isinstance(self.rrf, Dict):
313+
if "rank_constant" in self.rrf:
314+
rrf_options["rank_constant"] = self.rrf["rank_constant"]
315+
if "window_size" in self.rrf:
316+
# 'window_size' was renamed to 'rank_window_size', but we support
317+
# the older name for backwards compatibility
318+
rrf_options["rank_window_size"] = self.rrf["window_size"]
319+
if "rank_window_size" in self.rrf:
320+
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
321+
query_body = {
322+
"retriever": {
323+
"rrf": {
324+
"retrievers": [
325+
{"standard": standard_query},
326+
{"knn": knn},
327+
],
328+
**rrf_options,
329+
},
330+
},
331+
}
311332
return query_body
312333

313334
def needs_inference(self) -> bool:

Diff for: test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py

+134-54
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
VectorStore,
3434
)
3535
from elasticsearch.helpers.vectorstore._sync._utils import model_is_deployed
36+
from test_elasticsearch.utils import es_version
3637

3738
from . import ConsistentFakeEmbeddings, FakeEmbeddings
3839

@@ -337,6 +338,9 @@ def test_search_knn_with_hybrid_search(
337338
self, sync_client: Elasticsearch, index: str
338339
) -> None:
339340
"""Test end to end construction and search with metadata."""
341+
if es_version(sync_client) < (8, 14):
342+
pytest.skip("This test requires Elasticsearch 8.14 or newer")
343+
340344
store = VectorStore(
341345
index=index,
342346
retrieval_strategy=DenseVectorStrategy(hybrid=True),
@@ -349,20 +353,48 @@ def test_search_knn_with_hybrid_search(
349353

350354
def assert_query(query_body: dict, query: Optional[str]) -> dict:
351355
assert query_body == {
352-
"knn": {
353-
"field": "vector_field",
354-
"filter": [],
355-
"k": 1,
356-
"num_candidates": 50,
357-
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
358-
},
359-
"query": {
360-
"bool": {
361-
"filter": [],
362-
"must": [{"match": {"text_field": {"query": "foo"}}}],
356+
"retriever": {
357+
"rrf": {
358+
"retrievers": [
359+
{
360+
"standard": {
361+
"query": {
362+
"bool": {
363+
"filter": [],
364+
"must": [
365+
{
366+
"match": {
367+
"text_field": {"query": "foo"}
368+
}
369+
}
370+
],
371+
}
372+
},
373+
},
374+
},
375+
{
376+
"knn": {
377+
"field": "vector_field",
378+
"filter": [],
379+
"k": 1,
380+
"num_candidates": 50,
381+
"query_vector": [
382+
1.0,
383+
1.0,
384+
1.0,
385+
1.0,
386+
1.0,
387+
1.0,
388+
1.0,
389+
1.0,
390+
1.0,
391+
0.0,
392+
],
393+
},
394+
},
395+
],
363396
}
364-
},
365-
"rank": {"rrf": {}},
397+
}
366398
}
367399
return query_body
368400

@@ -373,55 +405,77 @@ def test_search_knn_with_hybrid_search_rrf(
373405
self, sync_client: Elasticsearch, index: str
374406
) -> None:
375407
"""Test end to end construction and rrf hybrid search with metadata."""
408+
if es_version(sync_client) < (8, 14):
409+
pytest.skip("This test requires Elasticsearch 8.14 or newer")
410+
376411
texts = ["foo", "bar", "baz"]
377412

378413
def assert_query(
379414
query_body: dict,
380415
query: Optional[str],
381416
expected_rrf: Union[dict, bool],
382417
) -> dict:
383-
cmp_query_body = {
384-
"knn": {
385-
"field": "vector_field",
386-
"filter": [],
387-
"k": 3,
388-
"num_candidates": 50,
389-
"query_vector": [
390-
1.0,
391-
1.0,
392-
1.0,
393-
1.0,
394-
1.0,
395-
1.0,
396-
1.0,
397-
1.0,
398-
1.0,
399-
0.0,
400-
],
401-
},
418+
standard_query = {
402419
"query": {
403420
"bool": {
404421
"filter": [],
405422
"must": [{"match": {"text_field": {"query": "foo"}}}],
406423
}
407-
},
424+
}
425+
}
426+
knn_query = {
427+
"field": "vector_field",
428+
"filter": [],
429+
"k": 3,
430+
"num_candidates": 50,
431+
"query_vector": [
432+
1.0,
433+
1.0,
434+
1.0,
435+
1.0,
436+
1.0,
437+
1.0,
438+
1.0,
439+
1.0,
440+
1.0,
441+
0.0,
442+
],
408443
}
409444

410-
if isinstance(expected_rrf, dict):
411-
cmp_query_body["rank"] = {"rrf": expected_rrf}
412-
elif isinstance(expected_rrf, bool) and expected_rrf is True:
413-
cmp_query_body["rank"] = {"rrf": {}}
445+
if expected_rrf is not False:
446+
cmp_query_body = {
447+
"retriever": {
448+
"rrf": {
449+
"retrievers": [
450+
{"standard": standard_query},
451+
{"knn": knn_query},
452+
],
453+
}
454+
}
455+
}
456+
if isinstance(expected_rrf, dict):
457+
cmp_query_body["retriever"]["rrf"].update(expected_rrf)
458+
else:
459+
cmp_query_body = {
460+
"knn": knn_query,
461+
**standard_query,
462+
}
414463

415464
assert query_body == cmp_query_body
416465

417466
return query_body
418467

419468
# 1. check query_body is okay
420-
rrf_test_cases: List[Union[dict, bool]] = [
421-
True,
422-
False,
423-
{"rank_constant": 1, "window_size": 5},
424-
]
469+
if es_version(sync_client) >= (8, 14):
470+
rrf_test_cases: List[Union[dict, bool]] = [
471+
True,
472+
False,
473+
{"rank_constant": 1, "rank_window_size": 5},
474+
]
475+
else:
476+
# for 8.13.x and older there is no retriever query, so we can only
477+
# run hybrid searches with rrf=False
478+
rrf_test_cases: List[Union[dict, bool]] = [False]
425479
for rrf_test_case in rrf_test_cases:
426480
store = VectorStore(
427481
index=index,
@@ -441,21 +495,47 @@ def assert_query(
441495
# 2. check query result is okay
442496
es_output = store.client.search(
443497
index=index,
444-
query={
445-
"bool": {
446-
"filter": [],
447-
"must": [{"match": {"text_field": {"query": "foo"}}}],
498+
retriever={
499+
"rrf": {
500+
"retrievers": [
501+
{
502+
"knn": {
503+
"field": "vector_field",
504+
"filter": [],
505+
"k": 3,
506+
"num_candidates": 50,
507+
"query_vector": [
508+
1.0,
509+
1.0,
510+
1.0,
511+
1.0,
512+
1.0,
513+
1.0,
514+
1.0,
515+
1.0,
516+
1.0,
517+
0.0,
518+
],
519+
},
520+
},
521+
{
522+
"standard": {
523+
"query": {
524+
"bool": {
525+
"filter": [],
526+
"must": [
527+
{"match": {"text_field": {"query": "foo"}}}
528+
],
529+
}
530+
},
531+
},
532+
},
533+
],
534+
"rank_constant": 1,
535+
"rank_window_size": 5,
448536
}
449537
},
450-
knn={
451-
"field": "vector_field",
452-
"filter": [],
453-
"k": 3,
454-
"num_candidates": 50,
455-
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
456-
},
457538
size=3,
458-
rank={"rrf": {"rank_constant": 1, "window_size": 5}},
459539
)
460540

461541
assert [o["_source"]["text_field"] for o in output] == [

0 commit comments

Comments
 (0)