Skip to content

Commit a66f755

Browse files
Address Julie's feedback 3
1 parent d8edd03 commit a66f755

File tree

8 files changed

+112
-113
lines changed

8 files changed

+112
-113
lines changed

docs/reference/vectors/vector-functions.asciidoc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,8 @@ you can access dense vectors's values directly through the following functions:
213213
- `doc[<field>].vectorValue` – returns a vector's value as an array of floats
214214

215215
- `doc[<field>].magnitude` – returns a vector's magnitude as a float
216-
(for vectors created prior version 7.5 magnitude is not stored.
217-
So this function calculates it anew every time is called).
216+
(for vectors created prior to version 7.5 the magnitude is not stored.
217+
So this function calculates it anew every time it is called).
218218

219219
For example, the script below implements a cosine similarity using these
220220
two functions:

server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import org.elasticsearch.script.ExplainableScoreScript;
1515
import org.elasticsearch.script.ScoreScript;
1616
import org.elasticsearch.script.Script;
17-
import org.elasticsearch.Version;
1817

1918
import java.io.IOException;
2019
import java.util.Objects;
@@ -42,15 +41,13 @@ public float score() {
4241

4342
private final int shardId;
4443
private final String indexName;
45-
private final Version indexVersion;
4644

47-
public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId, Version indexVersion) {
45+
public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId) {
4846
super(CombineFunction.REPLACE);
4947
this.sScript = sScript;
5048
this.script = script;
5149
this.indexName = indexName;
5250
this.shardId = shardId;
53-
this.indexVersion = indexVersion;
5451
}
5552

5653
@Override
@@ -60,7 +57,6 @@ public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOEx
6057
leafScript.setScorer(scorer);
6158
leafScript._setIndexName(indexName);
6259
leafScript._setShard(shardId);
63-
leafScript._setIndexVersion(indexVersion);
6460
return new LeafScoreFunction() {
6561
@Override
6662
public double score(int docId, float subQueryScore) throws IOException {

server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ private ScoreScript makeScoreScript(LeafReaderContext context) throws IOExceptio
146146
final ScoreScript scoreScript = scriptBuilder.newInstance(context);
147147
scoreScript._setIndexName(indexName);
148148
scoreScript._setShard(shardId);
149-
scoreScript._setIndexVersion(indexVersion);
150149
return scoreScript;
151150
}
152151

server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@ protected ScoreFunction doToFunction(SearchExecutionContext context) {
8383
try {
8484
ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT);
8585
ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup());
86-
return new ScriptScoreFunction(script, searchScript,
87-
context.index().getName(), context.getShardId(), context.indexVersionCreated());
86+
return new ScriptScoreFunction(script, searchScript, context.index().getName(), context.getShardId());
8887
} catch (Exception e) {
8988
throw new QueryShardException(context, "script_score: the script could not be loaded", e);
9089
}

