Skip to content

Commit ae1ef5f

Browse files
committed
Refactor unit tests for vector functions. (#48662)
This PR performs the following changes: * Split `ScoreScriptUtilsTests` into `DenseVectorFunctionTests` and `SparseVectorFunctionTests`. This will make it easier to delete all sparse vector function tests once we remove support on 8.x. * As much as possible, break up the large test methods into individual tests for each vector function (`cosineSimilarity`, `l2norm`, etc.).
1 parent ede1681 commit ae1ef5f

File tree

3 files changed

+328
-358
lines changed

3 files changed

+328
-358
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
7+
package org.elasticsearch.xpack.vectors.query;
8+
9+
import org.apache.lucene.util.BytesRef;
10+
import org.elasticsearch.Version;
11+
import org.elasticsearch.script.ScoreScript;
12+
import org.elasticsearch.test.ESTestCase;
13+
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.CosineSimilarity;
14+
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.DotProduct;
15+
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L1Norm;
16+
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2Norm;
17+
import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues.DenseVectorScriptDocValues;
18+
import org.junit.Before;
19+
20+
import java.util.Arrays;
21+
import java.util.Collections;
22+
import java.util.List;
23+
24+
import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoderTests.mockEncodeDenseVector;
25+
import static org.hamcrest.Matchers.containsString;
26+
import static org.mockito.Mockito.mock;
27+
import static org.mockito.Mockito.when;
28+
29+
public class DenseVectorFunctionTests extends ESTestCase {
30+
private String field;
31+
private float[] docVector;
32+
private List<Number> queryVector;
33+
private List<Number> invalidQueryVector;
34+
35+
@Before
36+
public void setUpVectors() {
37+
field = "vector";
38+
docVector = new float[] {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
39+
queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
40+
invalidQueryVector = Arrays.asList(0.5, 111.3);
41+
}
42+
43+
public void testDenseVectorFunctions() {
44+
for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) {
45+
BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion);
46+
DenseVectorScriptDocValues docValues = mock(DenseVectorScriptDocValues.class);
47+
when(docValues.getEncodedValue()).thenReturn(encodedDocVector);
48+
49+
ScoreScript scoreScript = mock(ScoreScript.class);
50+
when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
51+
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, docValues));
52+
53+
testDotProduct(docValues, scoreScript);
54+
testCosineSimilarity(docValues, scoreScript);
55+
testL1Norm(docValues, scoreScript);
56+
testL2Norm(docValues, scoreScript);
57+
}
58+
}
59+
60+
private void testDotProduct(DenseVectorScriptDocValues docValues, ScoreScript scoreScript) {
61+
DotProduct function = new DotProduct(scoreScript, queryVector, field);
62+
double result = function.dotProduct();
63+
assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001);
64+
65+
DotProduct deprecatedFunction = new DotProduct(scoreScript, queryVector, docValues);
66+
double deprecatedResult = deprecatedFunction.dotProduct();
67+
assertEquals("dotProduct result is not equal to the expected value!", 65425.624, deprecatedResult, 0.001);
68+
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
69+
70+
DotProduct invalidFunction = new DotProduct(scoreScript, invalidQueryVector, field);
71+
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::dotProduct);
72+
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
73+
}
74+
75+
private void testCosineSimilarity(DenseVectorScriptDocValues docValues, ScoreScript scoreScript) {
76+
CosineSimilarity function = new CosineSimilarity(scoreScript, queryVector, field);
77+
double result = function.cosineSimilarity();
78+
assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result, 0.001);
79+
80+
CosineSimilarity deprecatedFunction = new CosineSimilarity(scoreScript, queryVector, docValues);
81+
double deprecatedResult = deprecatedFunction.cosineSimilarity();
82+
assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, deprecatedResult, 0.001);
83+
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
84+
85+
CosineSimilarity invalidFunction = new CosineSimilarity(scoreScript, invalidQueryVector, field);
86+
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::cosineSimilarity);
87+
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
88+
}
89+
90+
private void testL1Norm(DenseVectorScriptDocValues docValues, ScoreScript scoreScript) {
91+
L1Norm function = new L1Norm(scoreScript, queryVector, field);
92+
double result = function.l1norm();
93+
assertEquals("l1norm result is not equal to the expected value!", 485.184, result, 0.001);
94+
95+
L1Norm deprecatedFunction = new L1Norm(scoreScript, queryVector, docValues);
96+
double deprecatedResult = deprecatedFunction.l1norm();
97+
assertEquals("l1norm result is not equal to the expected value!", 485.184, deprecatedResult, 0.001);
98+
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
99+
100+
L1Norm invalidFunction = new L1Norm(scoreScript, invalidQueryVector, field);
101+
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::l1norm);
102+
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
103+
}
104+
105+
private void testL2Norm(DenseVectorScriptDocValues docValues, ScoreScript scoreScript) {
106+
L2Norm function = new L2Norm(scoreScript, queryVector, field);
107+
double result = function.l2norm();
108+
assertEquals("l2norm result is not equal to the expected value!", 301.361, result, 0.001);
109+
110+
L2Norm deprecatedFunction = new L2Norm(scoreScript, queryVector, docValues);
111+
double deprecatedResult = deprecatedFunction.l2norm();
112+
assertEquals("l2norm result is not equal to the expected value!", 301.361, deprecatedResult, 0.001);
113+
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
114+
115+
L2Norm invalidFunction = new L2Norm(scoreScript, invalidQueryVector, field);
116+
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::l2norm);
117+
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
118+
}
119+
}

0 commit comments

Comments
 (0)