Skip to content

Add rescore knn vector test coverage #122801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.query;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorIndexType;
import org.elasticsearch.index.mapper.vectors.DenseVectorScriptDocValues;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.functionscore.ScriptScoreQueryBuilder;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.script.MockScriptPlugin;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.RescoreVectorBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.junit.Before;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse;
import static org.hamcrest.Matchers.equalTo;

public class RescoreKnnVectorQueryIT extends ESIntegTestCase {

public static final String INDEX_NAME = "test";
public static final String VECTOR_FIELD = "vector";
public static final String VECTOR_SCORE_SCRIPT = "vector_scoring";
public static final String QUERY_VECTOR_PARAM = "query_vector";

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Collections.singleton(CustomScriptPlugin.class);
}

public static class CustomScriptPlugin extends MockScriptPlugin {
private static final VectorSimilarityFunction SIMILARITY_FUNCTION = DenseVectorFieldMapper.VectorSimilarity.L2_NORM
.vectorSimilarityFunction(IndexVersion.current(), DenseVectorFieldMapper.ElementType.FLOAT);

@Override
protected Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
return Map.of(VECTOR_SCORE_SCRIPT, vars -> {
Map<?, ?> doc = (Map<?, ?>) vars.get("doc");
return SIMILARITY_FUNCTION.compare(
((DenseVectorScriptDocValues) doc.get(VECTOR_FIELD)).getVectorValue(),
(float[]) vars.get(QUERY_VECTOR_PARAM)
);
});
}
}

@Before
public void setup() throws IOException {
String type = randomFrom(
Arrays.stream(VectorIndexType.values())
.filter(VectorIndexType::isQuantized)
.map(t -> t.name().toLowerCase(Locale.ROOT))
.collect(Collectors.toCollection(ArrayList::new))
);
XContentBuilder mapping = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(VECTOR_FIELD)
.field("type", "dense_vector")
.field("similarity", "l2_norm")
.startObject("index_options")
.field("type", type)
.endObject()
.endObject()
.endObject()
.endObject();

Settings settings = Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5))
.build();
prepareCreate(INDEX_NAME).setMapping(mapping).setSettings(settings).get();
ensureGreen(INDEX_NAME);
}

private record TestParams(
int numDocs,
int numDims,
float[] queryVector,
int k,
int numCands,
RescoreVectorBuilder rescoreVectorBuilder
) {
public static TestParams generate() {
int numDims = randomIntBetween(32, 512) * 2; // Ensure even dimensions
int numDocs = randomIntBetween(10, 100);
int k = randomIntBetween(1, numDocs - 5);
return new TestParams(
numDocs,
numDims,
randomVector(numDims),
k,
(int) (k * randomFloatBetween(1.0f, 10.0f, true)),
new RescoreVectorBuilder(randomFloatBetween(1.0f, 100f, true))
);
}
}

public void testKnnSearchRescore() {
BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> knnSearchGenerator = (testParams, requestBuilder) -> {
KnnSearchBuilder knnSearch = new KnnSearchBuilder(
VECTOR_FIELD,
testParams.queryVector,
testParams.k,
testParams.numCands,
testParams.rescoreVectorBuilder,
null
);
return requestBuilder.setKnnSearch(List.of(knnSearch));
};
testKnnRescore(knnSearchGenerator);
}

public void testKnnQueryRescore() {
BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> knnQueryGenerator = (testParams, requestBuilder) -> {
KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder(
VECTOR_FIELD,
testParams.queryVector,
testParams.k,
testParams.numCands,
testParams.rescoreVectorBuilder,
null
);
return requestBuilder.setQuery(knnQuery);
};
testKnnRescore(knnQueryGenerator);
}

public void testKnnRetriever() {
BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> knnQueryGenerator = (testParams, requestBuilder) -> {
KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder(
VECTOR_FIELD,
testParams.queryVector,
null,
testParams.k,
testParams.numCands,
testParams.rescoreVectorBuilder,
null
);
return requestBuilder.setSource(new SearchSourceBuilder().retriever(knnRetriever));
};
testKnnRescore(knnQueryGenerator);
}

