Skip to content

Commit cd19598

Browse files
authored
Add support for similarity-based vector searches (#12679)
### Description Background in #12579 Add support for getting "all vectors within a radius" as opposed to getting the "topK closest vectors" in the current system ### Considerations I've tried to keep this change minimal and non-invasive by not modifying any APIs and re-using existing HNSW graphs -- changing the graph traversal and result collection criteria to: 1. Visit all nodes (reachable from the entry node in the last level) that are within an outer "traversal" radius 2. Collect all nodes that are within an inner "result" radius ### Advantages 1. Queries that have a high number of "relevant" results will get all of those (not limited by `topK`) 2. Conversely, arbitrary queries where many results are not "relevant" will not waste time in getting all `topK` (when some of them will be removed later) 3. Results of HNSW searches need not be sorted - and we can store them in a plain list as opposed to min-max heaps (saving on `heapify` calls). Merging results from segments is also cheaper, where we just concatenate results as opposed to calculating the index-level `topK` On a higher level, finding `topK` results needed HNSW searches to happen in `#rewrite` because of an interdependence of results between segments - where we want to find the index-level `topK` from multiple segment-level results. This is kind of against Lucene's concept of segments being independently searchable sub-indexes? Moreover, we needed explicit concurrency (#12160) to perform these in parallel, and these shortcomings would be naturally overcome with the new objective of finding "all vectors within a radius" - inherently independent of results from another segment (so we can move searches to a more fitting place?) ### Caveats I could not find much precedent in using HNSW graphs this way (or even the radius-based search for that matter - please add links to existing work if someone is aware) and consequently marked all classes as `@lucene.experimental` For now I have re-used lots of functionality from `AbstractKnnVectorQuery` to keep this minimal, but if the use-case is accepted more widely we can look into writing more suitable queries (as mentioned above briefly)
1 parent 1630ed4 commit cd19598

9 files changed

+1403
-1
lines changed

lucene/CHANGES.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,11 @@ API Changes
171171

172172
New Features
173173
---------------------
174-
(No changes)
174+
175+
* GITHUB#12679: Add support for similarity-based vector searches using [Byte|Float]VectorSimilarityQuery. Uses a new
176+
VectorSimilarityCollector to find all vectors scoring above a `resultSimilarity` while traversing the HNSW graph till
177+
better-scoring nodes are available, or the best candidate is below a score of `traversalSimilarity` in the lowest
178+
level. (Aditya Prakash, Kaival Parikh)
175179

176180
Improvements
177181
---------------------
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.lucene.search;
18+
19+
import java.io.IOException;
20+
import java.util.Arrays;
21+
import java.util.Comparator;
22+
import java.util.Objects;
23+
import org.apache.lucene.index.LeafReader;
24+
import org.apache.lucene.index.LeafReaderContext;
25+
import org.apache.lucene.util.BitSet;
26+
import org.apache.lucene.util.BitSetIterator;
27+
import org.apache.lucene.util.Bits;
28+
29+
/**
30+
* Search for all (approximate) vectors above a similarity threshold.
31+
*
32+
* @lucene.experimental
33+
*/
34+
abstract class AbstractVectorSimilarityQuery extends Query {
35+
protected final String field;
36+
protected final float traversalSimilarity, resultSimilarity;
37+
protected final Query filter;
38+
39+
/**
40+
* Search for all (approximate) vectors above a similarity threshold using {@link
41+
* VectorSimilarityCollector}. If a filter is applied, it traverses as many nodes as the cost of
42+
* the filter, and then falls back to exact search if results are incomplete.
43+
*
44+
* @param field a field that has been indexed as a vector field.
45+
* @param traversalSimilarity (lower) similarity score for graph traversal.
46+
* @param resultSimilarity (higher) similarity score for result collection.
47+
* @param filter a filter applied before the vector search.
48+
*/
49+
AbstractVectorSimilarityQuery(
50+
String field, float traversalSimilarity, float resultSimilarity, Query filter) {
51+
if (traversalSimilarity > resultSimilarity) {
52+
throw new IllegalArgumentException("traversalSimilarity should be <= resultSimilarity");
53+
}
54+
this.field = Objects.requireNonNull(field, "field");
55+
this.traversalSimilarity = traversalSimilarity;
56+
this.resultSimilarity = resultSimilarity;
57+
this.filter = filter;
58+
}
59+
60+
abstract VectorScorer createVectorScorer(LeafReaderContext context) throws IOException;
61+
62+
protected abstract TopDocs approximateSearch(
63+
LeafReaderContext context, Bits acceptDocs, int visitLimit) throws IOException;
64+
65+
@Override
66+
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
67+
throws IOException {
68+
return new Weight(this) {
69+
final Weight filterWeight =
70+
filter == null
71+
? null
72+
: searcher.createWeight(searcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1);
73+
74+
@Override
75+
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
76+
if (filterWeight != null) {
77+
Scorer filterScorer = filterWeight.scorer(context);
78+
if (filterScorer == null || filterScorer.iterator().advance(doc) > doc) {
79+
return Explanation.noMatch("Doc does not match the filter");
80+
}
81+
}
82+
83+
VectorScorer scorer = createVectorScorer(context);
84+
if (scorer == null) {
85+
return Explanation.noMatch("Not indexed as the correct vector field");
86+
} else if (scorer.advanceExact(doc)) {
87+
float score = scorer.score();
88+
if (score >= resultSimilarity) {
89+
return Explanation.match(boost * score, "Score above threshold");
90+
} else {
91+
return Explanation.noMatch("Score below threshold");
92+
}
93+
} else {
94+
return Explanation.noMatch("No vector found for doc");
95+
}
96+
}
97+
98+
@Override
99+
public Scorer scorer(LeafReaderContext context) throws IOException {
100+
@SuppressWarnings("resource")
101+
LeafReader leafReader = context.reader();
102+
Bits liveDocs = leafReader.getLiveDocs();
103+
104+
// If there is no filter
105+
if (filterWeight == null) {
106+
// Return exhaustive results
107+
TopDocs results = approximateSearch(context, liveDocs, Integer.MAX_VALUE);
108+
return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
109+
}
110+
111+
Scorer scorer = filterWeight.scorer(context);
112+
if (scorer == null) {
113+
// If the filter does not match any documents
114+
return null;
115+
}
116+
117+
BitSet acceptDocs;
118+
if (liveDocs == null && scorer.iterator() instanceof BitSetIterator bitSetIterator) {
119+
// If there are no deletions, and matching docs are already cached
120+
acceptDocs = bitSetIterator.getBitSet();
121+
} else {
122+
// Else collect all matching docs
123+
FilteredDocIdSetIterator filtered =
124+
new FilteredDocIdSetIterator(scorer.iterator()) {
125+
@Override
126+
protected boolean match(int doc) {
127+
return liveDocs == null || liveDocs.get(doc);
128+
}
129+
};
130+
acceptDocs = BitSet.of(filtered, leafReader.maxDoc());
131+
}
132+
133+
int cardinality = acceptDocs.cardinality();
134+
if (cardinality == 0) {
135+
// If there are no live matching docs
136+
return null;
137+
}
138+
139+
// Perform an approximate search
140+
TopDocs results = approximateSearch(context, acceptDocs, cardinality);
141+
142+
// If the limit was exhausted
143+
if (results.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO) {
144+
// Return a lazy-loading iterator
145+
return VectorSimilarityScorer.fromAcceptDocs(
146+
this,
147+
boost,
148+
createVectorScorer(context),
149+
new BitSetIterator(acceptDocs, cardinality),
150+
resultSimilarity);
151+
} else {
152+
// Return an iterator over the collected results
153+
return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
154+
}
155+
}
156+
157+
@Override
158+
public boolean isCacheable(LeafReaderContext ctx) {
159+
return true;
160+
}
161+
};
162+
}
163+
164+
@Override
165+
public void visit(QueryVisitor visitor) {
166+
if (visitor.acceptField(field)) {
167+
visitor.visitLeaf(this);
168+
}
169+
}
170+
171+
@Override
172+
public boolean equals(Object o) {
173+
return sameClassAs(o)
174+
&& Objects.equals(field, ((AbstractVectorSimilarityQuery) o).field)
175+
&& Float.compare(
176+
((AbstractVectorSimilarityQuery) o).traversalSimilarity, traversalSimilarity)
177+
== 0
178+
&& Float.compare(((AbstractVectorSimilarityQuery) o).resultSimilarity, resultSimilarity)
179+
== 0
180+
&& Objects.equals(filter, ((AbstractVectorSimilarityQuery) o).filter);
181+
}
182+
183+
@Override
184+
public int hashCode() {
185+
return Objects.hash(field, traversalSimilarity, resultSimilarity, filter);
186+
}
187+
188+
private static class VectorSimilarityScorer extends Scorer {
189+
final DocIdSetIterator iterator;
190+
final float[] cachedScore;
191+
192+
VectorSimilarityScorer(Weight weight, DocIdSetIterator iterator, float[] cachedScore) {
193+
super(weight);
194+
this.iterator = iterator;
195+
this.cachedScore = cachedScore;
196+
}
197+
198+
static VectorSimilarityScorer fromScoreDocs(Weight weight, float boost, ScoreDoc[] scoreDocs) {
199+
// Sort in ascending order of docid
200+
Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
201+
202+
float[] cachedScore = new float[1];
203+
DocIdSetIterator iterator =
204+
new DocIdSetIterator() {
205+
int index = -1;
206+
207+
@Override
208+
public int docID() {
209+
if (index < 0) {
210+
return -1;
211+
} else if (index >= scoreDocs.length) {
212+
return NO_MORE_DOCS;
213+
} else {
214+
cachedScore[0] = boost * scoreDocs[index].score;
215+
return scoreDocs[index].doc;
216+
}
217+
}
218+
219+
@Override
220+
public int nextDoc() {
221+
index++;
222+
return docID();
223+
}
224+
225+
@Override
226+
public int advance(int target) {
227+
index =
228+
Arrays.binarySearch(
229+
scoreDocs,
230+
new ScoreDoc(target, 0),
231+
Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
232+
if (index < 0) {
233+
index = -1 - index;
234+
}
235+
return docID();
236+
}
237+
238+
@Override
239+
public long cost() {
240+
return scoreDocs.length;
241+
}
242+
};
243+
244+
return new VectorSimilarityScorer(weight, iterator, cachedScore);
245+
}
246+
247+
static VectorSimilarityScorer fromAcceptDocs(
248+
Weight weight,
249+
float boost,
250+
VectorScorer scorer,
251+
DocIdSetIterator acceptDocs,
252+
float threshold) {
253+
float[] cachedScore = new float[1];
254+
DocIdSetIterator iterator =
255+
new FilteredDocIdSetIterator(acceptDocs) {
256+
@Override
257+
protected boolean match(int doc) throws IOException {
258+
// Compute the dot product
259+
float score = scorer.score();
260+
cachedScore[0] = score * boost;
261+
return score >= threshold;
262+
}
263+
};
264+
265+
return new VectorSimilarityScorer(weight, iterator, cachedScore);
266+
}
267+
268+
@Override
269+
public int docID() {
270+
return iterator.docID();
271+
}
272+
273+
@Override
274+
public DocIdSetIterator iterator() {
275+
return iterator;
276+
}
277+
278+
@Override
279+
public float getMaxScore(int upTo) {
280+
return Float.POSITIVE_INFINITY;
281+
}
282+
283+
@Override
284+
public float score() {
285+
return cachedScore[0];
286+
}
287+
}
288+
}

0 commit comments

Comments
 (0)