-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Add rescore knn vector test coverage #122801
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
carlosdelest
merged 10 commits into
elastic:main
from
carlosdelest:tests/rescore-knn-vector-query-test-coverage
Feb 24, 2025
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
7083868
Add IT test for RescoreKnnVectorQuery
carlosdelest dddf168
Add assertion to check docs are in order
carlosdelest 46a983a
Refactor RescoreKnnVectorQueryTests to create multiple segments, add …
carlosdelest 3513d20
Minor fixes
carlosdelest be6010c
Merge remote-tracking branch 'origin/main' into tests/rescore-knn-vec…
carlosdelest 813f001
Changing plugin class visibility
carlosdelest a073f43
Sort in query constructor, use ScoreDoc[] vs building individual arra…
carlosdelest a579ff5
Merge remote-tracking branch 'origin/main' into tests/rescore-knn-vec…
carlosdelest ee464fe
Clarify javadoc
carlosdelest 8bc0d4d
Merge branch 'main' into tests/rescore-knn-vector-query-test-coverage
carlosdelest File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
238 changes: 238 additions & 0 deletions
238
.../src/internalClusterTest/java/org/elasticsearch/search/query/RescoreKnnVectorQueryIT.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,238 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the "Elastic License | ||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side | ||
* Public License v 1"; you may not use this file except in compliance with, at | ||
* your election, the "Elastic License 2.0", the "GNU Affero General Public | ||
* License v3.0 only", or the "Server Side Public License, v 1". | ||
*/ | ||
|
||
package org.elasticsearch.search.query; | ||
|
||
import org.apache.lucene.index.VectorSimilarityFunction; | ||
import org.elasticsearch.action.index.IndexRequestBuilder; | ||
import org.elasticsearch.action.search.SearchRequestBuilder; | ||
import org.elasticsearch.action.search.SearchResponse; | ||
import org.elasticsearch.cluster.metadata.IndexMetadata; | ||
import org.elasticsearch.common.settings.Settings; | ||
import org.elasticsearch.index.IndexVersion; | ||
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; | ||
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorIndexType; | ||
import org.elasticsearch.index.mapper.vectors.DenseVectorScriptDocValues; | ||
import org.elasticsearch.index.query.MatchAllQueryBuilder; | ||
import org.elasticsearch.index.query.QueryBuilders; | ||
import org.elasticsearch.index.query.functionscore.ScriptScoreQueryBuilder; | ||
import org.elasticsearch.plugins.Plugin; | ||
import org.elasticsearch.script.MockScriptPlugin; | ||
import org.elasticsearch.script.Script; | ||
import org.elasticsearch.script.ScriptType; | ||
import org.elasticsearch.search.SearchHit; | ||
import org.elasticsearch.search.builder.SearchSourceBuilder; | ||
import org.elasticsearch.search.retriever.KnnRetrieverBuilder; | ||
import org.elasticsearch.search.vectors.KnnSearchBuilder; | ||
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; | ||
import org.elasticsearch.search.vectors.RescoreVectorBuilder; | ||
import org.elasticsearch.test.ESIntegTestCase; | ||
import org.elasticsearch.xcontent.XContentBuilder; | ||
import org.elasticsearch.xcontent.XContentFactory; | ||
import org.junit.Before; | ||
|
||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.Collection; | ||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Locale; | ||
import java.util.Map; | ||
import java.util.function.BiFunction; | ||
import java.util.function.Function; | ||
import java.util.stream.Collectors; | ||
|
||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; | ||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse; | ||
import static org.hamcrest.Matchers.equalTo; | ||
|
||
public class RescoreKnnVectorQueryIT extends ESIntegTestCase { | ||
|
||
public static final String INDEX_NAME = "test"; | ||
public static final String VECTOR_FIELD = "vector"; | ||
public static final String VECTOR_SCORE_SCRIPT = "vector_scoring"; | ||
public static final String QUERY_VECTOR_PARAM = "query_vector"; | ||
|
||
@Override | ||
protected Collection<Class<? extends Plugin>> nodePlugins() { | ||
return Collections.singleton(CustomScriptPlugin.class); | ||
} | ||
|
||
public static class CustomScriptPlugin extends MockScriptPlugin { | ||
private static final VectorSimilarityFunction SIMILARITY_FUNCTION = DenseVectorFieldMapper.VectorSimilarity.L2_NORM | ||
.vectorSimilarityFunction(IndexVersion.current(), DenseVectorFieldMapper.ElementType.FLOAT); | ||
|
||
@Override | ||
protected Map<String, Function<Map<String, Object>, Object>> pluginScripts() { | ||
return Map.of(VECTOR_SCORE_SCRIPT, vars -> { | ||
Map<?, ?> doc = (Map<?, ?>) vars.get("doc"); | ||
return SIMILARITY_FUNCTION.compare( | ||
((DenseVectorScriptDocValues) doc.get(VECTOR_FIELD)).getVectorValue(), | ||
(float[]) vars.get(QUERY_VECTOR_PARAM) | ||
); | ||
}); | ||
} | ||
} | ||
|
||
@Before | ||
public void setup() throws IOException { | ||
String type = randomFrom( | ||
Arrays.stream(VectorIndexType.values()) | ||
.filter(VectorIndexType::isQuantized) | ||
.map(t -> t.name().toLowerCase(Locale.ROOT)) | ||
.collect(Collectors.toCollection(ArrayList::new)) | ||
); | ||
XContentBuilder mapping = XContentFactory.jsonBuilder() | ||
.startObject() | ||
.startObject("properties") | ||
.startObject(VECTOR_FIELD) | ||
.field("type", "dense_vector") | ||
.field("similarity", "l2_norm") | ||
.startObject("index_options") | ||
.field("type", type) | ||
.endObject() | ||
.endObject() | ||
.endObject() | ||
.endObject(); | ||
|
||
Settings settings = Settings.builder() | ||
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) | ||
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5)) | ||
.build(); | ||
prepareCreate(INDEX_NAME).setMapping(mapping).setSettings(settings).get(); | ||
ensureGreen(INDEX_NAME); | ||
} | ||
|
||
private record TestParams( | ||
int numDocs, | ||
int numDims, | ||
float[] queryVector, | ||
int k, | ||
int numCands, | ||
RescoreVectorBuilder rescoreVectorBuilder | ||
) { | ||
public static TestParams generate() { | ||
int numDims = randomIntBetween(32, 512) * 2; // Ensure even dimensions | ||
int numDocs = randomIntBetween(10, 100); | ||
int k = randomIntBetween(1, numDocs - 5); | ||
return new TestParams( | ||
numDocs, | ||
numDims, | ||
randomVector(numDims), | ||
k, | ||
(int) (k * randomFloatBetween(1.0f, 10.0f, true)), | ||
new RescoreVectorBuilder(randomFloatBetween(1.0f, 100f, true)) | ||
); | ||
} | ||
} | ||
|
||
public void testKnnSearchRescore() { | ||
BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> knnSearchGenerator = (testParams, requestBuilder) -> { | ||
KnnSearchBuilder knnSearch = new KnnSearchBuilder( | ||
VECTOR_FIELD, | ||
testParams.queryVector, | ||
testParams.k, | ||
testParams.numCands, | ||
testParams.rescoreVectorBuilder, | ||
null | ||
); | ||
return requestBuilder.setKnnSearch(List.of(knnSearch)); | ||
}; | ||
testKnnRescore(knnSearchGenerator); | ||
} | ||
|
||
public void testKnnQueryRescore() { | ||
BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> knnQueryGenerator = (testParams, requestBuilder) -> { | ||
KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder( | ||
VECTOR_FIELD, | ||
testParams.queryVector, | ||
testParams.k, | ||
testParams.numCands, | ||
testParams.rescoreVectorBuilder, | ||
null | ||
); | ||
return requestBuilder.setQuery(knnQuery); | ||
}; | ||
testKnnRescore(knnQueryGenerator); | ||
} | ||
|
||
public void testKnnRetriever() { | ||
BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> knnQueryGenerator = (testParams, requestBuilder) -> { | ||
KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder( | ||
VECTOR_FIELD, | ||
testParams.queryVector, | ||
null, | ||
testParams.k, | ||
testParams.numCands, | ||
testParams.rescoreVectorBuilder, | ||
null | ||
); | ||
return requestBuilder.setSource(new SearchSourceBuilder().retriever(knnRetriever)); | ||
}; | ||
testKnnRescore(knnQueryGenerator); | ||
} | ||
|
||
private void testKnnRescore(BiFunction<TestParams, SearchRequestBuilder, SearchRequestBuilder> searchRequestGenerator) { | ||
TestParams testParams = TestParams.generate(); | ||
|
||
int numDocs = testParams.numDocs; | ||
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; | ||
|
||
for (int i = 0; i < numDocs; i++) { | ||
docs[i] = prepareIndex(INDEX_NAME).setId("" + i).setSource(VECTOR_FIELD, randomVector(testParams.numDims)); | ||
} | ||
indexRandom(true, docs); | ||
|
||
float[] queryVector = testParams.queryVector; | ||
float oversample = randomFloatBetween(1.0f, 100f, true); | ||
RescoreVectorBuilder rescoreVectorBuilder = new RescoreVectorBuilder(oversample); | ||
|
||
SearchRequestBuilder requestBuilder = searchRequestGenerator.apply( | ||
testParams, | ||
prepareSearch(INDEX_NAME).setSize(numDocs).setTrackTotalHits(randomBoolean()) | ||
); | ||
|
||
assertNoFailuresAndResponse(requestBuilder, knnResponse -> { compareWithExactSearch(knnResponse, queryVector, numDocs); }); | ||
} | ||
|
||
private static void compareWithExactSearch(SearchResponse knnResponse, float[] queryVector, int docCount) { | ||
// Do an exact query and compare | ||
Script script = new Script( | ||
ScriptType.INLINE, | ||
CustomScriptPlugin.NAME, | ||
VECTOR_SCORE_SCRIPT, | ||
Map.of(QUERY_VECTOR_PARAM, queryVector) | ||
); | ||
ScriptScoreQueryBuilder scriptScoreQueryBuilder = QueryBuilders.scriptScoreQuery(new MatchAllQueryBuilder(), script); | ||
assertNoFailuresAndResponse(prepareSearch(INDEX_NAME).setQuery(scriptScoreQueryBuilder).setSize(docCount), exactResponse -> { | ||
assertHitCount(exactResponse, docCount); | ||
|
||
int i = 0; | ||
SearchHit[] exactHits = exactResponse.getHits().getHits(); | ||
for (SearchHit knnHit : knnResponse.getHits().getHits()) { | ||
while (i < exactHits.length && exactHits[i].getId().equals(knnHit.getId()) == false) { | ||
i++; | ||
} | ||
if (i >= exactHits.length) { | ||
fail("Knn doc not found in exact search"); | ||
} | ||
assertThat("Real score is not the same as rescored score", knnHit.getScore(), equalTo(exactHits[i].getScore())); | ||
} | ||
}); | ||
} | ||
|
||
private static float[] randomVector(int numDimensions) { | ||
float[] vector = new float[numDimensions]; | ||
for (int j = 0; j < numDimensions; j++) { | ||
vector[j] = randomFloatBetween(0, 1, true); | ||
} | ||
return vector; | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it sounds like this is redundant if we have appropriate test coverage? I was also wondering if it may be worth changing the two first arguments into a
ScoreDoc[]
given that's how stuff comes in, and perhaps unifying the sorting here. I realize though that this is a copy of a Lucene class and the change I am suggesting will make it diverge from its original source.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thinking was to provide a way to understand a test failure in an easier way in case someone provided a non-sorted array, instead of going through all the investigations that you had to do 😓
I'm happy with removing the assertion in case you think it's unnecessary, but I think it helps to understand what the preconditions for this constructor are.
I think that's a good idea. I will give it a try.
It already diverges a bit in terms of making it easier to create - as long as it's on the constructor stuff I think we should be good for doing the change.
I'll give it a go and come back for feedback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have a test for
RescoreKnnVectorQuery
that indexes a bunch of random vectors, searches with a random vector and asserts the rewrite is aKnnScoreDocQuery
with the appropriately ordered values.It seems we are almost there in
RescoreKnnVectorQueryTests
, but maybe add some assertions there. Maybe via package private methods?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@javanna I gave it a try in a073f43 - I like it more, it simplifies how clients create this query plus we enforce the invariant in the constructor itself 💯
@benwtrent I think the change in a073f43 makes it unnecessary. We're already checking via random insertions in the test. Do you think we need to add something else to make sure this doesn't bite us again?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doing the sort in the ctor is fine and as long as we have tests that fill fail if somebody removes that sort, I am happy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RescoreKnnVectorQueryIT
add those tests. I checked by removing the sort that Luca added back in #122653 that this was caught by the newly added tests.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like it. It also allows to share some code between the two consumers. Perhaps make it clear in the javadocs that this is no longer a straight copy of its lucene sibling. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 I've clarified that in ee464fe