33
33
VectorStore ,
34
34
)
35
35
from elasticsearch .helpers .vectorstore ._sync ._utils import model_is_deployed
36
+ from test_elasticsearch .utils import es_version
36
37
37
38
from . import ConsistentFakeEmbeddings , FakeEmbeddings
38
39
@@ -337,6 +338,9 @@ def test_search_knn_with_hybrid_search(
337
338
self , sync_client : Elasticsearch , index : str
338
339
) -> None :
339
340
"""Test end to end construction and search with metadata."""
341
+ if es_version (sync_client ) < (8 , 14 ):
342
+ pytest .skip ("This test requires Elasticsearch 8.14 or newer" )
343
+
340
344
store = VectorStore (
341
345
index = index ,
342
346
retrieval_strategy = DenseVectorStrategy (hybrid = True ),
@@ -349,20 +353,48 @@ def test_search_knn_with_hybrid_search(
349
353
350
354
def assert_query (query_body : dict , query : Optional [str ]) -> dict :
351
355
assert query_body == {
352
- "knn" : {
353
- "field" : "vector_field" ,
354
- "filter" : [],
355
- "k" : 1 ,
356
- "num_candidates" : 50 ,
357
- "query_vector" : [1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 0.0 ],
358
- },
359
- "query" : {
360
- "bool" : {
361
- "filter" : [],
362
- "must" : [{"match" : {"text_field" : {"query" : "foo" }}}],
356
+ "retriever" : {
357
+ "rrf" : {
358
+ "retrievers" : [
359
+ {
360
+ "standard" : {
361
+ "query" : {
362
+ "bool" : {
363
+ "filter" : [],
364
+ "must" : [
365
+ {
366
+ "match" : {
367
+ "text_field" : {"query" : "foo" }
368
+ }
369
+ }
370
+ ],
371
+ }
372
+ },
373
+ },
374
+ },
375
+ {
376
+ "knn" : {
377
+ "field" : "vector_field" ,
378
+ "filter" : [],
379
+ "k" : 1 ,
380
+ "num_candidates" : 50 ,
381
+ "query_vector" : [
382
+ 1.0 ,
383
+ 1.0 ,
384
+ 1.0 ,
385
+ 1.0 ,
386
+ 1.0 ,
387
+ 1.0 ,
388
+ 1.0 ,
389
+ 1.0 ,
390
+ 1.0 ,
391
+ 0.0 ,
392
+ ],
393
+ },
394
+ },
395
+ ],
363
396
}
364
- },
365
- "rank" : {"rrf" : {}},
397
+ }
366
398
}
367
399
return query_body
368
400
@@ -373,55 +405,77 @@ def test_search_knn_with_hybrid_search_rrf(
373
405
self , sync_client : Elasticsearch , index : str
374
406
) -> None :
375
407
"""Test end to end construction and rrf hybrid search with metadata."""
408
+ if es_version (sync_client ) < (8 , 14 ):
409
+ pytest .skip ("This test requires Elasticsearch 8.14 or newer" )
410
+
376
411
texts = ["foo" , "bar" , "baz" ]
377
412
378
413
def assert_query (
379
414
query_body : dict ,
380
415
query : Optional [str ],
381
416
expected_rrf : Union [dict , bool ],
382
417
) -> dict :
383
- cmp_query_body = {
384
- "knn" : {
385
- "field" : "vector_field" ,
386
- "filter" : [],
387
- "k" : 3 ,
388
- "num_candidates" : 50 ,
389
- "query_vector" : [
390
- 1.0 ,
391
- 1.0 ,
392
- 1.0 ,
393
- 1.0 ,
394
- 1.0 ,
395
- 1.0 ,
396
- 1.0 ,
397
- 1.0 ,
398
- 1.0 ,
399
- 0.0 ,
400
- ],
401
- },
418
+ standard_query = {
402
419
"query" : {
403
420
"bool" : {
404
421
"filter" : [],
405
422
"must" : [{"match" : {"text_field" : {"query" : "foo" }}}],
406
423
}
407
- },
424
+ }
425
+ }
426
+ knn_query = {
427
+ "field" : "vector_field" ,
428
+ "filter" : [],
429
+ "k" : 3 ,
430
+ "num_candidates" : 50 ,
431
+ "query_vector" : [
432
+ 1.0 ,
433
+ 1.0 ,
434
+ 1.0 ,
435
+ 1.0 ,
436
+ 1.0 ,
437
+ 1.0 ,
438
+ 1.0 ,
439
+ 1.0 ,
440
+ 1.0 ,
441
+ 0.0 ,
442
+ ],
408
443
}
409
444
410
- if isinstance (expected_rrf , dict ):
411
- cmp_query_body ["rank" ] = {"rrf" : expected_rrf }
412
- elif isinstance (expected_rrf , bool ) and expected_rrf is True :
413
- cmp_query_body ["rank" ] = {"rrf" : {}}
445
+ if expected_rrf is not False :
446
+ cmp_query_body = {
447
+ "retriever" : {
448
+ "rrf" : {
449
+ "retrievers" : [
450
+ {"standard" : standard_query },
451
+ {"knn" : knn_query },
452
+ ],
453
+ }
454
+ }
455
+ }
456
+ if isinstance (expected_rrf , dict ):
457
+ cmp_query_body ["retriever" ]["rrf" ].update (expected_rrf )
458
+ else :
459
+ cmp_query_body = {
460
+ "knn" : knn_query ,
461
+ ** standard_query ,
462
+ }
414
463
415
464
assert query_body == cmp_query_body
416
465
417
466
return query_body
418
467
419
468
# 1. check query_body is okay
420
- rrf_test_cases : List [Union [dict , bool ]] = [
421
- True ,
422
- False ,
423
- {"rank_constant" : 1 , "window_size" : 5 },
424
- ]
469
+ if es_version (sync_client ) >= (8 , 14 ):
470
+ rrf_test_cases : List [Union [dict , bool ]] = [
471
+ True ,
472
+ False ,
473
+ {"rank_constant" : 1 , "rank_window_size" : 5 },
474
+ ]
475
+ else :
476
+ # for 8.13.x and older there is no retriever query, so we can only
477
+ # run hybrid searches with rrf=False
478
+ rrf_test_cases : List [Union [dict , bool ]] = [False ]
425
479
for rrf_test_case in rrf_test_cases :
426
480
store = VectorStore (
427
481
index = index ,
@@ -441,21 +495,47 @@ def assert_query(
441
495
# 2. check query result is okay
442
496
es_output = store .client .search (
443
497
index = index ,
444
- query = {
445
- "bool" : {
446
- "filter" : [],
447
- "must" : [{"match" : {"text_field" : {"query" : "foo" }}}],
498
+ retriever = {
499
+ "rrf" : {
500
+ "retrievers" : [
501
+ {
502
+ "knn" : {
503
+ "field" : "vector_field" ,
504
+ "filter" : [],
505
+ "k" : 3 ,
506
+ "num_candidates" : 50 ,
507
+ "query_vector" : [
508
+ 1.0 ,
509
+ 1.0 ,
510
+ 1.0 ,
511
+ 1.0 ,
512
+ 1.0 ,
513
+ 1.0 ,
514
+ 1.0 ,
515
+ 1.0 ,
516
+ 1.0 ,
517
+ 0.0 ,
518
+ ],
519
+ },
520
+ },
521
+ {
522
+ "standard" : {
523
+ "query" : {
524
+ "bool" : {
525
+ "filter" : [],
526
+ "must" : [
527
+ {"match" : {"text_field" : {"query" : "foo" }}}
528
+ ],
529
+ }
530
+ },
531
+ },
532
+ },
533
+ ],
534
+ "rank_constant" : 1 ,
535
+ "rank_window_size" : 5 ,
448
536
}
449
537
},
450
- knn = {
451
- "field" : "vector_field" ,
452
- "filter" : [],
453
- "k" : 3 ,
454
- "num_candidates" : 50 ,
455
- "query_vector" : [1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 0.0 ],
456
- },
457
538
size = 3 ,
458
- rank = {"rrf" : {"rank_constant" : 1 , "window_size" : 5 }},
459
539
)
460
540
461
541
assert [o ["_source" ]["text_field" ] for o in output ] == [
0 commit comments