server/src/main/java/org/elasticsearch/script/ScoreScript.java

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import org.apache.lucene.index.LeafReaderContext;
1111
import org.apache.lucene.search.Explanation;
1212
import org.apache.lucene.search.Scorable;
13-
import org.elasticsearch.Version;
1413
import org.elasticsearch.common.logging.DeprecationCategory;
1514
import org.elasticsearch.common.logging.DeprecationLogger;
1615
import org.elasticsearch.index.fielddata.ScriptDocValues;
@@ -85,7 +84,6 @@ public Explanation get(double score, Explanation subQueryExplanation) {
8584
private int docId;
8685
private int shardId = -1;
8786
private String indexName = null;
88-
private Version indexVersion = null;
8987

9088
public ScoreScript(Map<String, Object> params, SearchLookup lookup, LeafReaderContext leafContext) {
9189
// null check needed b/c of expression engine subclass
@@ -185,19 +183,6 @@ public String _getIndex() {
185183
}
186184
}
187185

188-
/**
189-
* Starting a name with underscore, so that the user cannot access this function directly through a script
190-
* It is only used within predefined painless functions.
191-
* @return index version or throws an exception if the index version is not set up for this script instance
192-
*/
193-
public Version _getIndexVersion() {
194-
if (indexVersion != null) {
195-
return indexVersion;
196-
} else {
197-
throw new IllegalArgumentException("index version can not be looked up!");
198-
}
199-
}
200-
201186
/**
202187
* Starting a name with underscore, so that the user cannot access this function directly through a script
203188
*/
@@ -212,13 +197,6 @@ public void _setIndexName(String indexName) {
212197
this.indexName = indexName;
213198
}
214199

215-
/**
216-
* Starting a name with underscore, so that the user cannot access this function directly through a script
217-
*/
218-
public void _setIndexVersion(Version indexVersion) {
219-
this.indexVersion = indexVersion;
220-
}
221-
222200

223201
/** A factory to construct {@link ScoreScript} instances. */
224202
public interface LeafFactory {

x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/30_dense_vector_script_access.yml

Lines changed: 2 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
---
22
"Access to values of dense_vector in script":
33
- skip:
4-
features: headers
54
version: " - 7.12.99"
65
reason: "Access to values of dense_vector in script was added in 7.13"
76
- do:
@@ -28,84 +27,8 @@
2827
- '{"index": {"_id": "missing_vector"}}'
2928
- '{}'
3029

31-
# check getVectorValue() API
32-
- do:
33-
search:
34-
body:
35-
query:
36-
script_score:
37-
query: { "exists" : { "field" : "v" } }
38-
script:
39-
source: |
40-
float s = 0;
41-
for (def el : doc['v'].vectorValue) {
42-
s += el;
43-
}
44-
s;
45-
46-
- match: { hits.hits.0._id: "3" }
47-
- match: { hits.hits.0._score: 5 }
48-
- match: { hits.hits.1._id: "2" }
49-
- match: { hits.hits.1._score: 4 }
50-
- match: { hits.hits.2._id: "1" }
51-
- match: { hits.hits.2._score: 3 }
52-
53-
54-
# check getMagnitude() API
55-
- do:
56-
headers:
57-
Content-Type: application/json
58-
search:
59-
body:
60-
query:
61-
script_score:
62-
query: { "exists" : { "field" : "v" } }
63-
script:
64-
source: "doc['v'].magnitude"
65-
66-
- match: { hits.hits.0._id: "3" }
67-
- gte: {hits.hits.0._score: 3.3166}
68-
- lte: {hits.hits.0._score: 3.3167}
69-
- match: { hits.hits.1._id: "2" }
70-
- gte: {hits.hits.1._score: 2.4494}
71-
- lte: {hits.hits.1._score: 2.4495}
72-
- match: { hits.hits.2._id: "1" }
73-
- gte: {hits.hits.2._score: 1.7320}
74-
- lte: {hits.hits.2._score: 1.7321}
75-
76-
# check failed request on missing values
77-
- do:
78-
catch: bad_request
79-
search:
80-
body:
81-
query:
82-
script_score:
83-
query: { match_all: { } }
84-
script:
85-
source: "doc['v'].vectorValue[0]"
86-
87-
- match: { status: 400 }
88-
- match: { error.root_cause.0.type: "script_exception" }
89-
90-
# check failed request on missing values
91-
- do:
92-
catch: bad_request
93-
search:
94-
body:
95-
query:
96-
script_score:
97-
query: { match_all: { } }
98-
script:
99-
source: "doc['v'].magnitude"
100-
101-
- match: { status: 400 }
102-
- match: { error.root_cause.0.type: "script_exception" }
103-
104-
10530
# vector functions in loop – return the index of the closest parameter vector based on cosine similarity
10631
- do:
107-
headers:
108-
Content-Type: application/json
10932
search:
11033
body:
11134
query:
@@ -123,7 +46,7 @@
12346
for (int j = 0; j < v.length; j++) {
12447
dotProduct += v[j] * params.pvs[i][j];
12548
}
126-
float cosSim = dotProduct / (vm * (float) params.pvs_lengths[i]);
49+
float cosSim = dotProduct / (vm * (float) params.pvs_magnts[i]);
12750
if (maxCosSim < cosSim) {
12851
maxCosSim = cosSim;
12952
closestPv = i;
@@ -132,7 +55,7 @@
13255
closestPv;
13356
params:
13457
pvs: [ [ 1, 1, 1 ], [ 1, 1, 2 ], [ 1, 1, 3 ] ]
135-
pvs_lengths: [1.7320, 2.4495, 3.3166]
58+
pvs_magnts: [1.7320, 2.4495, 3.3166]
13659

13760
- match: { hits.hits.0._id: "3" }
13861
- match: { hits.hits.0._score: 2 }

x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorFunctionTests.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ public void testVectorFunctions() {
5252
when(docValues.dims()).thenReturn(docVector.length);
5353

5454
ScoreScript scoreScript = mock(ScoreScript.class);
55-
when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
5655
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, docValues));
5756

5857
testDotProduct(scoreScript);
@@ -101,7 +100,7 @@ private void testL2Norm(ScoreScript scoreScript) {
101100
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
102101
}
103102

104-
private static BytesRef mockEncodeDenseVector(float[] values, Version indexVersion) {
103+
static BytesRef mockEncodeDenseVector(float[] values, Version indexVersion) {
105104
byte[] bytes = indexVersion.onOrAfter(Version.V_7_5_0)
106105
? new byte[VectorEncoderDecoder.INT_BYTES * values.length + VectorEncoderDecoder.INT_BYTES]
107106
: new byte[VectorEncoderDecoder.INT_BYTES * values.length];
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.vectors.query;
9+
10+
import org.apache.lucene.index.BinaryDocValues;
11+
import org.apache.lucene.util.BytesRef;
12+
import org.elasticsearch.Version;
13+
import org.elasticsearch.test.ESTestCase;
14+
15+
import java.io.IOException;
16+
import java.util.Arrays;
17+
18+
import static org.hamcrest.Matchers.containsString;
19+
20+
public class DenseVectorScriptDocValuesTests extends ESTestCase {
21+
22+
private static BinaryDocValues wrap(float[][] vectors, Version indexVersion) {
23+
return new BinaryDocValues() {
24+
int idx = -1;
25+
int maxIdx = vectors.length;
26+
@Override
27+
public BytesRef binaryValue() {
28+
if (idx >= maxIdx) {
29+
throw new IllegalStateException("max index exceeded");
30+
}
31+
return DenseVectorFunctionTests.mockEncodeDenseVector(vectors[idx], indexVersion);
32+
}
33+
34+
@Override
35+
public boolean advanceExact(int target) {
36+
idx = target;
37+
if (target < maxIdx) {
38+
return true;
39+
}
40+
return false;
41+
}
42+
43+
@Override
44+
public int docID() {
45+
return idx;
46+
}
47+
48+
@Override
49+
public int nextDoc() {
50+
return idx++;
51+
}
52+
53+
@Override
54+
public int advance(int target) {
55+
throw new IllegalArgumentException("not defined!");
56+
}
57+
58+
@Override
59+
public long cost() {
60+
throw new IllegalArgumentException("not defined!");
61+
}
62+
};
63+
}
64+
65+
public void testGetVectorValueAndGetMagnitude() throws IOException {
66+
final int dims = 3;
67+
float[][] vectors = {{ 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } };
68+
float[] expectedMagnitudes = { 1.7320f, 2.4495f, 3.3166f };
69+
70+
for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) {
71+
BinaryDocValues docValues = wrap(vectors, indexVersion);
72+
final DenseVectorScriptDocValues scriptDocValues = new DenseVectorScriptDocValues(docValues, indexVersion, dims);
73+
for (int i = 0; i < vectors.length; i++) {
74+
scriptDocValues.setNextDocId(i);
75+
assertArrayEquals(vectors[i], scriptDocValues.getVectorValue(), 0.0001f);
76+
assertEquals(expectedMagnitudes[i], scriptDocValues.getMagnitude(), 0.0001f);
77+
}
78+
}
79+
}
80+
81+
public void testMissingValues() throws IOException {
82+
final int dims = 3;
83+
float[][] vectors = {{ 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } };
84+
BinaryDocValues docValues = wrap(vectors, Version.CURRENT);
85+
final DenseVectorScriptDocValues scriptDocValues = new DenseVectorScriptDocValues(docValues, Version.CURRENT, dims);
86+
87+
scriptDocValues.setNextDocId(3);
88+
Exception e = expectThrows(IllegalArgumentException.class, () -> scriptDocValues.getVectorValue());
89+
assertEquals("A document doesn't have a value for a vector field!", e.getMessage());
90+
91+
e = expectThrows(IllegalArgumentException.class, () -> scriptDocValues.getMagnitude());
92+
assertEquals("A document doesn't have a value for a vector field!", e.getMessage());
93+
}
94+
95+
public void testGetFunctionIsNotAccessible() throws IOException {
96+
final int dims = 3;
97+
float[][] vectors = {{ 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } };
98+
BinaryDocValues docValues = wrap(vectors, Version.CURRENT);
99+
final DenseVectorScriptDocValues scriptDocValues = new DenseVectorScriptDocValues(docValues, Version.CURRENT, dims);
100+
101+
scriptDocValues.setNextDocId(0);
102+
Exception e = expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0));
103+
assertThat(e.getMessage(), containsString("accessing a vector field's value through 'get' or 'value' is not supported!"));
104+
}
105+
}

0 commit comments

Comments
 (0)