diff --git a/docs/reference/mapping/types/dense-vector.asciidoc b/docs/reference/mapping/types/dense-vector.asciidoc index 39e33bae94e12..95987477e4993 100644 --- a/docs/reference/mapping/types/dense-vector.asciidoc +++ b/docs/reference/mapping/types/dense-vector.asciidoc @@ -10,9 +10,8 @@ A `dense_vector` field stores dense vectors of float values. The maximum number of dimensions that can be in a vector should not exceed 2048. A `dense_vector` field is a single-valued field. -These vectors can be used for <>. -For example, a document score can represent a distance between -a given query vector and the indexed document vector. +`dense_vector` fields do not support querying, sorting or aggregating. They can +only be accessed in scripts through the dedicated <>. You index a dense vector as an array of floats. @@ -47,4 +46,4 @@ PUT my-index-000001/_doc/2 -------------------------------------------------- -<1> dims—the number of dimensions in the vector, required parameter. +<1> dims – the number of dimensions in the vector, required parameter. diff --git a/docs/reference/vectors/vector-functions.asciidoc b/docs/reference/vectors/vector-functions.asciidoc index e1ee37d166885..21f1761f28ce7 100644 --- a/docs/reference/vectors/vector-functions.asciidoc +++ b/docs/reference/vectors/vector-functions.asciidoc @@ -8,6 +8,16 @@ linearly scanned. Thus, expect the query time grow linearly with the number of matched documents. For this reason, we recommend to limit the number of matched documents with a `query` parameter. +This is the list of available vector functions and vector access methods: + +1. `cosineSimilarity` – calculates cosine similarity +2. `dotProduct` – calculates dot product +3. `l1norm` – calculates L^1^ distance +4. `l2norm` - calculates L^2^ distance +5. `doc[].vectorValue` – returns a vector's value as an array of floats +6. `doc[].magnitude` – returns a vector's magnitude + + Let's create an index with a `dense_vector` mapping and index a couple of documents into it. @@ -195,3 +205,51 @@ You can check if a document has a value for the field `my_vector` by "source": "doc['my_vector'].size() == 0 ? 0 : cosineSimilarity(params.queryVector, 'my_vector')" -------------------------------------------------- // NOTCONSOLE + +The recommended way to access dense vectors is through `cosineSimilarity`, +`dotProduct`, `l1norm` or `l2norm` functions. But for custom use cases, +you can access dense vectors's values directly through the following functions: + +- `doc[].vectorValue` – returns a vector's value as an array of floats + +- `doc[].magnitude` – returns a vector's magnitude as a float +(for vectors created prior to version 7.5 the magnitude is not stored. +So this function calculates it anew every time it is called). + +For example, the script below implements a cosine similarity using these +two functions: + +[source,console] +-------------------------------------------------- +GET my-index-000001/_search +{ + "query": { + "script_score": { + "query" : { + "bool" : { + "filter" : { + "term" : { + "status" : "published" + } + } + } + }, + "script": { + "source": """ + float[] v = doc['my_dense_vector'].vectorValue; + float vm = doc['my_dense_vector'].magnitude; + float dotProduct = 0; + for (int i = 0; i < v.length; i++) { + dotProduct += v[i] * params.queryVector[i]; + } + return dotProduct / (vm * (float) params.queryVectorMag); + """, + "params": { + "queryVector": [4, 3.4, -0.2], + "queryVectorMag": 5.25357 + } + } + } + } +} +-------------------------------------------------- diff --git a/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java index e3a9b44fd3197..1573f11b88eca 100644 --- a/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java +++ b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java @@ -14,7 +14,6 @@ import org.elasticsearch.script.ExplainableScoreScript; import org.elasticsearch.script.ScoreScript; import org.elasticsearch.script.Script; -import org.elasticsearch.Version; import java.io.IOException; import java.util.Objects; @@ -42,15 +41,13 @@ public float score() { private final int shardId; private final String indexName; - private final Version indexVersion; - public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId, Version indexVersion) { + public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId) { super(CombineFunction.REPLACE); this.sScript = sScript; this.script = script; this.indexName = indexName; this.shardId = shardId; - this.indexVersion = indexVersion; } @Override @@ -60,7 +57,6 @@ public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOEx leafScript.setScorer(scorer); leafScript._setIndexName(indexName); leafScript._setShard(shardId); - leafScript._setIndexVersion(indexVersion); return new LeafScoreFunction() { @Override public double score(int docId, float subQueryScore) throws IOException { diff --git a/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java index fae31341458ee..c33f588ac3670 100644 --- a/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java +++ b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java @@ -146,7 +146,6 @@ private ScoreScript makeScoreScript(LeafReaderContext context) throws IOExceptio final ScoreScript scoreScript = scriptBuilder.newInstance(context); scoreScript._setIndexName(indexName); scoreScript._setShard(shardId); - scoreScript._setIndexVersion(indexVersion); return scoreScript; } diff --git a/server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java b/server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java index 16ccc4be3a22f..4833d0c7ef727 100644 --- a/server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java @@ -83,8 +83,7 @@ protected ScoreFunction doToFunction(SearchExecutionContext context) { try { ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT); ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup()); - return new ScriptScoreFunction(script, searchScript, - context.index().getName(), context.getShardId(), context.indexVersionCreated()); + return new ScriptScoreFunction(script, searchScript, context.index().getName(), context.getShardId()); } catch (Exception e) { throw new QueryShardException(context, "script_score: the script could not be loaded", e); } diff --git a/server/src/main/java/org/elasticsearch/script/ScoreScript.java b/server/src/main/java/org/elasticsearch/script/ScoreScript.java index c90fd52050611..c4ce79d6c241f 100644 --- a/server/src/main/java/org/elasticsearch/script/ScoreScript.java +++ b/server/src/main/java/org/elasticsearch/script/ScoreScript.java @@ -10,7 +10,6 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Scorable; -import org.elasticsearch.Version; import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.index.fielddata.ScriptDocValues; @@ -85,7 +84,6 @@ public Explanation get(double score, Explanation subQueryExplanation) { private int docId; private int shardId = -1; private String indexName = null; - private Version indexVersion = null; public ScoreScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { // null check needed b/c of expression engine subclass @@ -185,19 +183,6 @@ public String _getIndex() { } } - /** - * Starting a name with underscore, so that the user cannot access this function directly through a script - * It is only used within predefined painless functions. - * @return index version or throws an exception if the index version is not set up for this script instance - */ - public Version _getIndexVersion() { - if (indexVersion != null) { - return indexVersion; - } else { - throw new IllegalArgumentException("index version can not be looked up!"); - } - } - /** * Starting a name with underscore, so that the user cannot access this function directly through a script */ @@ -212,13 +197,6 @@ public void _setIndexName(String indexName) { this.indexName = indexName; } - /** - * Starting a name with underscore, so that the user cannot access this function directly through a script - */ - public void _setIndexVersion(Version indexVersion) { - this.indexVersion = indexVersion; - } - /** A factory to construct {@link ScoreScript} instances. */ public interface LeafFactory { diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/30_dense_vector_script_access.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/30_dense_vector_script_access.yml new file mode 100644 index 0000000000000..ef670b004507f --- /dev/null +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/30_dense_vector_script_access.yml @@ -0,0 +1,65 @@ +--- +"Access to values of dense_vector in script": + - skip: + version: " - 7.12.99" + reason: "Access to values of dense_vector in script was added in 7.13" + - do: + indices.create: + index: test-index + body: + mappings: + properties: + v: + type: dense_vector + dims: 3 + + - do: + bulk: + index: test-index + refresh: true + body: + - '{"index": {"_id": "1"}}' + - '{"v": [1, 1, 1]}' + - '{"index": {"_id": "2"}}' + - '{"v": [1, 1, 2]}' + - '{"index": {"_id": "3"}}' + - '{"v": [1, 1, 3]}' + - '{"index": {"_id": "missing_vector"}}' + - '{}' + + # vector functions in loop – return the index of the closest parameter vector based on cosine similarity + - do: + search: + body: + query: + script_score: + query: { "exists": { "field": "v" } } + script: + source: | + float[] v = doc['v'].vectorValue; + float vm = doc['v'].magnitude; + + int closestPv = 0; + float maxCosSim = -1; + for (int i = 0; i < params.pvs.length; i++) { + float dotProduct = 0; + for (int j = 0; j < v.length; j++) { + dotProduct += v[j] * params.pvs[i][j]; + } + float cosSim = dotProduct / (vm * (float) params.pvs_magnts[i]); + if (maxCosSim < cosSim) { + maxCosSim = cosSim; + closestPv = i; + } + } + closestPv; + params: + pvs: [ [ 1, 1, 1 ], [ 1, 1, 2 ], [ 1, 1, 3 ] ] + pvs_magnts: [1.7320, 2.4495, 3.3166] + + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.0._score: 2 } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.1._score: 1 } + - match: { hits.hits.2._id: "1" } + - match: { hits.hits.2._score: 0 } diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java index 5fc3d3c235893..d4ae6cf4093fc 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java @@ -82,7 +82,7 @@ protected List> getParameters() { public DenseVectorFieldMapper build(ContentPath contentPath) { return new DenseVectorFieldMapper( name, - new DenseVectorFieldType(buildFullName(contentPath), dims.getValue(), meta.getValue()), + new DenseVectorFieldType(buildFullName(contentPath), indexVersionCreated, dims.getValue(), meta.getValue()), dims.getValue(), indexVersionCreated, multiFieldsBuilder.build(this, contentPath), @@ -94,10 +94,12 @@ public DenseVectorFieldMapper build(ContentPath contentPath) { public static final class DenseVectorFieldType extends MappedFieldType { private final int dims; + private final Version indexVersionCreated; - public DenseVectorFieldType(String name, int dims, Map meta) { + public DenseVectorFieldType(String name, Version indexVersionCreated, int dims, Map meta) { super(name, false, false, true, TextSearchInfo.NONE, meta); this.dims = dims; + this.indexVersionCreated = indexVersionCreated; } int dims() { @@ -124,7 +126,7 @@ protected Object parseSourceValue(Object value) { @Override public DocValueFormat docValueFormat(String format, ZoneId timeZone) { - throw new UnsupportedOperationException( + throw new IllegalArgumentException( "Field [" + name() + "] of type [" + typeName() + "] doesn't support docvalue_fields or aggregations"); } @@ -135,7 +137,7 @@ public boolean isAggregatable() { @Override public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier searchLookup) { - return new VectorIndexFieldData.Builder(name(), CoreValuesSourceType.KEYWORD); + return new VectorIndexFieldData.Builder(name(), CoreValuesSourceType.KEYWORD, indexVersionCreated, dims); } @Override diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java index 65073bdda472f..56bd68ad08cb0 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java @@ -13,6 +13,7 @@ import java.nio.ByteBuffer; + public final class VectorEncoderDecoder { public static final byte INT_BYTES = 4; @@ -29,9 +30,51 @@ public static int denseVectorLength(Version indexVersion, BytesRef vectorBR) { * NOTE: this function can only be called on vectors from an index version greater than or * equal to 7.5.0, since vectors created prior to that do not store the magnitude. */ - public static float decodeVectorMagnitude(Version indexVersion, BytesRef vectorBR) { + public static float decodeMagnitude(Version indexVersion, BytesRef vectorBR) { assert indexVersion.onOrAfter(Version.V_7_5_0); ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length); - return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - 4); + return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - INT_BYTES); + } + + /** + * Calculates vector magnitude + */ + private static float calculateMagnitude(Version indexVersion, BytesRef vectorBR) { + final int length = denseVectorLength(indexVersion, vectorBR); + ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length); + double magnitude = 0.0f; + for (int i = 0; i < length; i++) { + float value = byteBuffer.getFloat(); + magnitude += value * value; + } + magnitude = Math.sqrt(magnitude); + return (float) magnitude; + } + + public static float getMagnitude(Version indexVersion, BytesRef vectorBR) { + if (vectorBR == null) { + throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); + } + if (indexVersion.onOrAfter(Version.V_7_5_0)) { + return decodeMagnitude(indexVersion, vectorBR); + } else { + return calculateMagnitude(indexVersion, vectorBR); + } + } + + /** + * Decodes a BytesRef into the provided array of floats + * @param vectorBR - dense vector encoded in BytesRef + * @param vector - array of floats where the decoded vector should be stored + */ + public static void decodeDenseVector(BytesRef vectorBR, float[] vector) { + if (vectorBR == null) { + throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); + } + ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length); + for (int dim = 0; dim < vector.length; dim++) { + vector[dim] = byteBuffer.getFloat(); + } } + } diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValues.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValues.java index ba00639cacbd0..18e2e80a090bc 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValues.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValues.java @@ -10,17 +10,26 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.Version; import org.elasticsearch.index.fielddata.ScriptDocValues; +import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder; import java.io.IOException; public class DenseVectorScriptDocValues extends ScriptDocValues { private final BinaryDocValues in; + private final Version indexVersion; + private final int dims; + private final float[] vector; private BytesRef value; - DenseVectorScriptDocValues(BinaryDocValues in) { + + DenseVectorScriptDocValues(BinaryDocValues in, Version indexVersion, int dims) { this.in = in; + this.indexVersion = indexVersion; + this.dims = dims; + this.vector = new float[dims]; } @Override @@ -37,9 +46,30 @@ BytesRef getEncodedValue() { return value; } + // package private access only for {@link ScoreScriptUtils} + int dims() { + return dims; + } + @Override public BytesRef get(int index) { - throw new UnsupportedOperationException("accessing a vector field's value through 'get' or 'value' is not supported"); + throw new UnsupportedOperationException("accessing a vector field's value through 'get' or 'value' is not supported!" + + "Use 'vectorValue' or 'magnitude' instead!'"); + } + + /** + * Get dense vector's value as an array of floats + */ + public float[] getVectorValue() { + VectorEncoderDecoder.decodeDenseVector(value, vector); + return vector; + } + + /** + * Get dense vector's magnitude + */ + public float getMagnitude() { + return VectorEncoderDecoder.getMagnitude(indexVersion, value); } @Override diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java index d7dc51c13b359..a917d7ddf52cc 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java @@ -10,9 +10,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.ExceptionsHelper; -import org.elasticsearch.Version; import org.elasticsearch.script.ScoreScript; -import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder; import java.io.IOException; import java.nio.ByteBuffer; @@ -45,6 +43,11 @@ public DenseVectorFunction(ScoreScript scoreScript, this.scoreScript = scoreScript; this.docValues = (DenseVectorScriptDocValues) scoreScript.getDoc().get(field); + if (docValues.dims() != queryVector.size()){ + throw new IllegalArgumentException("The query vector has a different number of dimensions [" + + queryVector.size() + "] than the document vectors [" + docValues.dims() + "]."); + } + this.queryVector = new float[queryVector.size()]; double queryMagnitude = 0.0; for (int i = 0; i < queryVector.size(); i++) { @@ -67,18 +70,10 @@ BytesRef getEncodedVector() { } catch (IOException e) { throw ExceptionsHelper.convertToElastic(e); } - - // Validate the encoded vector's length. BytesRef vector = docValues.getEncodedValue(); if (vector == null) { throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); } - - int vectorLength = VectorEncoderDecoder.denseVectorLength(scoreScript._getIndexVersion(), vector); - if (queryVector.length != vectorLength) { - throw new IllegalArgumentException("The query vector has a different number of dimensions [" + - queryVector.length + "] than the document vectors [" + vectorLength + "]."); - } return vector; } } @@ -152,23 +147,11 @@ public CosineSimilarity(ScoreScript scoreScript, List queryVector, Strin public double cosineSimilarity() { BytesRef vector = getEncodedVector(); ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); - double dotProduct = 0.0; - double vectorMagnitude = 0.0f; - if (scoreScript._getIndexVersion().onOrAfter(Version.V_7_5_0)) { - for (float queryValue : queryVector) { - dotProduct += queryValue * byteBuffer.getFloat(); - } - vectorMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(scoreScript._getIndexVersion(), vector); - } else { - for (float queryValue : queryVector) { - float docValue = byteBuffer.getFloat(); - dotProduct += queryValue * docValue; - vectorMagnitude += docValue * docValue; - } - vectorMagnitude = (float) Math.sqrt(vectorMagnitude); + for (float queryValue : queryVector) { + dotProduct += queryValue * byteBuffer.getFloat(); } - return dotProduct / vectorMagnitude; + return dotProduct / docValues.getMagnitude(); } } } diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorDVLeafFieldData.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorDVLeafFieldData.java index edab9eecc848e..080f5a10b750c 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorDVLeafFieldData.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorDVLeafFieldData.java @@ -13,6 +13,7 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.util.Accountable; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.Version; import org.elasticsearch.index.fielddata.LeafFieldData; import org.elasticsearch.index.fielddata.ScriptDocValues; import org.elasticsearch.index.fielddata.SortedBinaryDocValues; @@ -25,10 +26,14 @@ final class VectorDVLeafFieldData implements LeafFieldData { private final LeafReader reader; private final String field; + private final Version indexVersion; + private final int dims; - VectorDVLeafFieldData(LeafReader reader, String field) { + VectorDVLeafFieldData(LeafReader reader, String field, Version indexVersion, int dims) { this.reader = reader; this.field = field; + this.indexVersion = indexVersion; + this.dims = dims; } @Override @@ -50,7 +55,7 @@ public SortedBinaryDocValues getBytesValues() { public ScriptDocValues getScriptValues() { try { final BinaryDocValues values = DocValues.getBinary(reader, field); - return new DenseVectorScriptDocValues(values); + return new DenseVectorScriptDocValues(values, indexVersion, dims); } catch (IOException e) { throw new IllegalStateException("Cannot load doc values for vector field!", e); } diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorIndexFieldData.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorIndexFieldData.java index 13fb8391a1982..2c6df884cb96d 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorIndexFieldData.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorIndexFieldData.java @@ -10,6 +10,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.SortField; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.index.fielddata.IndexFieldData; @@ -21,16 +22,21 @@ import org.elasticsearch.search.aggregations.support.ValuesSourceType; import org.elasticsearch.search.sort.BucketedSort; import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.xpack.vectors.mapper.DenseVectorFieldMapper; public class VectorIndexFieldData implements IndexFieldData { protected final String fieldName; protected final ValuesSourceType valuesSourceType; + private final Version indexVersion; + private final int dims; - public VectorIndexFieldData(String fieldName, ValuesSourceType valuesSourceType) { + public VectorIndexFieldData(String fieldName, ValuesSourceType valuesSourceType, Version indexVersion, int dims) { this.fieldName = fieldName; this.valuesSourceType = valuesSourceType; + this.indexVersion = indexVersion; + this.dims = dims; } @Override @@ -45,7 +51,8 @@ public ValuesSourceType getValuesSourceType() { @Override public SortField sortField(@Nullable Object missingValue, MultiValueMode sortMode, Nested nested, boolean reverse) { - throw new IllegalArgumentException("can't sort on the vector field"); + throw new IllegalArgumentException("Field [" + fieldName + "] of type [" + + DenseVectorFieldMapper.CONTENT_TYPE + "] doesn't support sort"); } @Override @@ -56,7 +63,7 @@ public BucketedSort newBucketedSort(BigArrays bigArrays, Object missingValue, Mu @Override public VectorDVLeafFieldData load(LeafReaderContext context) { - return new VectorDVLeafFieldData(context.reader(), fieldName); + return new VectorDVLeafFieldData(context.reader(), fieldName, indexVersion, dims); } @Override @@ -67,15 +74,19 @@ public VectorDVLeafFieldData loadDirect(LeafReaderContext context) { public static class Builder implements IndexFieldData.Builder { private final String name; private final ValuesSourceType valuesSourceType; + private final Version indexVersion; + private final int dims; - public Builder(String name, ValuesSourceType valuesSourceType) { + public Builder(String name, ValuesSourceType valuesSourceType, Version indexVersion, int dims) { this.name = name; this.valuesSourceType = valuesSourceType; + this.indexVersion = indexVersion; + this.dims = dims; } @Override public IndexFieldData build(IndexFieldDataCache cache, CircuitBreakerService breakerService) { - return new VectorIndexFieldData(name, valuesSourceType); + return new VectorIndexFieldData(name, valuesSourceType, indexVersion, dims); } } diff --git a/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt b/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt index 121e74933a5cd..86583d77264a2 100644 --- a/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt +++ b/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt @@ -5,6 +5,8 @@ # 2.0. # class org.elasticsearch.xpack.vectors.query.DenseVectorScriptDocValues { + float[] getVectorValue() + float getMagnitude() } class org.elasticsearch.script.ScoreScript @no_import { } diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java index f5a33dde9ebc6..4458e88678c4e 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java @@ -95,7 +95,7 @@ public void testDefaults() throws Exception { // assert that after decoding the indexed value is equal to expected BytesRef vectorBR = fields[0].binaryValue(); float[] decodedValues = decodeDenseVector(Version.CURRENT, vectorBR); - float decodedMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(Version.CURRENT, vectorBR); + float decodedMagnitude = VectorEncoderDecoder.decodeMagnitude(Version.CURRENT, vectorBR); assertEquals(expectedMagnitude, decodedMagnitude, 0.001f); assertArrayEquals( "Decoded dense vector values is not equal to the indexed one.", diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldTypeTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldTypeTests.java index 5b16d30af7f5f..bc1e3557ee559 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldTypeTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldTypeTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.vectors.mapper; +import org.elasticsearch.Version; import org.elasticsearch.index.mapper.FieldTypeTestCase; import java.io.IOException; @@ -16,29 +17,34 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase { public void testHasDocValues() { - DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType("f", 1, Collections.emptyMap()); + DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType( + "f", Version.CURRENT, 1, Collections.emptyMap()); assertTrue(ft.hasDocValues()); } public void testIsAggregatable() { - DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType("f", 1, Collections.emptyMap()); + DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType( + "f", Version.CURRENT,1, Collections.emptyMap()); assertFalse(ft.isAggregatable()); } public void testFielddataBuilder() { - DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType("f", 1, Collections.emptyMap()); + DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType( + "f", Version.CURRENT,1, Collections.emptyMap()); assertNotNull(ft.fielddataBuilder("index", () -> { throw new UnsupportedOperationException(); })); } public void testDocValueFormat() { - DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType("f", 1, Collections.emptyMap()); - expectThrows(UnsupportedOperationException.class, () -> ft.docValueFormat(null, null)); + DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType( + "f", Version.CURRENT,1, Collections.emptyMap()); + expectThrows(IllegalArgumentException.class, () -> ft.docValueFormat(null, null)); } public void testFetchSourceValue() throws IOException { - DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType("f", 5, Collections.emptyMap()); + DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType( + "f", Version.CURRENT, 5, Collections.emptyMap()); List vector = List.of(0.0, 1.0, 2.0, 3.0, 4.0); assertEquals(vector, fetchSourceValue(ft, vector)); } diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorFunctionTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorFunctionTests.java index b03b5b9f4a903..c6e5540d210ef 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorFunctionTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorFunctionTests.java @@ -44,11 +44,14 @@ public void setUpVectors() { public void testVectorFunctions() { for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) { BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion); + float magnitude = VectorEncoderDecoder.getMagnitude(indexVersion, encodedDocVector); + DenseVectorScriptDocValues docValues = mock(DenseVectorScriptDocValues.class); when(docValues.getEncodedValue()).thenReturn(encodedDocVector); + when(docValues.getMagnitude()).thenReturn(magnitude); + when(docValues.dims()).thenReturn(docVector.length); ScoreScript scoreScript = mock(ScoreScript.class); - when(scoreScript._getIndexVersion()).thenReturn(indexVersion); when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, docValues)); testDotProduct(scoreScript); @@ -63,8 +66,8 @@ private void testDotProduct(ScoreScript scoreScript) { double result = function.dotProduct(); assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001); - DotProduct invalidFunction = new DotProduct(scoreScript, invalidQueryVector, field); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::dotProduct); + IllegalArgumentException e = + expectThrows(IllegalArgumentException.class, () -> new DotProduct(scoreScript, invalidQueryVector, field)); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); } @@ -73,8 +76,8 @@ private void testCosineSimilarity(ScoreScript scoreScript) { double result = function.cosineSimilarity(); assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result, 0.001); - CosineSimilarity invalidFunction = new CosineSimilarity(scoreScript, invalidQueryVector, field); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::cosineSimilarity); + IllegalArgumentException e = + expectThrows(IllegalArgumentException.class, () -> new CosineSimilarity(scoreScript, invalidQueryVector, field)); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); } @@ -83,8 +86,8 @@ private void testL1Norm(ScoreScript scoreScript) { double result = function.l1norm(); assertEquals("l1norm result is not equal to the expected value!", 485.184, result, 0.001); - L1Norm invalidFunction = new L1Norm(scoreScript, invalidQueryVector, field); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::l1norm); + IllegalArgumentException e = + expectThrows(IllegalArgumentException.class, () -> new L1Norm(scoreScript, invalidQueryVector, field)); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); } @@ -93,12 +96,11 @@ private void testL2Norm(ScoreScript scoreScript) { double result = function.l2norm(); assertEquals("l2norm result is not equal to the expected value!", 301.361, result, 0.001); - L2Norm invalidFunction = new L2Norm(scoreScript, invalidQueryVector, field); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::l2norm); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new L2Norm(scoreScript, invalidQueryVector, field)); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); } - private static BytesRef mockEncodeDenseVector(float[] values, Version indexVersion) { + static BytesRef mockEncodeDenseVector(float[] values, Version indexVersion) { byte[] bytes = indexVersion.onOrAfter(Version.V_7_5_0) ? new byte[VectorEncoderDecoder.INT_BYTES * values.length + VectorEncoderDecoder.INT_BYTES] : new byte[VectorEncoderDecoder.INT_BYTES * values.length]; diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValuesTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValuesTests.java new file mode 100644 index 0000000000000..7a77b93adba3c --- /dev/null +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValuesTests.java @@ -0,0 +1,105 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.vectors.query; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.Version; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.Arrays; + +import static org.hamcrest.Matchers.containsString; + +public class DenseVectorScriptDocValuesTests extends ESTestCase { + + private static BinaryDocValues wrap(float[][] vectors, Version indexVersion) { + return new BinaryDocValues() { + int idx = -1; + int maxIdx = vectors.length; + @Override + public BytesRef binaryValue() { + if (idx >= maxIdx) { + throw new IllegalStateException("max index exceeded"); + } + return DenseVectorFunctionTests.mockEncodeDenseVector(vectors[idx], indexVersion); + } + + @Override + public boolean advanceExact(int target) { + idx = target; + if (target < maxIdx) { + return true; + } + return false; + } + + @Override + public int docID() { + return idx; + } + + @Override + public int nextDoc() { + return idx++; + } + + @Override + public int advance(int target) { + throw new IllegalArgumentException("not defined!"); + } + + @Override + public long cost() { + throw new IllegalArgumentException("not defined!"); + } + }; + } + + public void testGetVectorValueAndGetMagnitude() throws IOException { + final int dims = 3; + float[][] vectors = {{ 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }; + float[] expectedMagnitudes = { 1.7320f, 2.4495f, 3.3166f }; + + for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) { + BinaryDocValues docValues = wrap(vectors, indexVersion); + final DenseVectorScriptDocValues scriptDocValues = new DenseVectorScriptDocValues(docValues, indexVersion, dims); + for (int i = 0; i < vectors.length; i++) { + scriptDocValues.setNextDocId(i); + assertArrayEquals(vectors[i], scriptDocValues.getVectorValue(), 0.0001f); + assertEquals(expectedMagnitudes[i], scriptDocValues.getMagnitude(), 0.0001f); + } + } + } + + public void testMissingValues() throws IOException { + final int dims = 3; + float[][] vectors = {{ 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }; + BinaryDocValues docValues = wrap(vectors, Version.CURRENT); + final DenseVectorScriptDocValues scriptDocValues = new DenseVectorScriptDocValues(docValues, Version.CURRENT, dims); + + scriptDocValues.setNextDocId(3); + Exception e = expectThrows(IllegalArgumentException.class, () -> scriptDocValues.getVectorValue()); + assertEquals("A document doesn't have a value for a vector field!", e.getMessage()); + + e = expectThrows(IllegalArgumentException.class, () -> scriptDocValues.getMagnitude()); + assertEquals("A document doesn't have a value for a vector field!", e.getMessage()); + } + + public void testGetFunctionIsNotAccessible() throws IOException { + final int dims = 3; + float[][] vectors = {{ 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }; + BinaryDocValues docValues = wrap(vectors, Version.CURRENT); + final DenseVectorScriptDocValues scriptDocValues = new DenseVectorScriptDocValues(docValues, Version.CURRENT, dims); + + scriptDocValues.setNextDocId(0); + Exception e = expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); + assertThat(e.getMessage(), containsString("accessing a vector field's value through 'get' or 'value' is not supported!")); + } +}