Skip to content

Commit b9bf952

Browse files
add support for rrf=False back
1 parent 203f678 commit b9bf952

File tree

3 files changed

+133
-118
lines changed

3 files changed

+133
-118
lines changed

Diff for: elasticsearch/helpers/vectorstore/_async/strategies.py

+41-35
Original file line numberDiff line numberDiff line change
@@ -285,44 +285,50 @@ def _hybrid(
285285
# RRF is used to even the score from the knn query and text query
286286
# RRF has two optional parameters: {'rank_constant':int, 'rank_window_size':int}
287287
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
288-
rrf_options = {}
289-
if isinstance(self.rrf, Dict):
290-
if "rank_constant" in self.rrf:
291-
rrf_options["rank_constant"] = self.rrf["rank_constant"]
292-
if "window_size" in self.rrf:
293-
# 'window_size' was renamed to 'rank_window_size', but we support
294-
# the older name for backwards compatibility
295-
rrf_options["rank_window_size"] = self.rrf["window_size"]
296-
if "rank_window_size" in self.rrf:
297-
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
298-
query_body = {
299-
"retriever": {
300-
"rrf": {
301-
"retrievers": [
288+
standard_query = {
289+
"query": {
290+
"bool": {
291+
"must": [
302292
{
303-
"standard": {
304-
"query": {
305-
"bool": {
306-
"must": [
307-
{
308-
"match": {
309-
self.text_field: {
310-
"query": query,
311-
}
312-
}
313-
}
314-
],
315-
"filter": filter,
316-
}
317-
},
318-
},
319-
},
320-
{"knn": knn},
293+
"match": {
294+
self.text_field: {
295+
"query": query,
296+
}
297+
}
298+
}
321299
],
322-
**rrf_options,
323-
},
324-
},
300+
"filter": filter,
301+
}
302+
}
325303
}
304+
305+
if self.rrf is False:
306+
query_body = {
307+
"knn": knn,
308+
**standard_query,
309+
}
310+
else:
311+
rrf_options = {}
312+
if isinstance(self.rrf, Dict):
313+
if "rank_constant" in self.rrf:
314+
rrf_options["rank_constant"] = self.rrf["rank_constant"]
315+
if "window_size" in self.rrf:
316+
# 'window_size' was renamed to 'rank_window_size', but we support
317+
# the older name for backwards compatibility
318+
rrf_options["rank_window_size"] = self.rrf["window_size"]
319+
if "rank_window_size" in self.rrf:
320+
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
321+
query_body = {
322+
"retriever": {
323+
"rrf": {
324+
"retrievers": [
325+
{"standard": standard_query},
326+
{"knn": knn},
327+
],
328+
**rrf_options,
329+
},
330+
},
331+
}
326332
return query_body
327333

328334
def needs_inference(self) -> bool:

Diff for: elasticsearch/helpers/vectorstore/_sync/strategies.py

+41-35
Original file line numberDiff line numberDiff line change
@@ -285,44 +285,50 @@ def _hybrid(
285285
# RRF is used to even the score from the knn query and text query
286286
# RRF has two optional parameters: {'rank_constant':int, 'rank_window_size':int}
287287
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
288-
rrf_options = {}
289-
if isinstance(self.rrf, Dict):
290-
if "rank_constant" in self.rrf:
291-
rrf_options["rank_constant"] = self.rrf["rank_constant"]
292-
if "window_size" in self.rrf:
293-
# 'window_size' was renamed to 'rank_window_size', but we support
294-
# the older name for backwards compatibility
295-
rrf_options["rank_window_size"] = self.rrf["window_size"]
296-
if "rank_window_size" in self.rrf:
297-
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
298-
query_body = {
299-
"retriever": {
300-
"rrf": {
301-
"retrievers": [
288+
standard_query = {
289+
"query": {
290+
"bool": {
291+
"must": [
302292
{
303-
"standard": {
304-
"query": {
305-
"bool": {
306-
"must": [
307-
{
308-
"match": {
309-
self.text_field: {
310-
"query": query,
311-
}
312-
}
313-
}
314-
],
315-
"filter": filter,
316-
}
317-
},
318-
},
319-
},
320-
{"knn": knn},
293+
"match": {
294+
self.text_field: {
295+
"query": query,
296+
}
297+
}
298+
}
321299
],
322-
**rrf_options,
323-
},
324-
},
300+
"filter": filter,
301+
}
302+
}
325303
}
304+
305+
if self.rrf is False:
306+
query_body = {
307+
"knn": knn,
308+
**standard_query,
309+
}
310+
else:
311+
rrf_options = {}
312+
if isinstance(self.rrf, Dict):
313+
if "rank_constant" in self.rrf:
314+
rrf_options["rank_constant"] = self.rrf["rank_constant"]
315+
if "window_size" in self.rrf:
316+
# 'window_size' was renamed to 'rank_window_size', but we support
317+
# the older name for backwards compatibility
318+
rrf_options["rank_window_size"] = self.rrf["window_size"]
319+
if "rank_window_size" in self.rrf:
320+
rrf_options["rank_window_size"] = self.rrf["rank_window_size"]
321+
query_body = {
322+
"retriever": {
323+
"rrf": {
324+
"retrievers": [
325+
{"standard": standard_query},
326+
{"knn": knn},
327+
],
328+
**rrf_options,
329+
},
330+
},
331+
}
326332
return query_body
327333

328334
def needs_inference(self) -> bool:

Diff for: test_elasticsearch/test_server/test_vectorstore/test_vectorstore.py

+51-48
Original file line numberDiff line numberDiff line change
@@ -415,64 +415,67 @@ def assert_query(
415415
query: Optional[str],
416416
expected_rrf: Union[dict, bool],
417417
) -> dict:
418-
cmp_query_body = {
419-
"retriever": {
420-
"rrf": {
421-
"retrievers": [
422-
{
423-
"standard": {
424-
"query": {
425-
"bool": {
426-
"filter": [],
427-
"must": [
428-
{
429-
"match": {
430-
"text_field": {"query": "foo"}
431-
}
432-
}
433-
],
434-
}
435-
},
436-
},
437-
},
438-
{
439-
"knn": {
440-
"field": "vector_field",
441-
"filter": [],
442-
"k": 3,
443-
"num_candidates": 50,
444-
"query_vector": [
445-
1.0,
446-
1.0,
447-
1.0,
448-
1.0,
449-
1.0,
450-
1.0,
451-
1.0,
452-
1.0,
453-
1.0,
454-
0.0,
455-
],
456-
},
457-
},
458-
],
418+
standard_query = {
419+
"query": {
420+
"bool": {
421+
"filter": [],
422+
"must": [{"match": {"text_field": {"query": "foo"}}}],
459423
}
460424
}
461425
}
426+
knn_query = {
427+
"field": "vector_field",
428+
"filter": [],
429+
"k": 3,
430+
"num_candidates": 50,
431+
"query_vector": [
432+
1.0,
433+
1.0,
434+
1.0,
435+
1.0,
436+
1.0,
437+
1.0,
438+
1.0,
439+
1.0,
440+
1.0,
441+
0.0,
442+
],
443+
}
462444

463-
if isinstance(expected_rrf, dict):
464-
cmp_query_body["retriever"]["rrf"].update(expected_rrf)
445+
if expected_rrf is not False:
446+
cmp_query_body = {
447+
"retriever": {
448+
"rrf": {
449+
"retrievers": [
450+
{"standard": standard_query},
451+
{"knn": knn_query},
452+
],
453+
}
454+
}
455+
}
456+
if isinstance(expected_rrf, dict):
457+
cmp_query_body["retriever"]["rrf"].update(expected_rrf)
458+
else:
459+
cmp_query_body = {
460+
"knn": knn_query,
461+
**standard_query,
462+
}
465463

466464
assert query_body == cmp_query_body
467465

468466
return query_body
469467

470468
# 1. check query_body is okay
471-
rrf_test_cases: List[Union[dict, bool]] = [
472-
True,
473-
False,
474-
{"rank_constant": 1, "rank_window_size": 5},
475-
]
469+
if es_version(sync_client) >= (8, 14):
470+
rrf_test_cases: List[Union[dict, bool]] = [
471+
True,
472+
False,
473+
{"rank_constant": 1, "rank_window_size": 5},
474+
]
475+
else:
476+
# for 8.13.x and older there is no retriever query, so we can only
477+
# run hybrid searches with rrf=False
478+
rrf_test_cases: List[Union[dict, bool]] = [False]
476479
for rrf_test_case in rrf_test_cases:
477480
store = VectorStore(
478481
index=index,

0 commit comments

Comments
 (0)