21
21
import java .io .IOException ;
22
22
import java .util .Arrays ;
23
23
import java .util .Comparator ;
24
+ import java .util .List ;
24
25
import java .util .Objects ;
26
+ import java .util .concurrent .ExecutionException ;
27
+ import java .util .concurrent .Executor ;
28
+ import java .util .concurrent .FutureTask ;
25
29
import org .apache .lucene .codecs .KnnVectorsReader ;
26
30
import org .apache .lucene .index .FieldInfo ;
27
31
import org .apache .lucene .index .IndexReader ;
28
32
import org .apache .lucene .index .LeafReaderContext ;
29
33
import org .apache .lucene .util .BitSet ;
30
34
import org .apache .lucene .util .BitSetIterator ;
31
35
import org .apache .lucene .util .Bits ;
36
+ import org .apache .lucene .util .ThreadInterruptedException ;
32
37
33
38
/**
34
39
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
@@ -62,9 +67,8 @@ public AbstractKnnVectorQuery(String field, int k, Query filter) {
62
67
@ Override
63
68
public Query rewrite (IndexSearcher indexSearcher ) throws IOException {
64
69
IndexReader reader = indexSearcher .getIndexReader ();
65
- TopDocs [] perLeafResults = new TopDocs [reader .leaves ().size ()];
66
70
67
- Weight filterWeight = null ;
71
+ final Weight filterWeight ;
68
72
if (filter != null ) {
69
73
BooleanQuery booleanQuery =
70
74
new BooleanQuery .Builder ()
@@ -73,17 +77,16 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
73
77
.build ();
74
78
Query rewritten = indexSearcher .rewrite (booleanQuery );
75
79
filterWeight = indexSearcher .createWeight (rewritten , ScoreMode .COMPLETE_NO_SCORES , 1f );
80
+ } else {
81
+ filterWeight = null ;
76
82
}
77
83
78
- for (LeafReaderContext ctx : reader .leaves ()) {
79
- TopDocs results = searchLeaf (ctx , filterWeight );
80
- if (ctx .docBase > 0 ) {
81
- for (ScoreDoc scoreDoc : results .scoreDocs ) {
82
- scoreDoc .doc += ctx .docBase ;
83
- }
84
- }
85
- perLeafResults [ctx .ord ] = results ;
86
- }
84
+ Executor executor = indexSearcher .getExecutor ();
85
+ TopDocs [] perLeafResults =
86
+ (executor == null )
87
+ ? sequentialSearch (reader .leaves (), filterWeight )
88
+ : parallelSearch (reader .leaves (), filterWeight , executor );
89
+
87
90
// Merge sort the results
88
91
TopDocs topK = TopDocs .merge (k , perLeafResults );
89
92
if (topK .scoreDocs .length == 0 ) {
@@ -92,7 +95,54 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
92
95
return createRewrittenQuery (reader , topK );
93
96
}
94
97
98
+ private TopDocs [] sequentialSearch (
99
+ List <LeafReaderContext > leafReaderContexts , Weight filterWeight ) {
100
+ try {
101
+ TopDocs [] perLeafResults = new TopDocs [leafReaderContexts .size ()];
102
+ for (LeafReaderContext ctx : leafReaderContexts ) {
103
+ perLeafResults [ctx .ord ] = searchLeaf (ctx , filterWeight );
104
+ }
105
+ return perLeafResults ;
106
+ } catch (Exception e ) {
107
+ throw new RuntimeException (e );
108
+ }
109
+ }
110
+
111
+ private TopDocs [] parallelSearch (
112
+ List <LeafReaderContext > leafReaderContexts , Weight filterWeight , Executor executor ) {
113
+ List <FutureTask <TopDocs >> tasks =
114
+ leafReaderContexts .stream ()
115
+ .map (ctx -> new FutureTask <>(() -> searchLeaf (ctx , filterWeight )))
116
+ .toList ();
117
+
118
+ SliceExecutor sliceExecutor = new SliceExecutor (executor );
119
+ sliceExecutor .invokeAll (tasks );
120
+
121
+ return tasks .stream ()
122
+ .map (
123
+ task -> {
124
+ try {
125
+ return task .get ();
126
+ } catch (ExecutionException e ) {
127
+ throw new RuntimeException (e .getCause ());
128
+ } catch (InterruptedException e ) {
129
+ throw new ThreadInterruptedException (e );
130
+ }
131
+ })
132
+ .toArray (TopDocs []::new );
133
+ }
134
+
95
135
private TopDocs searchLeaf (LeafReaderContext ctx , Weight filterWeight ) throws IOException {
136
+ TopDocs results = getLeafResults (ctx , filterWeight );
137
+ if (ctx .docBase > 0 ) {
138
+ for (ScoreDoc scoreDoc : results .scoreDocs ) {
139
+ scoreDoc .doc += ctx .docBase ;
140
+ }
141
+ }
142
+ return results ;
143
+ }
144
+
145
+ private TopDocs getLeafResults (LeafReaderContext ctx , Weight filterWeight ) throws IOException {
96
146
Bits liveDocs = ctx .reader ().getLiveDocs ();
97
147
int maxDoc = ctx .reader ().maxDoc ();
98
148
0 commit comments