diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml index 903b9dc3de3b0..57b1f07a157e0 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml @@ -68,7 +68,7 @@ setup: - match: {hits.hits.2._id: "2"} - gte: {hits.hits.2._score: 35853.78} - - lte: {hits.hits.2._score: 35853.79} + - lte: {hits.hits.2._score: 35853.791} --- "Cosine Similarity": @@ -99,3 +99,77 @@ setup: - match: {hits.hits.2._id: "1"} - gte: {hits.hits.2._score: 0.78} - lte: {hits.hits.2._score: 0.791} + +--- +"Dot Product in Sort Context": + - skip: + features: headers + version: " - 7.4.99" + reason: "vector functions were added in the sort context from 7.5" + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: {match_all: {} } + sort: + _script: + type: number + order: desc + script: + source: "dotProduct(params.query_vector, doc['my_dense_vector'])" + params: + query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] + + - match: {hits.total: 3} + + - match: {hits.hits.0._id: "1"} + - gte: {hits.hits.0.sort.0: 65425.62} + - lte: {hits.hits.0.sort.0: 65425.63} + + - match: {hits.hits.1._id: "3"} + - gte: {hits.hits.1.sort.0: 37111.98} + - lte: {hits.hits.1.sort.0: 37111.99} + + - match: {hits.hits.2._id: "2"} + - gte: {hits.hits.2.sort.0: 35853.78} + - lte: {hits.hits.2.sort.0: 35853.791} + +--- +"Cosine Similarity in Sort Context": + - skip: + features: headers + version: " - 7.4.99" + reason: "vector functions were added in the sort context from 7.5" + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: {match_all: {} } + sort: + _script: + type: number + order: desc + script: + source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + params: + query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] + + - match: {hits.total: 3} + + - match: {hits.hits.0._id: "3"} + - gte: {hits.hits.0.sort.0: 0.999} + - lte: {hits.hits.0.sort.0: 1.001} + + - match: {hits.hits.1._id: "2"} + - gte: {hits.hits.1.sort.0: 0.998} + - lte: {hits.hits.1.sort.0: 1.0} + + - match: {hits.hits.2._id: "1"} + - gte: {hits.hits.2.sort.0: 0.78} + - lte: {hits.hits.2.sort.0: 0.791} diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DocValuesWhitelistExtension.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DocValuesWhitelistExtension.java index 7d8d1d97da92e..d07bdeaaed555 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DocValuesWhitelistExtension.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DocValuesWhitelistExtension.java @@ -3,21 +3,21 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ - - package org.elasticsearch.xpack.vectors.query; - import org.elasticsearch.painless.spi.PainlessExtension; import org.elasticsearch.painless.spi.Whitelist; import org.elasticsearch.painless.spi.WhitelistLoader; +import org.elasticsearch.script.NumberSortScript; import org.elasticsearch.script.ScoreScript; import org.elasticsearch.script.ScriptContext; -import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; +import static java.util.Collections.singletonList; + public class DocValuesWhitelistExtension implements PainlessExtension { private static final Whitelist WHITELIST = @@ -25,6 +25,10 @@ public class DocValuesWhitelistExtension implements PainlessExtension { @Override public Map, List> getContextWhitelists() { - return Collections.singletonMap(ScoreScript.CONTEXT, Collections.singletonList(WHITELIST)); + Map, List> whitelist = new HashMap<>(); + List list = singletonList(WHITELIST); + whitelist.put(ScoreScript.CONTEXT, list); + whitelist.put(NumberSortScript.CONTEXT, list); + return whitelist; } }