Skip to content

Commit d53e83c

Browse files
Add access to dense_vector values (#71847)
Allow direct access to a dense_vector' values in script through the following functions: - getVectorValue – returns a vector's value as an array of floats - getMagnitude – returns a vector's magnitude Closes #51964 Backport for #71313
1 parent d190d58 commit d53e83c

File tree

23 files changed

+436
-139
lines changed

23 files changed

+436
-139
lines changed

docs/reference/mapping/types/dense-vector.asciidoc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@ A `dense_vector` field stores dense vectors of float values.
1010
The maximum number of dimensions that can be in a vector should
1111
not exceed 2048. A `dense_vector` field is a single-valued field.
1212

13-
These vectors can be used for <<vector-functions,document scoring>>.
14-
For example, a document score can represent a distance between
15-
a given query vector and the indexed document vector.
13+
`dense_vector` fields do not support querying, sorting or aggregating. They can
14+
only be accessed in scripts through the dedicated <<vector-functions,vector functions>>.
1615

1716
You index a dense vector as an array of floats.
1817

@@ -47,4 +46,4 @@ PUT my-index-000001/_doc/2
4746
4847
--------------------------------------------------
4948

50-
<1> dimsthe number of dimensions in the vector, required parameter.
49+
<1> dimsthe number of dimensions in the vector, required parameter.

docs/reference/vectors/vector-functions.asciidoc

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@ to limit the number of matched documents with a `query` parameter.
1010

1111
====== `dense_vector` functions
1212

13+
This is the list of available vector functions and vector access methods:
14+
15+
1. `cosineSimilarity` – calculates cosine similarity
16+
2. `dotProduct` – calculates dot product
17+
3. `l1norm` – calculates L^1^ distance
18+
4. `l2norm` - calculates L^2^ distance
19+
5. `doc[<field>].vectorValue` – returns a vector's value as an array of floats
20+
6. `doc[<field>].magnitude` – returns a vector's magnitude
21+
1322
Let's create an index with a `dense_vector` mapping and index a couple
1423
of documents into it.
1524

@@ -198,6 +207,54 @@ You can check if a document has a value for the field `my_vector` by
198207
--------------------------------------------------
199208
// NOTCONSOLE
200209

210+
The recommended way to access dense vectors is through `cosineSimilarity`,
211+
`dotProduct`, `l1norm` or `l2norm` functions. But for custom use cases,
212+
you can access dense vectors's values directly through the following functions:
213+
214+
- `doc[<field>].vectorValue` – returns a vector's value as an array of floats
215+
216+
- `doc[<field>].magnitude` – returns a vector's magnitude as a float
217+
(for vectors created prior to version 7.5 the magnitude is not stored.
218+
So this function calculates it anew every time it is called).
219+
220+
For example, the script below implements a cosine similarity using these
221+
two functions:
222+
223+
[source,console]
224+
--------------------------------------------------
225+
GET my-index-000001/_search
226+
{
227+
"query": {
228+
"script_score": {
229+
"query" : {
230+
"bool" : {
231+
"filter" : {
232+
"term" : {
233+
"status" : "published"
234+
}
235+
}
236+
}
237+
},
238+
"script": {
239+
"source": """
240+
float[] v = doc['my_dense_vector'].vectorValue;
241+
float vm = doc['my_dense_vector'].magnitude;
242+
float dotProduct = 0;
243+
for (int i = 0; i < v.length; i++) {
244+
dotProduct += v[i] * params.queryVector[i];
245+
}
246+
return dotProduct / (vm * (float) params.queryVectorMag);
247+
""",
248+
"params": {
249+
"queryVector": [4, 3.4, -0.2],
250+
"queryVectorMag": 5.25357
251+
}
252+
}
253+
}
254+
}
255+
}
256+
--------------------------------------------------
257+
201258
====== `sparse_vector` functions
202259

203260
deprecated[7.6, The `sparse_vector` type is deprecated and will be removed in 8.0.]

server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import org.elasticsearch.script.ExplainableScoreScript;
1515
import org.elasticsearch.script.ScoreScript;
1616
import org.elasticsearch.script.Script;
17-
import org.elasticsearch.Version;
1817

1918
import java.io.IOException;
2019
import java.util.Objects;
@@ -42,15 +41,13 @@ public float score() {
4241

4342
private final int shardId;
4443
private final String indexName;
45-
private final Version indexVersion;
4644

47-
public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId, Version indexVersion) {
45+
public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId) {
4846
super(CombineFunction.REPLACE);
4947
this.sScript = sScript;
5048
this.script = script;
5149
this.indexName = indexName;
5250
this.shardId = shardId;
53-
this.indexVersion = indexVersion;
5451
}
5552

5653
@Override
@@ -60,7 +57,6 @@ public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOEx
6057
leafScript.setScorer(scorer);
6158
leafScript._setIndexName(indexName);
6259
leafScript._setShard(shardId);
63-
leafScript._setIndexVersion(indexVersion);
6460
return new LeafScoreFunction() {
6561
@Override
6662
public double score(int docId, float subQueryScore) throws IOException {

server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ private ScoreScript makeScoreScript(LeafReaderContext context) throws IOExceptio
146146
final ScoreScript scoreScript = scriptBuilder.newInstance(context);
147147
scoreScript._setIndexName(indexName);
148148
scoreScript._setShard(shardId);
149-
scoreScript._setIndexVersion(indexVersion);
150149
return scoreScript;
151150
}
152151

server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@ protected ScoreFunction doToFunction(SearchExecutionContext context) {
8383
try {
8484
ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT);
8585
ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup());
86-
return new ScriptScoreFunction(script, searchScript,
87-
context.index().getName(), context.getShardId(), context.indexVersionCreated());
86+
return new ScriptScoreFunction(script, searchScript, context.index().getName(), context.getShardId());
8887
} catch (Exception e) {
8988
throw new QueryShardException(context, "script_score: the script could not be loaded", e);
9089
}

server/src/main/java/org/elasticsearch/script/ScoreScript.java

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import org.apache.lucene.index.LeafReaderContext;
1111
import org.apache.lucene.search.Explanation;
1212
import org.apache.lucene.search.Scorable;
13-
import org.elasticsearch.Version;
1413
import org.elasticsearch.common.logging.DeprecationCategory;
1514
import org.elasticsearch.common.logging.DeprecationLogger;
1615
import org.elasticsearch.index.fielddata.ScriptDocValues;
@@ -85,7 +84,6 @@ public Explanation get(double score, Explanation subQueryExplanation) {
8584
private int docId;
8685
private int shardId = -1;
8786
private String indexName = null;
88-
private Version indexVersion = null;
8987

9088
public ScoreScript(Map<String, Object> params, SearchLookup lookup, LeafReaderContext leafContext) {
9189
// null check needed b/c of expression engine subclass
@@ -185,19 +183,6 @@ public String _getIndex() {
185183
}
186184
}
187185

188-
/**
189-
* Starting a name with underscore, so that the user cannot access this function directly through a script
190-
* It is only used within predefined painless functions.
191-
* @return index version or throws an exception if the index version is not set up for this script instance
192-
*/
193-
public Version _getIndexVersion() {
194-
if (indexVersion != null) {
195-
return indexVersion;
196-
} else {
197-
throw new IllegalArgumentException("index version can not be looked up!");
198-
}
199-
}
200-
201186
/**
202187
* Starting a name with underscore, so that the user cannot access this function directly through a script
203188
*/
@@ -212,13 +197,6 @@ public void _setIndexName(String indexName) {
212197
this.indexName = indexName;
213198
}
214199

215-
/**
216-
* Starting a name with underscore, so that the user cannot access this function directly through a script
217-
*/
218-
public void _setIndexVersion(Version indexVersion) {
219-
this.indexVersion = indexVersion;
220-
}
221-
222200

223201
/** A factory to construct {@link ScoreScript} instances. */
224202
public interface LeafFactory {
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
---
2+
"Access to values of dense_vector in script":
3+
- skip:
4+
version: " - 7.12.99"
5+
reason: "Access to values of dense_vector in script was added in 7.13"
6+
- do:
7+
indices.create:
8+
index: test-index
9+
body:
10+
mappings:
11+
properties:
12+
v:
13+
type: dense_vector
14+
dims: 3
15+
16+
- do:
17+
bulk:
18+
index: test-index
19+
refresh: true
20+
body:
21+
- '{"index": {"_id": "1"}}'
22+
- '{"v": [1, 1, 1]}'
23+
- '{"index": {"_id": "2"}}'
24+
- '{"v": [1, 1, 2]}'
25+
- '{"index": {"_id": "3"}}'
26+
- '{"v": [1, 1, 3]}'
27+
- '{"index": {"_id": "missing_vector"}}'
28+
- '{}'
29+
30+
# vector functions in loop – return the index of the closest parameter vector based on cosine similarity
31+
- do:
32+
search:
33+
body:
34+
query:
35+
script_score:
36+
query: { "exists": { "field": "v" } }
37+
script:
38+
source: |
39+
float[] v = doc['v'].vectorValue;
40+
float vm = doc['v'].magnitude;
41+
42+
int closestPv = 0;
43+
float maxCosSim = -1;
44+
for (int i = 0; i < params.pvs.length; i++) {
45+
float dotProduct = 0;
46+
for (int j = 0; j < v.length; j++) {
47+
dotProduct += v[j] * params.pvs[i][j];
48+
}
49+
float cosSim = dotProduct / (vm * (float) params.pvs_magnts[i]);
50+
if (maxCosSim < cosSim) {
51+
maxCosSim = cosSim;
52+
closestPv = i;
53+
}
54+
}
55+
closestPv;
56+
params:
57+
pvs: [ [ 1, 1, 1 ], [ 1, 1, 2 ], [ 1, 1, 3 ] ]
58+
pvs_magnts: [1.7320, 2.4495, 3.3166]
59+
60+
- match: { hits.hits.0._id: "3" }
61+
- match: { hits.hits.0._score: 2 }
62+
- match: { hits.hits.1._id: "2" }
63+
- match: { hits.hits.1._score: 1 }
64+
- match: { hits.hits.2._id: "1" }
65+
- match: { hits.hits.2._score: 0 }

x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ protected List<Parameter<?>> getParameters() {
8383
public DenseVectorFieldMapper build(ContentPath contentPath) {
8484
return new DenseVectorFieldMapper(
8585
name,
86-
new DenseVectorFieldType(buildFullName(contentPath), dims.getValue(), meta.getValue()),
86+
new DenseVectorFieldType(buildFullName(contentPath), indexVersionCreated, dims.getValue(), meta.getValue()),
8787
dims.getValue(),
8888
indexVersionCreated,
8989
multiFieldsBuilder.build(this, contentPath),
@@ -95,10 +95,12 @@ public DenseVectorFieldMapper build(ContentPath contentPath) {
9595

9696
public static final class DenseVectorFieldType extends MappedFieldType {
9797
private final int dims;
98+
private final Version indexVersionCreated;
9899

99-
public DenseVectorFieldType(String name, int dims, Map<String, String> meta) {
100+
public DenseVectorFieldType(String name, Version indexVersionCreated, int dims, Map<String, String> meta) {
100101
super(name, false, false, true, TextSearchInfo.NONE, meta);
101102
this.dims = dims;
103+
this.indexVersionCreated = indexVersionCreated;
102104
}
103105

104106
int dims() {
@@ -125,7 +127,7 @@ protected Object parseSourceValue(Object value) {
125127

126128
@Override
127129
public DocValueFormat docValueFormat(String format, ZoneId timeZone) {
128-
throw new UnsupportedOperationException(
130+
throw new IllegalArgumentException(
129131
"Field [" + name() + "] of type [" + typeName() + "] doesn't support docvalue_fields or aggregations");
130132
}
131133

@@ -136,7 +138,7 @@ public boolean isAggregatable() {
136138

137139
@Override
138140
public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier<SearchLookup> searchLookup) {
139-
return new VectorIndexFieldData.Builder(name(), true, CoreValuesSourceType.KEYWORD);
141+
return new VectorIndexFieldData.Builder(name(), true, CoreValuesSourceType.KEYWORD, indexVersionCreated, dims);
140142
}
141143

142144
@Override

x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapper.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ protected List<Parameter<?>> getParameters() {
7171
@Override
7272
public SparseVectorFieldMapper build(ContentPath contentPath) {
7373
return new SparseVectorFieldMapper(
74-
name, new SparseVectorFieldType(buildFullName(contentPath), meta.getValue()),
74+
name, new SparseVectorFieldType(buildFullName(contentPath), indexCreatedVersion, meta.getValue()),
7575
multiFieldsBuilder.build(this, contentPath), copyTo.build(), indexCreatedVersion);
7676
}
7777
}
@@ -83,8 +83,10 @@ name, new SparseVectorFieldType(buildFullName(contentPath), meta.getValue()),
8383

8484
public static final class SparseVectorFieldType extends MappedFieldType {
8585

86-
public SparseVectorFieldType(String name, Map<String, String> meta) {
86+
private final Version indexVersionCreated;
87+
public SparseVectorFieldType(String name, Version indexVersionCreated, Map<String, String> meta) {
8788
super(name, false, false, true, TextSearchInfo.NONE, meta);
89+
this.indexVersionCreated = indexVersionCreated;
8890
}
8991

9092
@Override
@@ -94,7 +96,7 @@ public String typeName() {
9496

9597
@Override
9698
public DocValueFormat docValueFormat(String format, ZoneId timeZone) {
97-
throw new UnsupportedOperationException(
99+
throw new IllegalArgumentException(
98100
"Field [" + name() + "] of type [" + typeName() + "] doesn't support docvalue_fields or aggregations");
99101
}
100102

@@ -118,7 +120,7 @@ public Query existsQuery(SearchExecutionContext context) {
118120

119121
@Override
120122
public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier<SearchLookup> searchLookup) {
121-
return new VectorIndexFieldData.Builder(name(), false, CoreValuesSourceType.KEYWORD);
123+
return new VectorIndexFieldData.Builder(name(), false, CoreValuesSourceType.KEYWORD, indexVersionCreated, -1);
122124
}
123125

124126
@Override

x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import java.nio.ByteBuffer;
1616

17-
// static utility functions for encoding and decoding dense_vector and sparse_vector fields
1817
public final class VectorEncoderDecoder {
1918
static final byte INT_BYTES = 4;
2019
static final byte SHORT_BYTES = 2;
@@ -168,9 +167,51 @@ public static int denseVectorLength(Version indexVersion, BytesRef vectorBR) {
168167
* NOTE: this function can only be called on vectors from an index version greater than or
169168
* equal to 7.5.0, since vectors created prior to that do not store the magnitude.
170169
*/
171-
public static float decodeVectorMagnitude(Version indexVersion, BytesRef vectorBR) {
170+
public static float decodeMagnitude(Version indexVersion, BytesRef vectorBR) {
172171
assert indexVersion.onOrAfter(Version.V_7_5_0);
173172
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
174-
return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - 4);
173+
return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - INT_BYTES);
175174
}
175+
176+
/**
177+
* Calculates vector magnitude
178+
*/
179+
private static float calculateMagnitude(Version indexVersion, BytesRef vectorBR) {
180+
final int length = denseVectorLength(indexVersion, vectorBR);
181+
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
182+
double magnitude = 0.0f;
183+
for (int i = 0; i < length; i++) {
184+
float value = byteBuffer.getFloat();
185+
magnitude += value * value;
186+
}
187+
magnitude = Math.sqrt(magnitude);
188+
return (float) magnitude;
189+
}
190+
191+
public static float getMagnitude(Version indexVersion, BytesRef vectorBR) {
192+
if (vectorBR == null) {
193+
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
194+
}
195+
if (indexVersion.onOrAfter(Version.V_7_5_0)) {
196+
return decodeMagnitude(indexVersion, vectorBR);
197+
} else {
198+
return calculateMagnitude(indexVersion, vectorBR);
199+
}
200+
}
201+
202+
/**
203+
* Decodes a BytesRef into the provided array of floats
204+
* @param vectorBR - dense vector encoded in BytesRef
205+
* @param vector - array of floats where the decoded vector should be stored
206+
*/
207+
public static void decodeDenseVector(BytesRef vectorBR, float[] vector) {
208+
if (vectorBR == null) {
209+
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
210+
}
211+
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
212+
for (int dim = 0; dim < vector.length; dim++) {
213+
vector[dim] = byteBuffer.getFloat();
214+
}
215+
}
216+
176217
}

0 commit comments

Comments
 (0)