Skip to content

Commit 32ab831

Browse files
Vectorstore: use a retriever query for hybrid search
Fixes #2651
1 parent 14e6265 commit 32ab831

File tree

3 files changed

+190
-90
lines changed

3 files changed

+190
-90
lines changed

Diff for: elasticsearch/helpers/vectorstore/_async/strategies.py

+34-19
Original file line numberDiff line numberDiff line change
@@ -283,31 +283,46 @@ def _hybrid(
283283
) -> Dict[str, Any]:
284284
# Add a query to the knn query.
285285
# RRF is used to even the score from the knn query and text query
286-
# RRF has two optional parameters: {'rank_constant':int, 'window_size':int}
286+
# RRF has two optional parameters: {'rank_constant':int, 'rank_window_size':int}
287287
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
288+
rrf_options = {}
289+
if isinstance(self.rrf, Dict):
290+
if "rank_constant" in self.rrf:
291+
rrf_options["rank_constant"] = self.rrf["rank_constant"]
292+
if "window_size" in self.rrf:
293+
# 'window_size' was renamed to 'rank_window_size', but we support
294+
# the older name for backwards compatiblit
295+
rrf_options["rank_window_size"] = self.rrf["window_size"]
296+
if "rank_window_size" in self.rrf:
297+
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
288298
query_body = {
289-
"knn": knn,
290-
"query": {
291-
"bool": {
292-
"must": [
299+
"retriever": {
300+
"rrf": {
301+
"retrievers": [
293302
{
294-
"match": {
295-
self.text_field: {
296-
"query": query,
297-
}
298-
}
299-
}
303+
"standard": {
304+
"query": {
305+
"bool": {
306+
"must": [
307+
{
308+
"match": {
309+
self.text_field: {
310+
"query": query,
311+
}
312+
}
313+
}
314+
],
315+
"filter": filter,
316+
}
317+
},
318+
},
319+
},
320+
{"knn": knn},
300321
],
301-
"filter": filter,
302-
}
322+
**rrf_options,
323+
},
303324
},
304325
}
305-
306-
if isinstance(self.rrf, Dict):
307-
query_body["rank"] = {"rrf": self.rrf}
308-
elif isinstance(self.rrf, bool) and self.rrf is True:
309-
query_body["rank"] = {"rrf": {}}
310-
311326
return query_body
312327

313328
def needs_inference(self) -> bool:

Diff for: elasticsearch/helpers/vectorstore/_sync/strategies.py

+34-19
Original file line numberDiff line numberDiff line change
@@ -283,31 +283,46 @@ def _hybrid(
283283
) -> Dict[str, Any]:
284284
# Add a query to the knn query.
285285
# RRF is used to even the score from the knn query and text query
286-
# RRF has two optional parameters: {'rank_constant':int, 'window_size':int}
286+
# RRF has two optional parameters: {'rank_constant':int, 'rank_window_size':int}
287287
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
288+
rrf_options = {}
289+
if isinstance(self.rrf, Dict):
290+
if "rank_constant" in self.rrf:
291+
rrf_options["rank_constant"] = self.rrf["rank_constant"]
292+
if "window_size" in self.rrf:
293+
# 'window_size' was renamed to 'rank_window_size', but we support
294+
# the older name for backwards compatiblit
295+
rrf_options["rank_window_size"] = self.rrf["window_size"]
296+
if "rank_window_size" in self.rrf:
297+
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
288298
query_body = {
289-
"knn": knn,
290-
"query": {
291-
"bool": {
292-
"must": [
299+
"retriever": {
300+
"rrf": {
301+
"retrievers": [
293302
{
294-
"match": {
295-
self.text_field: {
296-
"query": query,
297-
}
298-
}
299-
}
303+
"standard": {
304+
"query": {
305+
"bool": {
306+
"must": [
307+
{
308+
"match": {
309+
self.text_field: {
310+
"query": query,
311+
}
312+
}
313+
}
314+
],
315+
"filter": filter,
316+
}
317+
},
318+
},
319+
},
320+
{"knn": knn},
300321
],
301-
"filter": filter,
302-
}
322+
**rrf_options,
323+
},
303324
},
304325
}
305-
306-
if isinstance(self.rrf, Dict):
307-
query_body["rank"] = {"rrf": self.rrf}
308-
elif isinstance(self.rrf, bool) and self.rrf is True:
309-
query_body["rank"] = {"rrf": {}}
310-
311326
return query_body
312327

313328
def needs_inference(self) -> bool:

Diff for: test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py

+122-52
Original file line numberDiff line numberDiff line change
@@ -349,20 +349,48 @@ def test_search_knn_with_hybrid_search(
349349

350350
def assert_query(query_body: dict, query: Optional[str]) -> dict:
351351
assert query_body == {
352-
"knn": {
353-
"field": "vector_field",
354-
"filter": [],
355-
"k": 1,
356-
"num_candidates": 50,
357-
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
358-
},
359-
"query": {
360-
"bool": {
361-
"filter": [],
362-
"must": [{"match": {"text_field": {"query": "foo"}}}],
352+
"retriever": {
353+
"rrf": {
354+
"retrievers": [
355+
{
356+
"standard": {
357+
"query": {
358+
"bool": {
359+
"filter": [],
360+
"must": [
361+
{
362+
"match": {
363+
"text_field": {"query": "foo"}
364+
}
365+
}
366+
],
367+
}
368+
},
369+
},
370+
},
371+
{
372+
"knn": {
373+
"field": "vector_field",
374+
"filter": [],
375+
"k": 1,
376+
"num_candidates": 50,
377+
"query_vector": [
378+
1.0,
379+
1.0,
380+
1.0,
381+
1.0,
382+
1.0,
383+
1.0,
384+
1.0,
385+
1.0,
386+
1.0,
387+
0.0,
388+
],
389+
},
390+
},
391+
],
363392
}
364-
},
365-
"rank": {"rrf": {}},
393+
}
366394
}
367395
return query_body
368396

@@ -381,36 +409,52 @@ def assert_query(
381409
expected_rrf: Union[dict, bool],
382410
) -> dict:
383411
cmp_query_body = {
384-
"knn": {
385-
"field": "vector_field",
386-
"filter": [],
387-
"k": 3,
388-
"num_candidates": 50,
389-
"query_vector": [
390-
1.0,
391-
1.0,
392-
1.0,
393-
1.0,
394-
1.0,
395-
1.0,
396-
1.0,
397-
1.0,
398-
1.0,
399-
0.0,
400-
],
401-
},
402-
"query": {
403-
"bool": {
404-
"filter": [],
405-
"must": [{"match": {"text_field": {"query": "foo"}}}],
412+
"retriever": {
413+
"rrf": {
414+
"retrievers": [
415+
{
416+
"standard": {
417+
"query": {
418+
"bool": {
419+
"filter": [],
420+
"must": [
421+
{
422+
"match": {
423+
"text_field": {"query": "foo"}
424+
}
425+
}
426+
],
427+
}
428+
},
429+
},
430+
},
431+
{
432+
"knn": {
433+
"field": "vector_field",
434+
"filter": [],
435+
"k": 3,
436+
"num_candidates": 50,
437+
"query_vector": [
438+
1.0,
439+
1.0,
440+
1.0,
441+
1.0,
442+
1.0,
443+
1.0,
444+
1.0,
445+
1.0,
446+
1.0,
447+
0.0,
448+
],
449+
},
450+
},
451+
],
406452
}
407-
},
453+
}
408454
}
409455

410456
if isinstance(expected_rrf, dict):
411-
cmp_query_body["rank"] = {"rrf": expected_rrf}
412-
elif isinstance(expected_rrf, bool) and expected_rrf is True:
413-
cmp_query_body["rank"] = {"rrf": {}}
457+
cmp_query_body["retriever"]["rrf"].update(expected_rrf)
414458

415459
assert query_body == cmp_query_body
416460

@@ -420,7 +464,7 @@ def assert_query(
420464
rrf_test_cases: List[Union[dict, bool]] = [
421465
True,
422466
False,
423-
{"rank_constant": 1, "window_size": 5},
467+
{"rank_constant": 1, "rank_window_size": 5},
424468
]
425469
for rrf_test_case in rrf_test_cases:
426470
store = VectorStore(
@@ -441,21 +485,47 @@ def assert_query(
441485
# 2. check query result is okay
442486
es_output = store.client.search(
443487
index=index,
444-
query={
445-
"bool": {
446-
"filter": [],
447-
"must": [{"match": {"text_field": {"query": "foo"}}}],
488+
retriever={
489+
"rrf": {
490+
"retrievers": [
491+
{
492+
"knn": {
493+
"field": "vector_field",
494+
"filter": [],
495+
"k": 3,
496+
"num_candidates": 50,
497+
"query_vector": [
498+
1.0,
499+
1.0,
500+
1.0,
501+
1.0,
502+
1.0,
503+
1.0,
504+
1.0,
505+
1.0,
506+
1.0,
507+
0.0,
508+
],
509+
},
510+
},
511+
{
512+
"standard": {
513+
"query": {
514+
"bool": {
515+
"filter": [],
516+
"must": [
517+
{"match": {"text_field": {"query": "foo"}}}
518+
],
519+
}
520+
},
521+
},
522+
},
523+
],
524+
"rank_constant": 1,
525+
"rank_window_size": 5,
448526
}
449527
},
450-
knn={
451-
"field": "vector_field",
452-
"filter": [],
453-
"k": 3,
454-
"num_candidates": 50,
455-
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
456-
},
457528
size=3,
458-
rank={"rrf": {"rank_constant": 1, "window_size": 5}},
459529
)
460530

461531
assert [o["_source"]["text_field"] for o in output] == [

0 commit comments

Comments
 (0)