private void testKnnRescore(BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> searchRequestGenerator) {
TestParams testParams = TestParams.generate();

int numDocs = testParams.numDocs;
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];

for (int i = 0; i < numDocs; i++) {
docs[i] = prepareIndex(INDEX_NAME).setId("" + i).setSource(VECTOR_FIELD, randomVector(testParams.numDims));
}
indexRandom(true, docs);

float[] queryVector = testParams.queryVector;
float oversample = randomFloatBetween(1.0f, 100f, true);
RescoreVectorBuilder rescoreVectorBuilder = new RescoreVectorBuilder(oversample);

SearchRequestBuilder requestBuilder = searchRequestGenerator.apply(
testParams,
prepareSearch(INDEX_NAME).setSize(numDocs).setTrackTotalHits(randomBoolean())
);

assertNoFailuresAndResponse(requestBuilder, knnResponse -> { compareWithExactSearch(knnResponse, queryVector, numDocs); });
}

private static void compareWithExactSearch(SearchResponse knnResponse, float[] queryVector, int docCount) {
// Do an exact query and compare
Script script = new Script(
ScriptType.INLINE,
CustomScriptPlugin.NAME,
VECTOR_SCORE_SCRIPT,
Map.of(QUERY_VECTOR_PARAM, queryVector)
);
ScriptScoreQueryBuilder scriptScoreQueryBuilder = QueryBuilders.scriptScoreQuery(new MatchAllQueryBuilder(), script);
assertNoFailuresAndResponse(prepareSearch(INDEX_NAME).setQuery(scriptScoreQueryBuilder).setSize(docCount), exactResponse -> {
assertHitCount(exactResponse, docCount);

int i = 0;
SearchHit[] exactHits = exactResponse.getHits().getHits();
for (SearchHit knnHit : knnResponse.getHits().getHits()) {
while (i < exactHits.length && exactHits[i].getId().equals(knnHit.getId()) == false) {
i++;
}
if (i >= exactHits.length) {
fail("Knn doc not found in exact search");
}
assertThat("Real score is not the same as rescored score", knnHit.getScore(), equalTo(exactHits[i].getScore()));
}
});
}

private static float[] randomVector(int numDimensions) {
float[] vector = new float[numDimensions];
for (int j = 0; j < numDimensions; j++) {
vector[j] = randomFloatBetween(0, 1, true);
}
return vector;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,7 @@ public final int hashCode() {
}
}

