Skip to content

Commit d94c4dc

Browse files
committed
Use float instead of double for query vectors. (#46004)
Currently, when using script_score functions like cosineSimilarity, the query vector is treated as an array of doubles. Since the stored document vectors use floats, it seems like the least surprising behavior for the query vectors to also be float arrays. In addition to improving consistency, this change may help with some optimizations we have been considering around vector dot product.
1 parent 1249e6b commit d94c4dc

File tree

5 files changed

+21
-21
lines changed

5 files changed

+21
-21
lines changed

x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/15_dense_vector_l1l2.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ setup:
6565

6666
- match: {hits.hits.1._id: "2"}
6767
- gte: {hits.hits.1._score: 12.29}
68-
- lte: {hits.hits.1._score: 12.30}
68+
- lte: {hits.hits.1._score: 12.31}
6969

7070
- match: {hits.hits.2._id: "3"}
7171
- gte: {hits.hits.2._score: 0.00}

x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/35_sparse_vector_l1l2.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ setup:
6363

6464
- match: {hits.hits.1._id: "2"}
6565
- gte: {hits.hits.1._score: 12.29}
66-
- lte: {hits.hits.1._score: 12.30}
66+
- lte: {hits.hits.1._score: 12.31}
6767

6868
- match: {hits.hits.2._id: "3"}
6969
- gte: {hits.hits.2._score: 0.00}

x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ public void swap(int i, int j) {
130130
* @param values - values for the sparse query vector
131131
* @param n - number of dimensions
132132
*/
133-
public static void sortSparseDimsDoubleValues(int[] dims, double[] values, int n) {
133+
public static void sortSparseDimsFloatValues(int[] dims, float[] values, int n) {
134134
new InPlaceMergeSorter() {
135135
@Override
136136
public int compare(int i, int j) {
@@ -143,7 +143,7 @@ public void swap(int i, int j) {
143143
dims[i] = dims[j];
144144
dims[j] = tempDim;
145145

146-
double tempValue = values[j];
146+
float tempValue = values[j];
147147
values[j] = values[i];
148148
values[i] = tempValue;
149149
}

x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java

+12-12
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import java.util.List;
1515
import java.util.Map;
1616

17-
import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder.sortSparseDimsDoubleValues;
17+
import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder.sortSparseDimsFloatValues;
1818

1919
public class ScoreScriptUtils {
2020

@@ -37,7 +37,7 @@ public static double l1norm(List<Number> queryVector, VectorScriptDocValues.Dens
3737
Iterator<Number> queryVectorIter = queryVector.iterator();
3838
double l1norm = 0;
3939
for (int dim = 0; dim < docVector.length; dim++){
40-
l1norm += Math.abs(queryVectorIter.next().doubleValue() - docVector[dim]);
40+
l1norm += Math.abs(queryVectorIter.next().floatValue() - docVector[dim]);
4141
}
4242
return l1norm;
4343
}
@@ -59,7 +59,7 @@ public static double l2norm(List<Number> queryVector, VectorScriptDocValues.Dens
5959
Iterator<Number> queryVectorIter = queryVector.iterator();
6060
double l2norm = 0;
6161
for (int dim = 0; dim < docVector.length; dim++){
62-
double diff = queryVectorIter.next().doubleValue() - docVector[dim];
62+
double diff = queryVectorIter.next().floatValue() - docVector[dim];
6363
l2norm += diff * diff;
6464
}
6565
return Math.sqrt(l2norm);
@@ -97,11 +97,11 @@ public static final class CosineSimilarity {
9797
// calculate queryVectorMagnitude once per query execution
9898
public CosineSimilarity(List<Number> queryVector) {
9999
this.queryVector = queryVector;
100-
double doubleValue;
100+
101101
double dotProduct = 0;
102102
for (Number value : queryVector) {
103-
doubleValue = value.doubleValue();
104-
dotProduct += doubleValue * doubleValue;
103+
float floatValue = value.floatValue();
104+
dotProduct += floatValue * floatValue;
105105
}
106106
this.queryVectorMagnitude = Math.sqrt(dotProduct);
107107
}
@@ -130,7 +130,7 @@ private static double intDotProduct(List<Number> v1, float[] v2){
130130
double v1v2DotProduct = 0;
131131
Iterator<Number> v1Iter = v1.iterator();
132132
for (int dim = 0; dim < v2.length; dim++) {
133-
v1v2DotProduct += v1Iter.next().doubleValue() * v2[dim];
133+
v1v2DotProduct += v1Iter.next().floatValue() * v2[dim];
134134
}
135135
return v1v2DotProduct;
136136
}
@@ -139,15 +139,15 @@ private static double intDotProduct(List<Number> v1, float[] v2){
139139
//**************FUNCTIONS FOR SPARSE VECTORS
140140

141141
public static class VectorSparseFunctions {
142-
final double[] queryValues;
142+
final float[] queryValues;
143143
final int[] queryDims;
144144

145145
// prepare queryVector once per script execution
146146
// queryVector represents a map of dimensions to values
147147
public VectorSparseFunctions(Map<String, Number> queryVector) {
148148
//break vector into two arrays dims and values
149149
int n = queryVector.size();
150-
queryValues = new double[n];
150+
queryValues = new float[n];
151151
queryDims = new int[n];
152152
int i = 0;
153153
for (Map.Entry<String, Number> dimValue : queryVector.entrySet()) {
@@ -156,11 +156,11 @@ public VectorSparseFunctions(Map<String, Number> queryVector) {
156156
} catch (final NumberFormatException e) {
157157
throw new IllegalArgumentException("Failed to parse a query vector dimension, it must be an integer!", e);
158158
}
159-
queryValues[i] = dimValue.getValue().doubleValue();
159+
queryValues[i] = dimValue.getValue().floatValue();
160160
i++;
161161
}
162162
// Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions
163-
sortSparseDimsDoubleValues(queryDims, queryValues, n);
163+
sortSparseDimsFloatValues(queryDims, queryValues, n);
164164
}
165165
}
166166

@@ -317,7 +317,7 @@ public double cosineSimilaritySparse(VectorScriptDocValues.SparseVectorScriptDoc
317317
}
318318
}
319319

320-
private static double intDotProductSparse(double[] v1Values, int[] v1Dims, float[] v2Values, int[] v2Dims) {
320+
private static double intDotProductSparse(float[] v1Values, int[] v1Dims, float[] v2Values, int[] v2Dims) {
321321
double v1v2DotProduct = 0;
322322
int v1Index = 0;
323323
int v2Index = 0;

x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java

+5-5
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ public void testDenseVectorFunctions() {
3636
BytesRef encodedDocVector = mockEncodeDenseVector(docVector);
3737
VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class);
3838
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
39-
List<Number> queryVector = Arrays.asList(0.5, 111.3, -13.0, 14.8, -156.0);
39+
List<Number> queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
4040

4141
// test dotProduct
4242
double result = dotProduct(queryVector, dvs);
43-
assertEquals("dotProduct result is not equal to the expected value!", 65425.626, result, 0.001);
43+
assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001);
4444

4545
// test cosineSimilarity
4646
CosineSimilarity cosineSimilarity = new CosineSimilarity(queryVector);
@@ -91,7 +91,7 @@ public void testSparseVectorFunctions() {
9191
// test dotProduct
9292
DotProductSparse docProductSparse = new DotProductSparse(queryVector);
9393
double result = docProductSparse.dotProductSparse(dvs);
94-
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.626, result, 0.001);
94+
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
9595

9696
// test cosineSimilarity
9797
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(queryVector);
@@ -128,7 +128,7 @@ public void testSparseVectorMissingDimensions1() {
128128
// test dotProduct
129129
DotProductSparse docProductSparse = new DotProductSparse(queryVector);
130130
double result = docProductSparse.dotProductSparse(dvs);
131-
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.626, result, 0.001);
131+
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
132132

133133
// test cosineSimilarity
134134
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(queryVector);
@@ -165,7 +165,7 @@ public void testSparseVectorMissingDimensions2() {
165165
// test dotProduct
166166
DotProductSparse docProductSparse = new DotProductSparse(queryVector);
167167
double result = docProductSparse.dotProductSparse(dvs);
168-
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.626, result, 0.001);
168+
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
169169

170170
// test cosineSimilarity
171171
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(queryVector);

0 commit comments

Comments
 (0)