Skip to content

kNN vector rescoring for quantized vectors #116663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
df06716
Use a FunctionScoreQuery to replace scores using a VectorSimilarity b…
carlosdelest Nov 12, 2024
be76444
Change API to use "rescore": {"oversample": 1.0}
carlosdelest Nov 13, 2024
bd920c5
Add tests
carlosdelest Nov 13, 2024
91204a1
Fix inference module
carlosdelest Nov 13, 2024
6c2c1be
Merge branch 'main' into feature/knn-vector-rescore-query
carlosdelest Nov 13, 2024
b44ec48
Fix knn query usage in other modules
carlosdelest Nov 14, 2024
2a9e300
Add rescore vector builder to KnnSearchBuilder
carlosdelest Nov 14, 2024
4497b92
Add vector rescore builder to kNN retriever
carlosdelest Nov 14, 2024
955da1f
Fix refactoring, spotless
carlosdelest Nov 14, 2024
ff2c1e9
Check oversampling is not used for quantized types
carlosdelest Nov 18, 2024
bc1e5c6
Minor refactoring to reuse KnnScoreDocQuery
carlosdelest Nov 20, 2024
a7936da
Use KnnRescoreVectorQuery to perform rescoring and limiting the numbe…
carlosdelest Nov 20, 2024
f5080a6
Small name refactoring, fix adjusting parameters
carlosdelest Nov 21, 2024
39e1676
Add testing
carlosdelest Nov 21, 2024
9946e8d
Add tests for RescoreKnnVectorQuery
carlosdelest Nov 27, 2024
4fbbadd
Spotless
carlosdelest Nov 27, 2024
0dab8ea
Add test for knn retriever
carlosdelest Nov 29, 2024
257b75d
Add tests
carlosdelest Nov 29, 2024
81384f2
Parameterize recore knn vector query tests
carlosdelest Nov 29, 2024
1347d4b
Adds profiling, including a small refactoring of the QueryProfiler in…
carlosdelest Nov 29, 2024
916ac83
Spotless
carlosdelest Nov 29, 2024
229ce2d
Add YAML tests
carlosdelest Dec 2, 2024
90e79db
Merge branch 'main' into feature/knn-vector-rescore-query
carlosdelest Dec 2, 2024
934eedb
Minor documentation / style fixes
carlosdelest Dec 2, 2024
cca6e39
Fix test commpilation
carlosdelest Dec 2, 2024
d95db48
Properly implement advanceExact()
carlosdelest Dec 2, 2024
ee904b0
Fix VectorSimilarityFloatValueSource implementation for advanceExact
carlosdelest Dec 2, 2024
0d77521
Add capability for BwC tests
carlosdelest Dec 2, 2024
095a951
Update docs/changelog/116663.yaml
carlosdelest Dec 3, 2024
732bd7d
Correctly implement profiling. Rename ProfilingQuery to QueryProfiler…
carlosdelest Dec 4, 2024
b5e6309
Fix toString()
carlosdelest Dec 4, 2024
3120c5c
YAML tests do not check doc ordering, just scores
carlosdelest Dec 4, 2024
da018b8
Add tests for rescoring on non-quantized values
carlosdelest Dec 4, 2024
1d5426b
having null index type means no quantization, as INT8 is explicit in …
carlosdelest Dec 4, 2024
2379596
Bytes can't be quantized - remove all infra for byte vectors in resco…
carlosdelest Dec 4, 2024
b0c2221
Add assertion to advanceExact()
carlosdelest Dec 4, 2024
9f4cebd
Parsing / toXContent improvements
carlosdelest Dec 4, 2024
fab7395
private access for field, use getter instead
carlosdelest Dec 4, 2024
e03c8e9
toXContent improvements
carlosdelest Dec 4, 2024
69b5451
Fix toXContent / parsing
carlosdelest Dec 4, 2024
ddc6094
Make rescore parameter mandatory
carlosdelest Dec 5, 2024
3ef07fa
Add index types to vector query builder tests
carlosdelest Dec 5, 2024
c127f6e
Merge branch 'main' into feature/knn-vector-rescore-query
carlosdelest Dec 5, 2024
2280fbd
Merge remote-tracking branch 'carlosdelest/feature/knn-vector-rescore…
carlosdelest Dec 5, 2024
fd9188f
Fix compilation on rrf plugin
carlosdelest Dec 5, 2024
863b6e6
Fix tests for vector query builders to ensure multiple dimensions / i…
carlosdelest Dec 5, 2024
9c9773f
Fix test to use just floats
carlosdelest Dec 5, 2024
7d083b7
Fix test, add coverage for byte element types
carlosdelest Dec 5, 2024
a8633e5
Fix YAML test capabilities, add profile test for similarity
carlosdelest Dec 5, 2024
ef8ac8c
Rename "rescore" to "rescore_vector"
carlosdelest Dec 5, 2024
4c65c8a
Fix sneaky bug on iterator
carlosdelest Dec 5, 2024
9907aad
Use 'num_candidates_factor' parameter and update num_candidates inste…
carlosdelest Dec 9, 2024
978cff3
Add knn retriever YAML tests
carlosdelest Dec 9, 2024
9412be0
Simplify logic for RescoreKnnVectorQuery now that k is not modifiable
carlosdelest Dec 9, 2024
7045500
Limit for rescoring factor is 1.0, so we can't have less rescored doc…
carlosdelest Dec 9, 2024
826fd3b
Merge branch 'main' into feature/knn-vector-rescore-query
carlosdelest Dec 9, 2024
497d8e2
Fix test after merge
carlosdelest Dec 10, 2024
83238a5
Vector similarity needs to wrap the new rescoring query and not the o…
carlosdelest Dec 10, 2024
fa975de
Change similarity to MIP in tests
carlosdelest Dec 10, 2024
74a22f8
Spotless
carlosdelest Dec 10, 2024
94963fc
Merge branch 'main' into feature/knn-vector-rescore-query
carlosdelest Dec 10, 2024
a256de9
Apply suggestions from code review
carlosdelest Dec 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ static TransportVersion def(int id) {
public static final TransportVersion LOGSDB_TELEMETRY_STATS = def(8_785_00_0);
public static final TransportVersion KQL_QUERY_ADDED = def(8_786_00_0);
public static final TransportVersion ROLE_MONITOR_STATS = def(8_787_00_0);
public static final TransportVersion KNN_QUERY_RESCORE_OVERSAMPLE = def(8_788_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.queries.function.FunctionScoreQuery;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
Expand Down Expand Up @@ -121,6 +122,8 @@ public static boolean isNotUnitVector(float magnitude) {
public static short MAX_DIMS_COUNT = 4096; // maximum allowed number of dimensions
public static int MAX_DIMS_COUNT_BIT = 4096 * Byte.SIZE; // maximum allowed number of dimensions

public static final int OVERSAMPLE_LIMIT = 10_000; // Max oversample allowed for k and num_candidates

public static short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to vector
public static final int MAGNITUDE_BYTES = 4;

Expand Down Expand Up @@ -2000,6 +2003,7 @@ public Query createKnnQuery(
VectorData queryVector,
Integer k,
int numCands,
Float rescoreOversample,
Query filter,
Float similarityThreshold,
BitSetProducer parentFilter
Expand All @@ -2010,21 +2014,50 @@ public Query createKnnQuery(
);
}
return switch (getElementType()) {
case BYTE -> createKnnByteQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter);
case FLOAT -> createKnnFloatQuery(queryVector.asFloatVector(), k, numCands, filter, similarityThreshold, parentFilter);
case BIT -> createKnnBitQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter);
case BYTE -> createKnnByteQuery(
queryVector.asByteVector(),
k,
numCands,
filter,
rescoreOversample,
similarityThreshold,
parentFilter
);
case FLOAT -> createKnnFloatQuery(
queryVector.asFloatVector(),
k,
numCands,
rescoreOversample,
filter,
similarityThreshold,
parentFilter
);
case BIT -> createKnnBitQuery(
queryVector.asByteVector(),
k,
numCands,
rescoreOversample,
filter,
similarityThreshold,
parentFilter
);
};
}

