diff --git a/docs/reference/mapping/types/dense-vector.asciidoc b/docs/reference/mapping/types/dense-vector.asciidoc index f79bdde9cc4dd..7ea04b42330c2 100644 --- a/docs/reference/mapping/types/dense-vector.asciidoc +++ b/docs/reference/mapping/types/dense-vector.asciidoc @@ -7,9 +7,7 @@ experimental[] A `dense_vector` field stores dense vectors of float values. The maximum number of dimensions that can be in a vector should -not exceed 1024. The number of dimensions can be -different across documents. A `dense_vector` field is -a single-valued field. +not exceed 1024. A `dense_vector` field is a single-valued field. These vectors can be used for <>. For example, a document score can represent a distance between @@ -24,7 +22,8 @@ PUT my_index "mappings": { "properties": { "my_vector": { - "type": "dense_vector" + "type": "dense_vector", + "dims": 3 <1> }, "my_text" : { "type" : "keyword" @@ -42,13 +41,14 @@ PUT my_index/_doc/1 PUT my_index/_doc/2 { "my_text" : "text2", - "my_vector" : [-0.5, 10, 10, 4] + "my_vector" : [-0.5, 10, 10] } -------------------------------------------------- // CONSOLE +<1> dims—the number of dimensions in the vector, required parameter. + Internally, each document's dense vector is encoded as a binary doc value. Its size in bytes is equal to -`4 * NUMBER_OF_DIMENSIONS`, where `NUMBER_OF_DIMENSIONS` - -number of the vector's dimensions. \ No newline at end of file +`4 * dims`, where `dims`—the number of the vector's dimensions. \ No newline at end of file diff --git a/docs/reference/query-dsl/script-score-query.asciidoc b/docs/reference/query-dsl/script-score-query.asciidoc index 42e0ec083d560..401d323f6fff4 100644 --- a/docs/reference/query-dsl/script-score-query.asciidoc +++ b/docs/reference/query-dsl/script-score-query.asciidoc @@ -199,8 +199,7 @@ a vector function is executed, 0 is returned as a result for this document. NOTE: If a document's dense vector field has a number of dimensions -different from the query's vector, 0 is used for missing dimensions -in the calculations of vector functions. +different from the query's vector, an error will be thrown. [[random-score-function]] diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml index 018c5546c02dd..5ac20486fe3e7 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml @@ -1,8 +1,8 @@ setup: - skip: features: headers - version: " - 7.2.99" - reason: "dense_vector functions were introduced in 7.3.0" + version: " - 7.99.99" # TODO: change to 7.2.99 after backport + reason: "dense_vector dims parameter was added from 8.0" - do: indices.create: @@ -15,6 +15,7 @@ setup: properties: my_dense_vector: type: dense_vector + dims: 5 - do: index: index: test-index diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml index 685bf2ae97a2f..40d8cdd636aec 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml @@ -1,8 +1,8 @@ setup: - skip: features: headers - version: " - 7.2.99" - reason: "dense_vector functions were introduced in 7.3.0" + version: " - 7.99.99" # TODO: change to 7.2.99 after backport + reason: "dense_vector dims parameter was added from 8.0" - do: indices.create: @@ -17,31 +17,36 @@ setup: properties: my_dense_vector: type: dense_vector + dims: 3 --- -"Vectors of different dimensions and data types": -# document vectors of different dimensions +"Indexing of Dense vectors should error when dims don't match defined in the mapping": + - do: + catch: bad_request index: index: test-index id: 1 body: - my_dense_vector: [10] + my_dense_vector: [10, 2] + - match: { error.type: "mapper_parsing_exception" } +--- +"Vectors of mixed integers and floats": - do: index: index: test-index - id: 2 + id: 1 body: - my_dense_vector: [10, 10.5] + my_dense_vector: [10, 10, 10] - do: index: index: test-index - id: 3 + id: 2 body: - my_dense_vector: [10, 10.5, 100.5] + my_dense_vector: [10.9, 10.9, 10.9] - do: indices.refresh: {} @@ -59,14 +64,13 @@ setup: script: source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" params: - query_vector: [10] + query_vector: [10, 10, 10] - - match: {hits.total: 3} + - match: {hits.total: 2} - match: {hits.hits.0._id: "1"} - match: {hits.hits.1._id: "2"} - - match: {hits.hits.2._id: "3"} -# query vector of type double +# query vector of type float - do: headers: Content-Type: application/json @@ -79,12 +83,52 @@ setup: script: source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" params: - query_vector: [10.0] + query_vector: [10.0, 10.0, 10.0] - - match: {hits.total: 3} + - match: {hits.total: 2} - match: {hits.hits.0._id: "1"} - match: {hits.hits.1._id: "2"} - - match: {hits.hits.2._id: "3"} + + +--- +"Functions with query vectors with dims different from docs vectors should error": + - do: + index: + index: test-index + id: 1 + body: + my_dense_vector: [1, 2, 3] + + - do: + indices.refresh: {} + + - do: + catch: bad_request + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + params: + query_vector: [1, 2, 3, 4] + - match: { error.root_cause.0.type: "script_exception" } + + - do: + catch: bad_request + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "dotProduct(params.query_vector, doc['my_dense_vector'])" + params: + query_vector: [1, 2, 3, 4] + - match: { error.root_cause.0.type: "script_exception" } --- "Distance functions for documents missing vector field should return 0": @@ -93,7 +137,7 @@ setup: index: test-index id: 1 body: - my_dense_vector: [10] + my_dense_vector: [10, 10, 10] - do: index: @@ -117,7 +161,7 @@ setup: script: source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" params: - query_vector: [10.0] + query_vector: [10.0, 10.0, 10.0] - match: {hits.total: 2} - match: {hits.hits.0._id: "1"} @@ -149,5 +193,5 @@ setup: script: source: "dotProductSparse(params.query_vector, doc['my_dense_vector'])" params: - query_vector: {"2": 0.5, "10" : 111.3} + query_vector: {"2": 0.5, "10" : 111.3, "3": 44} - match: { error.root_cause.0.type: "script_exception" } diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java index 70e597a9a489b..a7773e3e3c527 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java @@ -12,10 +12,11 @@ import org.apache.lucene.index.IndexableField; import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.Query; -import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser.Token; +import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.ArrayValueMapperParser; import org.elasticsearch.index.mapper.FieldMapper; @@ -56,12 +57,28 @@ public static class Defaults { } public static class Builder extends FieldMapper.Builder { + private int dims = 0; public Builder(String name) { super(name, Defaults.FIELD_TYPE, Defaults.FIELD_TYPE); builder = this; } + public Builder dims(int dims) { + if ((dims > MAX_DIMS_COUNT) || (dims < 1)) { + throw new MapperParsingException("The number of dimensions for field [" + name + + "] should be in the range [1, " + MAX_DIMS_COUNT + "]"); + } + this.dims = dims; + return this; + } + + @Override + protected void setupFieldType(BuilderContext context) { + super.setupFieldType(context); + fieldType().setDims(dims); + } + @Override public DenseVectorFieldType fieldType() { return (DenseVectorFieldType) super.fieldType(); @@ -80,11 +97,17 @@ public static class TypeParser implements Mapper.TypeParser { @Override public Mapper.Builder parse(String name, Map node, ParserContext parserContext) throws MapperParsingException { DenseVectorFieldMapper.Builder builder = new DenseVectorFieldMapper.Builder(name); - return builder; + Object dimsField = node.remove("dims"); + if (dimsField == null) { + throw new MapperParsingException("The [dims] property must be specified for field [" + name + "]."); + } + int dims = XContentMapValues.nodeIntegerValue(dimsField); + return builder.dims(dims); } } public static final class DenseVectorFieldType extends MappedFieldType { + private int dims; public DenseVectorFieldType() {} @@ -96,6 +119,14 @@ public DenseVectorFieldType clone() { return new DenseVectorFieldType(this); } + int dims() { + return dims; + } + + void setDims(int dims) { + this.dims = dims; + } + @Override public String typeName() { return CONTENT_TYPE; @@ -145,28 +176,30 @@ public void parse(ParseContext context) throws IOException { if (context.externalValueSet()) { throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] can't be used in multi-fields"); } + int dims = fieldType().dims(); //number of vector dimensions // encode array of floats as array of integers and store into buf // this code is here and not int the VectorEncoderDecoder so not to create extra arrays - byte[] buf = new byte[0]; + byte[] buf = new byte[dims * INT_BYTES]; int offset = 0; int dim = 0; for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) { + if (dim++ >= dims) { + throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] of doc [" + + context.sourceToParse().id() + "] has exceeded the number of dimensions [" + dims + "] defined in mapping"); + } ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()::getTokenLocation); float value = context.parser().floatValue(true); - if (buf.length < (offset + INT_BYTES)) { - buf = ArrayUtil.grow(buf, (offset + INT_BYTES)); - } int intValue = Float.floatToIntBits(value); - buf[offset] = (byte) (intValue >> 24); - buf[offset+1] = (byte) (intValue >> 16); - buf[offset+2] = (byte) (intValue >> 8); - buf[offset+3] = (byte) intValue; - offset += INT_BYTES; - if (dim++ >= MAX_DIMS_COUNT) { - throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + - "] has exceeded the maximum allowed number of dimensions of [" + MAX_DIMS_COUNT + "]"); - } + buf[offset++] = (byte) (intValue >> 24); + buf[offset++] = (byte) (intValue >> 16); + buf[offset++] = (byte) (intValue >> 8); + buf[offset++] = (byte) intValue; + } + if (dim != dims) { + throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] of doc [" + + context.sourceToParse().id() + "] has number of dimensions [" + dim + + "] less than defined in the mapping [" + dims +"]"); } BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(buf, 0, offset)); if (context.doc().getByKey(fieldType().name()) != null) { @@ -176,6 +209,12 @@ public void parse(ParseContext context) throws IOException { context.doc().addWithKey(fieldType().name(), field); } + @Override + protected void doXContentBody(XContentBuilder builder, boolean includeDefaults, Params params) throws IOException { + super.doXContentBody(builder, includeDefaults, params); + builder.field("dims", fieldType().dims()); + } + @Override protected void parseCreateField(ParseContext context, List fields) { throw new AssertionError("parse is implemented directly"); diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java index 3058ff51434db..31b94ae108e63 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java @@ -162,12 +162,11 @@ public static float[] decodeDenseVector(BytesRef vectorBR) { float[] vector = new float[dimCount]; int offset = vectorBR.offset; for (int dim = 0; dim < dimCount; dim++) { - int intValue = ((vectorBR.bytes[offset] & 0xFF) << 24) | - ((vectorBR.bytes[offset+1] & 0xFF) << 16) | - ((vectorBR.bytes[offset+2] & 0xFF) << 8) | - (vectorBR.bytes[offset+3] & 0xFF); + int intValue = ((vectorBR.bytes[offset++] & 0xFF) << 24) | + ((vectorBR.bytes[offset++] & 0xFF) << 16) | + ((vectorBR.bytes[offset++] & 0xFF) << 8) | + (vectorBR.bytes[offset++] & 0xFF); vector[dim] = Float.intBitsToFloat(intValue); - offset = offset + INT_BYTES; } return vector; } diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java index e73f40f900bac..fcb02cee68822 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java @@ -30,6 +30,10 @@ public static double dotProduct(List queryVector, VectorScriptDocValues. BytesRef value = dvs.getEncodedValue(); if (value == null) return 0; float[] docVector = VectorEncoderDecoder.decodeDenseVector(value); + if (queryVector.size() != docVector.length) { + throw new IllegalArgumentException("Can't calculate dotProduct! The number of dimensions of the query vector [" + + queryVector.size() + "] is different from the documents' vectors [" + docVector.length + "]."); + } return intDotProduct(queryVector, docVector); } @@ -61,6 +65,10 @@ public double cosineSimilarity(VectorScriptDocValues.DenseVectorScriptDocValues BytesRef value = dvs.getEncodedValue(); if (value == null) return 0; float[] docVector = VectorEncoderDecoder.decodeDenseVector(value); + if (queryVector.size() != docVector.length) { + throw new IllegalArgumentException("Can't calculate cosineSimilarity! The number of dimensions of the query vector [" + + queryVector.size() + "] is different from the documents' vectors [" + docVector.length + "]."); + } // calculate docVector magnitude double dotProduct = 0f; @@ -75,13 +83,10 @@ public double cosineSimilarity(VectorScriptDocValues.DenseVectorScriptDocValues } private static double intDotProduct(List v1, float[] v2){ - int dims = Math.min(v1.size(), v2.length); double v1v2DotProduct = 0; - int dim = 0; Iterator v1Iter = v1.iterator(); - while(dim < dims) { + for (int dim = 0; dim < v2.length; dim++) { v1v2DotProduct += v1Iter.next().doubleValue() * v2[dim]; - dim++; } return v1v2DotProduct; } diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java index b8a804effe78f..d1b37c73a246e 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java @@ -26,8 +26,6 @@ import org.elasticsearch.xpack.core.XPackPlugin; import org.elasticsearch.xpack.vectors.Vectors; -import org.junit.Before; - import java.io.IOException; import java.util.Collection; @@ -35,67 +33,93 @@ import static org.hamcrest.Matchers.instanceOf; public class DenseVectorFieldMapperTests extends ESSingleNodeTestCase { - private DocumentMapper mapper; + @Override + protected Collection> getPlugins() { + return pluginList(Vectors.class, XPackPlugin.class); + } - @Before - public void setUpMapper() throws Exception { - IndexService indexService = createIndex("test-index"); + public void testMappingExceedDimsLimit() throws IOException { + IndexService indexService = createIndex("test-index"); DocumentMapperParser parser = indexService.mapperService().documentMapperParser(); String mapping = Strings.toString(XContentFactory.jsonBuilder() .startObject() - .startObject("_doc") - .startObject("properties") - .startObject("my-dense-vector").field("type", "dense_vector") - .endObject() - .endObject() - .endObject() + .startObject("_doc") + .startObject("properties") + .startObject("my-dense-vector").field("type", "dense_vector").field("dims", DenseVectorFieldMapper.MAX_DIMS_COUNT + 1) + .endObject() + .endObject() + .endObject() .endObject()); - mapper = parser.parse("_doc", new CompressedXContent(mapping)); - } - - @Override - protected Collection> getPlugins() { - return pluginList(Vectors.class, XPackPlugin.class); + MapperParsingException e = expectThrows(MapperParsingException.class, () -> parser.parse("_doc", new CompressedXContent(mapping))); + assertEquals(e.getMessage(), "The number of dimensions for field [my-dense-vector] should be in the range [1, 1024]"); } public void testDefaults() throws Exception { - float[] expectedArray = {-12.1f, 100.7f, -4}; + IndexService indexService = createIndex("test-index"); + DocumentMapperParser parser = indexService.mapperService().documentMapperParser(); + String mapping = Strings.toString(XContentFactory.jsonBuilder() + .startObject() + .startObject("_doc") + .startObject("properties") + .startObject("my-dense-vector").field("type", "dense_vector").field("dims", 3) + .endObject() + .endObject() + .endObject() + .endObject()); + DocumentMapper mapper = parser.parse("_doc", new CompressedXContent(mapping)); + + float[] validVector = {-12.1f, 100.7f, -4}; ParsedDocument doc1 = mapper.parse(new SourceToParse("test-index", "_doc", "1", BytesReference .bytes(XContentFactory.jsonBuilder() .startObject() - .startArray("my-dense-vector").value(expectedArray[0]).value(expectedArray[1]).value(expectedArray[2]).endArray() + .startArray("my-dense-vector").value(validVector[0]).value(validVector[1]).value(validVector[2]).endArray() .endObject()), XContentType.JSON)); IndexableField[] fields = doc1.rootDoc().getFields("my-dense-vector"); assertEquals(1, fields.length); assertThat(fields[0], instanceOf(BinaryDocValuesField.class)); - // assert that after decoding the indexed value is equal to expected - BytesRef vectorBR = ((BinaryDocValuesField) fields[0]).binaryValue(); + BytesRef vectorBR = fields[0].binaryValue(); float[] decodedValues = VectorEncoderDecoder.decodeDenseVector(vectorBR); assertArrayEquals( "Decoded dense vector values is not equal to the indexed one.", - expectedArray, + validVector, decodedValues, 0.001f ); } - public void testDimensionLimit() throws IOException { - float[] validVector = new float[DenseVectorFieldMapper.MAX_DIMS_COUNT]; - BytesReference validDoc = BytesReference.bytes( - XContentFactory.jsonBuilder().startObject() - .array("my-dense-vector", validVector) + public void testDocumentsWithIncorrectDims() throws Exception { + IndexService indexService = createIndex("test-index"); + int dims = 3; + DocumentMapperParser parser = indexService.mapperService().documentMapperParser(); + String mapping = Strings.toString(XContentFactory.jsonBuilder() + .startObject() + .startObject("_doc") + .startObject("properties") + .startObject("my-dense-vector").field("type", "dense_vector").field("dims", dims) + .endObject() + .endObject() + .endObject() .endObject()); - mapper.parse(new SourceToParse("test-index", "_doc", "1", validDoc, XContentType.JSON)); + DocumentMapper mapper = parser.parse("_doc", new CompressedXContent(mapping)); - float[] invalidVector = new float[DenseVectorFieldMapper.MAX_DIMS_COUNT + 1]; - BytesReference invalidDoc = BytesReference.bytes( - XContentFactory.jsonBuilder().startObject() - .array("my-dense-vector", invalidVector) - .endObject()); + // test that error is thrown when a document has number of dims more than defined in the mapping + float[] invalidVector = new float[dims + 1]; + BytesReference invalidDoc = BytesReference.bytes(XContentFactory.jsonBuilder().startObject() + .array("my-dense-vector", invalidVector) + .endObject()); MapperParsingException e = expectThrows(MapperParsingException.class, () -> mapper.parse( new SourceToParse("test-index", "_doc", "1", invalidDoc, XContentType.JSON))); - assertThat(e.getDetailedMessage(), containsString("has exceeded the maximum allowed number of dimensions")); + assertThat(e.getCause().getMessage(), containsString("has exceeded the number of dimensions [3] defined in mapping")); + + // test that error is thrown when a document has number of dims less than defined in the mapping + float[] invalidVector2 = new float[dims - 1]; + BytesReference invalidDoc2 = BytesReference.bytes(XContentFactory.jsonBuilder().startObject() + .array("my-dense-vector", invalidVector2) + .endObject()); + MapperParsingException e2 = expectThrows(MapperParsingException.class, () -> mapper.parse( + new SourceToParse("test-index", "_doc", "2", invalidDoc2, XContentType.JSON))); + assertThat(e2.getCause().getMessage(), containsString("has number of dimensions [2] less than defined in the mapping [3]")); } } diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java index 9acbe44630d99..939d999b0d9aa 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java @@ -16,14 +16,14 @@ public class VectorEncoderDecoderTests extends ESTestCase { public void testDenseVectorEncodingDecoding() { - int dimCount = randomIntBetween(0, 300); + int dimCount = randomIntBetween(0, DenseVectorFieldMapper.MAX_DIMS_COUNT); float[] expectedValues = new float[dimCount]; for (int i = 0; i < dimCount; i++) { expectedValues[i] = randomFloat(); } // test that values that went through encoding and decoding are equal to their original - BytesRef encodedDenseVector = mockEncodeDenseVector(expectedValues); + BytesRef encodedDenseVector = mockEncodeDenseVector(expectedValues); float[] decodedValues = VectorEncoderDecoder.decodeDenseVector(encodedDenseVector); assertArrayEquals( "Decoded dense vector values are not equal to their original.", @@ -31,7 +31,6 @@ public void testDenseVectorEncodingDecoding() { decodedValues, 0.001f ); - } public void testSparseVectorEncodingDecoding() { @@ -70,18 +69,17 @@ public void testSparseVectorEncodingDecoding() { } // imitates the code in DenseVectorFieldMapper::parse - public static BytesRef mockEncodeDenseVector(float[] dims) { + public static BytesRef mockEncodeDenseVector(float[] values) { final short INT_BYTES = VectorEncoderDecoder.INT_BYTES; - byte[] buf = new byte[INT_BYTES * dims.length]; + byte[] buf = new byte[INT_BYTES * values.length]; int offset = 0; int intValue; - for (float value: dims) { + for (float value: values) { intValue = Float.floatToIntBits(value); - buf[offset] = (byte) (intValue >> 24); - buf[offset+1] = (byte) (intValue >> 16); - buf[offset+2] = (byte) (intValue >> 8); - buf[offset+3] = (byte) intValue; - offset += INT_BYTES; + buf[offset++] = (byte) (intValue >> 24); + buf[offset++] = (byte) (intValue >> 16); + buf[offset++] = (byte) (intValue >> 8); + buf[offset++] = (byte) intValue; } return new BytesRef(buf, 0, offset); } diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java index 538b49977e108..699a9b09fb537 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java @@ -20,6 +20,7 @@ import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoderTests.mockEncodeDenseVector; import static org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.dotProduct; +import static org.hamcrest.Matchers.containsString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -40,6 +41,17 @@ public void testDenseVectorFunctions() { CosineSimilarity cosineSimilarity = new CosineSimilarity(queryVector); double result2 = cosineSimilarity.cosineSimilarity(dvs); assertEquals("cosineSimilarity result is not equal to the expected value!", 0.78, result2, 0.1); + + // test dotProduct fails when queryVector has wrong number of dims + List invalidQueryVector = Arrays.asList(0.5, 111.3); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> dotProduct(invalidQueryVector, dvs)); + assertThat(e.getMessage(), containsString("dimensions of the query vector [2] is different from the documents' vectors [5]")); + + // test cosineSimilarity fails when queryVector has wrong number of dims + CosineSimilarity cosineSimilarity2 = new CosineSimilarity(invalidQueryVector); + e = expectThrows(IllegalArgumentException.class, () -> cosineSimilarity2.cosineSimilarity(dvs)); + assertThat(e.getMessage(), containsString("dimensions of the query vector [2] is different from the documents' vectors [5]")); + } public void testSparseVectorFunctions() {