private enum VectorIndexType {
public enum VectorIndexType {
HNSW("hnsw", false) {
@Override
public IndexOptions parseIndexOptions(String fieldName, Map<String, ?> indexOptionsMap) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,24 @@
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;

import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Objects;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

/**
* A query that matches the provided docs with their scores.
*
* Note: this query was adapted from Lucene's DocAndScoreQuery from the class
* Note: this query was originally adapted from Lucene's DocAndScoreQuery from the class
* {@link org.apache.lucene.search.KnnFloatVectorQuery}, which is package-private.
* There are no changes to the behavior, just some renames.
*/
public class KnnScoreDocQuery extends Query {
private final int[] docs;
Expand All @@ -50,13 +51,18 @@ public class KnnScoreDocQuery extends Query {
/**
* Creates a query.
*
* @param docs the global doc IDs of documents that match, in ascending order
* @param scores the scores of the matching documents
* @param scoreDocs an array of ScoreDocs to use for the query
* @param reader IndexReader
*/
KnnScoreDocQuery(int[] docs, float[] scores, IndexReader reader) {
this.docs = docs;
this.scores = scores;
KnnScoreDocQuery(ScoreDoc[] scoreDocs, IndexReader reader) {
// Ensure that the docs are sorted by docId, as they are later searched using binary search
Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
this.docs = new int[scoreDocs.length];
this.scores = new float[scoreDocs.length];
for (int i = 0; i < scoreDocs.length; i++) {
docs[i] = scoreDocs[i].doc;
scores[i] = scoreDocs[i].score;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it sounds like this is redundant if we have appropriate test coverage? I was also wondering if it may be worth changing the two first arguments into a ScoreDoc[] given that's how stuff comes in, and perhaps unifying the sorting here. I realize though that this is a copy of a Lucene class and the change I am suggesting will make it diverge from its original source.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it sounds like this is redundant if we have appropriate test coverage?

My thinking was to provide a way to understand a test failure in an easier way in case someone provided a non-sorted array, instead of going through all the investigations that you had to do 😓

I'm happy with removing the assertion in case you think it's unnecessary, but I think it helps to understand what the preconditions for this constructor are.

I was also wondering if it may be worth changing the two first arguments into a ScoreDoc[] given that's how stuff comes in, and perhaps unifying the sorting here. I realize though that this is a copy of a Lucene class and the change I am suggesting will make it diverge from its original source.

I think that's a good idea. I will give it a try.

I realize though that this is a copy of a Lucene class and the change I am suggesting will make it diverge from its original source.

It already diverges a bit in terms of making it easier to create - as long as it's on the constructor stuff I think we should be good for doing the change.

I'll give it a go and come back for feedback.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have a test for RescoreKnnVectorQuery that indexes a bunch of random vectors, searches with a random vector and asserts the rewrite is a KnnScoreDocQuery with the appropriately ordered values.

It seems we are almost there in RescoreKnnVectorQueryTests, but maybe add some assertions there. Maybe via package private methods?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it may be worth changing the two first arguments into a ScoreDoc[] given that's how stuff comes in, and perhaps unifying the sorting here

@javanna I gave it a try in a073f43 - I like it more, it simplifies how clients create this query plus we enforce the invariant in the constructor itself 💯

We should have a test for RescoreKnnVectorQuery that indexes a bunch of random vectors, searches with a random vector and asserts the rewrite is a KnnScoreDocQuery with the appropriately ordered values.

@benwtrent I think the change in a073f43 makes it unnecessary. We're already checking via random insertions in the test. Do you think we need to add something else to make sure this doesn't bite us again?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doing the sort in the ctor is fine and as long as we have tests that fill fail if somebody removes that sort, I am happy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RescoreKnnVectorQueryIT add those tests. I checked by removing the sort that Luca added back in #122653 that this was caught by the newly added tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it. It also allows to share some code between the two consumers. Perhaps make it clear in the javadocs that this is no longer a straight copy of its lucene sibling. Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps make it clear in the javadocs that this is no longer a straight copy of its lucene sibling.

👍 I've clarified that in ee464fe

this.segmentStarts = findSegmentStarts(reader, docs);
this.contextIdentity = reader.getContext().id();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,7 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep

@Override
protected Query doToQuery(SearchExecutionContext context) throws IOException {
int numDocs = scoreDocs.length;
int[] docs = new int[numDocs];
float[] scores = new float[numDocs];
for (int i = 0; i < numDocs; i++) {
docs[i] = scoreDocs[i].doc;
scores[i] = scoreDocs[i].score;
}

return new KnnScoreDocQuery(docs, scores, context.getIndexReader());
return new KnnScoreDocQuery(scoreDocs, context.getIndexReader());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource;
import org.elasticsearch.search.profile.query.QueryProfiler;

import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Objects;

/**
Expand Down Expand Up @@ -60,16 +58,7 @@ public Query rewrite(IndexSearcher searcher) throws IOException {
// Retrieve top k documents from the rescored query
TopDocs topDocs = searcher.search(query, k);
vectorOperations = topDocs.totalHits.value();
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
int[] docIds = new int[scoreDocs.length];
float[] scores = new float[scoreDocs.length];
for (int i = 0; i < scoreDocs.length; i++) {
docIds[i] = scoreDocs[i].doc;
scores[i] = scoreDocs[i].score;
}

return new KnnScoreDocQuery(docIds, scores, searcher.getIndexReader());
return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader());
}

public Query innerQuery() {
Expand Down
Loading