private Query createKnnBitQuery(
byte[] queryVector,
Integer k,
int numCands,
Float rescoreOversample,
Query filter,
Float similarityThreshold,
BitSetProducer parentFilter
) {
elementType.checkDimensions(dims, queryVector.length);
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
}
Query knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter);
Expand All @@ -2035,6 +2068,17 @@ private Query createKnnBitQuery(
similarity.score(similarityThreshold, elementType, dims)
);
}
if (rescoreOversample != null) {
knnQuery = new FunctionScoreQuery(
knnQuery,
new VectorSimilarityByteValueSource(
name(),
queryVector,
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.BYTE)
)
);

}
return knnQuery;
}

Expand All @@ -2043,6 +2087,7 @@ private Query createKnnByteQuery(
Integer k,
int numCands,
Query filter,
Float rescoreOversample,
Float similarityThreshold,
BitSetProducer parentFilter
) {
Expand All @@ -2052,23 +2097,38 @@ private Query createKnnByteQuery(
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
}
int adjustedK = rescoreOversample == null ? k : Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample));
int adjustedNumCands = Math.max(adjustedK, numCands);

Query knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter);
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter)
: new ESKnnByteVectorQuery(name(), queryVector, adjustedK, adjustedNumCands, filter);
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
knnQuery,
similarityThreshold,
similarity.score(similarityThreshold, elementType, dims)
);
}
if (rescoreOversample != null) {
knnQuery = new FunctionScoreQuery(
knnQuery,
new VectorSimilarityByteValueSource(
name(),
queryVector,
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.BYTE)
)
);

}
return knnQuery;
}

