Skip to content

Commit e675035

Browse files
add support for rrf=False back
1 parent 203f678 commit e675035

File tree

3 files changed

+133
-121
lines changed

3 files changed

+133
-121
lines changed

Diff for: elasticsearch/helpers/vectorstore/_async/strategies.py

+41-35
Original file line numberDiff line numberDiff line change
@@ -285,44 +285,50 @@ def _hybrid(
285285
# RRF is used to even the score from the knn query and text query
286286
# RRF has two optional parameters: {'rank_constant':int, 'rank_window_size':int}
287287
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
288-
rrf_options = {}
289-
if isinstance(self.rrf, Dict):
290-
if "rank_constant" in self.rrf:
291-
rrf_options["rank_constant"] = self.rrf["rank_constant"]
292-
if "window_size" in self.rrf:
293-
# 'window_size' was renamed to 'rank_window_size', but we support
294-
# the older name for backwards compatibility
295-
rrf_options["rank_window_size"] = self.rrf["window_size"]
296-
if "rank_window_size" in self.rrf:
297-
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
298-
query_body = {
299-
"retriever": {
300-
"rrf": {
301-
"retrievers": [
288+
standard_query = {
289+
"query": {
290+
"bool": {
291+
"must": [
302292
{
303-
"standard": {
304-
"query": {
305-
"bool": {
306-
"must": [
307-
{
308-
"match": {
309-
self.text_field: {
310-
"query": query,
311-
}
312-
}
313-
}
314-
],
315-
"filter": filter,
316-
}
317-
},
318-
},
319-
},
320-
{"knn": knn},
293+
"match": {
294+
self.text_field: {
295+
"query": query,
296+
}
297+
}
298+
}
321299
],
322-
**rrf_options,
323-
},
324-
},
300+
"filter": filter,
301+
}
302+
}
325303
}
304+
305+
if self.rrf is False:
306+
query_body = {
307+
"knn": knn,
308+
**standard_query,
309+
}
310+
else:
311+
rrf_options = {}
312+
if isinstance(self.rrf, Dict):
313+
if "rank_constant" in self.rrf:
314+
rrf_options["rank_constant"] = self.rrf["rank_constant"]
315+
if "window_size" in self.rrf:
316+
# 'window_size' was renamed to 'rank_window_size', but we support
317+
# the older name for backwards compatibility
318+
rrf_options["rank_window_size"] = self.rrf["window_size"]
319+
if "rank_window_size" in self.rrf:
320+
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
321+
query_body = {
322+
"retriever": {
323+
"rrf": {
324+
"retrievers": [
325+
{"standard": standard_query},
326+
{"knn": knn},
327+
],
328+
**rrf_options,
329+
},
330+
},
331+
}
326332
return query_body
327333

328334
def needs_inference(self) -> bool:

Diff for: elasticsearch/helpers/vectorstore/_sync/strategies.py

+41-35
Original file line numberDiff line numberDiff line change
@@ -285,44 +285,50 @@ def _hybrid(
285285
# RRF is used to even the score from the knn query and text query
286286
# RRF has two optional parameters: {'rank_constant':int, 'rank_window_size':int}
287287
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
288-
rrf_options = {}
289-
if isinstance(self.rrf, Dict):
290-
if "rank_constant" in self.rrf:
291-
rrf_options["rank_constant"] = self.rrf["rank_constant"]
292-
if "window_size" in self.rrf:
293-
# 'window_size' was renamed to 'rank_window_size', but we support
294-
# the older name for backwards compatibility
295-
rrf_options["rank_window_size"] = self.rrf["window_size"]
296-
if "rank_window_size" in self.rrf:
297-
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
298-
query_body = {
299-
"retriever": {
300-
"rrf": {
301-
"retrievers": [
288+
standard_query = {
289+
"query": {
290+
"bool": {
291+
"must": [
302292
{
303-
"standard": {
304-
"query": {
305-
"bool": {
306-
"must": [
307-
{
308-
"match": {
309-
self.text_field: {
310-
"query": query,
311-
}
312-
}
313-
}
314-
],
315-
"filter": filter,
316-
}
317-
},
318-
},
319-
},
320-
{"knn": knn},
293+
"match": {
294+
self.text_field: {
295+
"query": query,
296+
}
297+
}
298+
}
321299
],
322-
**rrf_options,
323-
},
324-
},
300+
"filter": filter,
301+
}
302+
}
325303
}
304+
305+
if self.rrf is False:
306+
query_body = {
307+
"knn": knn,
308+
**standard_query,
309+
}
310+
else:
311+
rrf_options = {}
312+
if isinstance(self.rrf, Dict):
313+
if "rank_constant" in self.rrf:
314+
rrf_options["rank_constant"] = self.rrf["rank_constant"]
315+
if "window_size" in self.rrf:
316+
# 'window_size' was renamed to 'rank_window_size', but we support
317+
# the older name for backwards compatibility
318+
rrf_options["rank_window_size"] = self.rrf["window_size"]
319+
if "rank_window_size" in self.rrf:
320+
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
321+
query_body = {
322+
"retriever": {
323+
"rrf": {
324+
"retrievers": [
325+
{"standard": standard_query},
326+
{"knn": knn},
327+
],
328+
**rrf_options,
329+
},
330+
},
331+
}
326332
return query_body
327333

328334
def needs_inference(self) -> bool:

Diff for: test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py

+51-51
Original file line numberDiff line numberDiff line change
@@ -405,74 +405,74 @@ def test_search_knn_with_hybrid_search_rrf(
405405
self, sync_client: Elasticsearch, index: str
406406
) -> None:
407407
"""Test end to end construction and rrf hybrid search with metadata."""
408-
if es_version(sync_client) < (8, 14):
409-
pytest.skip("This test requires Elasticsearch 8.14 or newer")
410-
411408
texts = ["foo", "bar", "baz"]
412409

413410
def assert_query(
414411
query_body: dict,
415412
query: Optional[str],
416413
expected_rrf: Union[dict, bool],
417414
) -> dict:
418-
cmp_query_body = {
419-
"retriever": {
420-
"rrf": {
421-
"retrievers": [
422-
{
423-
"standard": {
424-
"query": {
425-
"bool": {
426-
"filter": [],
427-
"must": [
428-
{
429-
"match": {
430-
"text_field": {"query": "foo"}
431-
}
432-
}
433-
],
434-
}
435-
},
436-
},
437-
},
438-
{
439-
"knn": {
440-
"field": "vector_field",
441-
"filter": [],
442-
"k": 3,
443-
"num_candidates": 50,
444-
"query_vector": [
445-
1.0,
446-
1.0,
447-
1.0,
448-
1.0,
449-
1.0,
450-
1.0,
451-
1.0,
452-
1.0,
453-
1.0,
454-
0.0,
455-
],
456-
},
457-
},
458-
],
415+
standard_query = {
416+
"query": {
417+
"bool": {
418+
"filter": [],
419+
"must": [{"match": {"text_field": {"query": "foo"}}}],
459420
}
460421
}
461422
}
423+
knn_query = {
424+
"field": "vector_field",
425+
"filter": [],
426+
"k": 3,
427+
"num_candidates": 50,
428+
"query_vector": [
429+
1.0,
430+
1.0,
431+
1.0,
432+
1.0,
433+
1.0,
434+
1.0,
435+
1.0,
436+
1.0,
437+
1.0,
438+
0.0,
439+
],
440+
}
462441

463-
if isinstance(expected_rrf, dict):
464-
cmp_query_body["retriever"]["rrf"].update(expected_rrf)
442+
if expected_rrf is not False:
443+
cmp_query_body = {
444+
"retriever": {
445+
"rrf": {
446+
"retrievers": [
447+
{"standard": standard_query},
448+
{"knn": knn_query},
449+
],
450+
}
451+
}
452+
}
453+
if isinstance(expected_rrf, dict):
454+
cmp_query_body["retriever"]["rrf"].update(expected_rrf)
455+
else:
456+
cmp_query_body = {
457+
"knn": knn_query,
458+
**standard_query,
459+
}
465460

466461
assert query_body == cmp_query_body
467462

468463
return query_body
469464

470465
# 1. check query_body is okay
471-
rrf_test_cases: List[Union[dict, bool]] = [
472-
True,
473-
False,
474-
{"rank_constant": 1, "rank_window_size": 5},
475-
]
466+
if es_version(sync_client) >= (8, 14):
467+
rrf_test_cases: List[Union[dict, bool]] = [
468+
True,
469+
False,
470+
{"rank_constant": 1, "rank_window_size": 5},
471+
]
472+
else:
473+
# for 8.13.x and older there is no retriever query, so we can only
474+
# run hybrid searches with rrf=False
475+
rrf_test_cases: List[Union[dict, bool]] = [False]
476476
for rrf_test_case in rrf_test_cases:
477477
store = VectorStore(
478478
index=index,

0 commit comments

Comments
 (0)