Skip to content

Commit e0d92ee

Browse files
kaivalnpKaival Parikh
and
Kaival Parikh
authored
Concurrent rewrite for KnnVectorQuery (#12160)
- Reduce overhead of non-concurrent search by preserving original execution - Improve readability by factoring into separate functions --------- Co-authored-by: Kaival Parikh <[email protected]>
1 parent 569533b commit e0d92ee

File tree

3 files changed

+71
-12
lines changed

3 files changed

+71
-12
lines changed

lucene/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ Optimizations
7070

7171
* GITHUB#11857, GITHUB#11859, GITHUB#11893, GITHUB#11909: Hunspell: improved suggestion performance (Peter Gromov)
7272

73+
* GITHUB#12160: Concurrent rewrite for AbstractKnnVectorQuery. (Kaival Parikh)
74+
7375
Bug Fixes
7476
---------------------
7577

lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,19 @@
2121
import java.io.IOException;
2222
import java.util.Arrays;
2323
import java.util.Comparator;
24+
import java.util.List;
2425
import java.util.Objects;
26+
import java.util.concurrent.ExecutionException;
27+
import java.util.concurrent.Executor;
28+
import java.util.concurrent.FutureTask;
2529
import org.apache.lucene.codecs.KnnVectorsReader;
2630
import org.apache.lucene.index.FieldInfo;
2731
import org.apache.lucene.index.IndexReader;
2832
import org.apache.lucene.index.LeafReaderContext;
2933
import org.apache.lucene.util.BitSet;
3034
import org.apache.lucene.util.BitSetIterator;
3135
import org.apache.lucene.util.Bits;
36+
import org.apache.lucene.util.ThreadInterruptedException;
3237

3338
/**
3439
* Uses {@link KnnVectorsReader#search} to perform nearest neighbour search.
@@ -62,9 +67,8 @@ public AbstractKnnVectorQuery(String field, int k, Query filter) {
6267
@Override
6368
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
6469
IndexReader reader = indexSearcher.getIndexReader();
65-
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
6670

67-
Weight filterWeight = null;
71+
final Weight filterWeight;
6872
if (filter != null) {
6973
BooleanQuery booleanQuery =
7074
new BooleanQuery.Builder()
@@ -73,17 +77,16 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
7377
.build();
7478
Query rewritten = indexSearcher.rewrite(booleanQuery);
7579
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
80+
} else {
81+
filterWeight = null;
7682
}
7783

78-
for (LeafReaderContext ctx : reader.leaves()) {
79-
TopDocs results = searchLeaf(ctx, filterWeight);
80-
if (ctx.docBase > 0) {
81-
for (ScoreDoc scoreDoc : results.scoreDocs) {
82-
scoreDoc.doc += ctx.docBase;
83-
}
84-
}
85-
perLeafResults[ctx.ord] = results;
86-
}
84+
Executor executor = indexSearcher.getExecutor();
85+
TopDocs[] perLeafResults =
86+
(executor == null)
87+
? sequentialSearch(reader.leaves(), filterWeight)
88+
: parallelSearch(reader.leaves(), filterWeight, executor);
89+
8790
// Merge sort the results
8891
TopDocs topK = TopDocs.merge(k, perLeafResults);
8992
if (topK.scoreDocs.length == 0) {
@@ -92,7 +95,54 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
9295
return createRewrittenQuery(reader, topK);
9396
}
9497

98+
private TopDocs[] sequentialSearch(
99+
List<LeafReaderContext> leafReaderContexts, Weight filterWeight) {
100+
try {
101+
TopDocs[] perLeafResults = new TopDocs[leafReaderContexts.size()];
102+
for (LeafReaderContext ctx : leafReaderContexts) {
103+
perLeafResults[ctx.ord] = searchLeaf(ctx, filterWeight);
104+
}
105+
return perLeafResults;
106+
} catch (Exception e) {
107+
throw new RuntimeException(e);
108+
}
109+
}
110+
111+
private TopDocs[] parallelSearch(
112+
List<LeafReaderContext> leafReaderContexts, Weight filterWeight, Executor executor) {
113+
List<FutureTask<TopDocs>> tasks =
114+
leafReaderContexts.stream()
115+
.map(ctx -> new FutureTask<>(() -> searchLeaf(ctx, filterWeight)))
116+
.toList();
117+
118+
SliceExecutor sliceExecutor = new SliceExecutor(executor);
119+
sliceExecutor.invokeAll(tasks);
120+
121+
return tasks.stream()
122+
.map(
123+
task -> {
124+
try {
125+
return task.get();
126+
} catch (ExecutionException e) {
127+
throw new RuntimeException(e.getCause());
128+
} catch (InterruptedException e) {
129+
throw new ThreadInterruptedException(e);
130+
}
131+
})
132+
.toArray(TopDocs[]::new);
133+
}
134+
95135
private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
136+
TopDocs results = getLeafResults(ctx, filterWeight);
137+
if (ctx.docBase > 0) {
138+
for (ScoreDoc scoreDoc : results.scoreDocs) {
139+
scoreDoc.doc += ctx.docBase;
140+
}
141+
}
142+
return results;
143+
}
144+
145+
private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException {
96146
Bits liveDocs = ctx.reader().getLiveDocs();
97147
int maxDoc = ctx.reader().maxDoc();
98148

lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,10 @@ public void testDimensionMismatch() throws IOException {
210210
IndexSearcher searcher = newSearcher(reader);
211211
AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0}, 10);
212212
IllegalArgumentException e =
213-
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
213+
expectThrows(
214+
RuntimeException.class,
215+
IllegalArgumentException.class,
216+
() -> searcher.search(kvq, 10));
214217
assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
215218
}
216219
}
@@ -495,6 +498,7 @@ public void testRandomWithFilter() throws IOException {
495498
assertEquals(9, results.totalHits.value);
496499
assertEquals(results.totalHits.value, results.scoreDocs.length);
497500
expectThrows(
501+
RuntimeException.class,
498502
UnsupportedOperationException.class,
499503
() ->
500504
searcher.search(
@@ -509,6 +513,7 @@ public void testRandomWithFilter() throws IOException {
509513
assertEquals(5, results.totalHits.value);
510514
assertEquals(results.totalHits.value, results.scoreDocs.length);
511515
expectThrows(
516+
RuntimeException.class,
512517
UnsupportedOperationException.class,
513518
() ->
514519
searcher.search(
@@ -536,6 +541,7 @@ public void testRandomWithFilter() throws IOException {
536541
// Test a filter that exhausts visitedLimit in upper levels, and switches to exact search
537542
Query filter4 = IntPoint.newRangeQuery("tag", lower, lower + 2);
538543
expectThrows(
544+
RuntimeException.class,
539545
UnsupportedOperationException.class,
540546
() ->
541547
searcher.search(
@@ -708,6 +714,7 @@ public void testBitSetQuery() throws IOException {
708714

709715
Query filter = new ThrowingBitSetQuery(new FixedBitSet(numDocs));
710716
expectThrows(
717+
RuntimeException.class,
711718
UnsupportedOperationException.class,
712719
() ->
713720
searcher.search(

0 commit comments

Comments
 (0)