private Query createKnnFloatQuery(
float[] queryVector,
Integer k,
int numCands,
Float rescoreOversample,
Query filter,
Float similarityThreshold,
BitSetProducer parentFilter
Expand All @@ -2088,16 +2148,30 @@ && isNotUnitVector(squaredMagnitude)) {
}
}
}

int adjustedK = rescoreOversample == null ? k : Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * rescoreOversample));
int adjustedNumCands = Math.max(adjustedK, numCands);
Query knnQuery = parentFilter != null
? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter)
: new ESKnnFloatVectorQuery(name(), queryVector, k, numCands, filter);
? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter)
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, adjustedNumCands, filter);
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
knnQuery,
similarityThreshold,
similarity.score(similarityThreshold, elementType, dims)
);
}
if (rescoreOversample != null) {
knnQuery = new FunctionScoreQuery(
knnQuery,
new VectorSimilarityFloatValueSource(
name(),
queryVector,
similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT)
)
);

}
return knnQuery;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.mapper.vectors;

import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.DoubleValues;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.IndexSearcher;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;

public class VectorSimilarityByteValueSource extends DoubleValuesSource {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this. byte values cannot be quantized.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤦 Thanks for the catch - it's just floats that get quantized. Removing it in 2379596


private final String field;
private final byte[] target;
private final VectorSimilarityFunction vectorSimilarityFunction;

public VectorSimilarityByteValueSource(String field, byte[] target, VectorSimilarityFunction vectorSimilarityFunction) {
this.field = field;
this.target = target;
this.vectorSimilarityFunction = vectorSimilarityFunction;
}

@Override
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
final LeafReader reader = ctx.reader();

ByteVectorValues vectorValues = reader.getByteVectorValues(field);
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();

return new DoubleValues() {
private int docId = -1;

@Override
public double doubleValue() throws IOException {
return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(docId));
}

@Override
public boolean advanceExact(int doc) throws IOException {
docId = doc;
return iterator.advance(docId) != DocIdSetIterator.NO_MORE_DOCS;
}
};
}

@Override
public boolean needsScores() {
return false;
}

@Override
public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException {
return this;
}

@Override
public int hashCode() {
return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
VectorSimilarityByteValueSource that = (VectorSimilarityByteValueSource) o;
return Objects.equals(field, that.field)
&& Objects.deepEquals(target, that.target)
&& vectorSimilarityFunction == that.vectorSimilarityFunction;
}

@Override
public String toString() {
return "VectorSimilarityByteValueSource(" + field + ", " + Arrays.toString(target) + ", " + vectorSimilarityFunction + ")";
}

@Override
public boolean isCacheable(LeafReaderContext ctx) {
return false;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.mapper.vectors;

import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.DoubleValues;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.IndexSearcher;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;

public class VectorSimilarityFloatValueSource extends DoubleValuesSource {

private final String field;
private final float[] target;
private final VectorSimilarityFunction vectorSimilarityFunction;

public VectorSimilarityFloatValueSource(String field, float[] target, VectorSimilarityFunction vectorSimilarityFunction) {
this.field = field;
this.target = target;
this.vectorSimilarityFunction = vectorSimilarityFunction;
}

@Override
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
final LeafReader reader = ctx.reader();

FloatVectorValues vectorValues = reader.getFloatVectorValues(field);
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();

return new DoubleValues() {
private int docId = -1;

@Override
public double doubleValue() throws IOException {
return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(docId));
}

@Override
public boolean advanceExact(int doc) throws IOException {
docId = doc;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets add an assert doc > iterator.docID()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good - b0c2221

return iterator.advance(docId) != DocIdSetIterator.NO_MORE_DOCS;
}
};
}

@Override
public boolean needsScores() {
return false;
}

@Override
public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException {
return this;
}

@Override
public int hashCode() {
return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
VectorSimilarityFloatValueSource that = (VectorSimilarityFloatValueSource) o;
return Objects.equals(field, that.field)
&& Objects.deepEquals(target, that.target)
&& vectorSimilarityFunction == that.vectorSimilarityFunction;
}

@Override
public String toString() {
return "VectorSimilarityFloatValueSource(" + field + ", " + Arrays.toString(target) + ", " + vectorSimilarityFunction + ")";
}

@Override
public boolean isCacheable(LeafReaderContext ctx) {
return false;
}
}
Loading