Skip to content

Add linear function to rank_feature query #67438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions docs/reference/query-dsl/rank-feature-query.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ query supports the following mathematical functions:
* <<rank-feature-query-saturation,Saturation>>
* <<rank-feature-query-logarithm,Logarithm>>
* <<rank-feature-query-sigmoid,Sigmoid>>
* <<rank-feature-query-linear,Linear>>

If you don't know where to start, we recommend using the `saturation` function.
If no function is provided, the `rank_feature` query uses the `saturation`
Expand Down Expand Up @@ -126,7 +127,7 @@ The following query searches for `2016` and boosts relevance scores based on

[source,console]
----
GET /test/_search
GET /test/_search
{
"query": {
"bool": {
Expand Down Expand Up @@ -190,7 +191,7 @@ value of the rank feature `field`. If no function is provided, the `rank_feature
query defaults to the `saturation` function. See
<<rank-feature-query-saturation,Saturation>> for more information.

Only one function `saturation`, `log`, or `sigmoid` can be provided.
Only one function `saturation`, `log`, `sigmoid` or `linear` can be provided.
--

`log`::
Expand All @@ -201,7 +202,7 @@ function used to boost <<relevance-scores,relevance scores>> based on the
value of the rank feature `field`. See
<<rank-feature-query-logarithm,Logarithm>> for more information.

Only one function `saturation`, `log`, or `sigmoid` can be provided.
Only one function `saturation`, `log`, `sigmoid` or `linear` can be provided.
--

`sigmoid`::
Expand All @@ -212,7 +213,18 @@ to boost <<relevance-scores,relevance scores>> based on the value of the
rank feature `field`. See <<rank-feature-query-sigmoid,Sigmoid>> for more
information.

Only one function `saturation`, `log`, or `sigmoid` can be provided.
Only one function `saturation`, `log`, `sigmoid` or `linear` can be provided.
--

`linear`::
+
--
(Optional, <<rank-feature-query-linear,function object>>) Linear function used
to boost <<relevance-scores,relevance scores>> based on the value of the
rank feature `field`. See <<rank-feature-query-linear,Linear>> for more
information.

Only one function `saturation`, `log`, `sigmoid` or `linear` can be provided.
--


Expand Down Expand Up @@ -311,3 +323,27 @@ GET /test/_search
}
}
--------------------------------------------------
[[rank-feature-query-linear]]
===== Linear
The `linear` function is the simplest function, and gives a score equal
to the indexed value of `S`, where `S` is the value of the rank feature
field.
If a rank feature field is indexed with `"positive_score_impact": true`,
its indexed value is equal to `S` and rounded to preserve only
9 significant bits for the precision.
If a rank feature field is indexed with `"positive_score_impact": false`,
its indexed value is equal to `1/S` and rounded to preserve only 9 significant
bits for the precision.

[source,console]
--------------------------------------------------
GET /test/_search
{
"query": {
"rank_feature": {
"field": "pagerank",
"linear": {}
}
}
}
--------------------------------------------------
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.RankFeatureFieldMapper.RankFeatureFieldType;
Expand Down Expand Up @@ -104,7 +105,7 @@ void doXContent(XContentBuilder builder) throws IOException {
}

@Override
Query toQuery(String field, String feature, boolean positiveScoreImpact) throws IOException {
Query toQuery(String field, String feature, boolean positiveScoreImpact) {
if (positiveScoreImpact == false) {
throw new IllegalArgumentException("Cannot use the [log] function with a field that has a negative score impact as " +
"it would trigger negative scores");
Expand Down Expand Up @@ -175,7 +176,7 @@ void doXContent(XContentBuilder builder) throws IOException {
}

@Override
Query toQuery(String field, String feature, boolean positiveScoreImpact) throws IOException {
Query toQuery(String field, String feature, boolean positiveScoreImpact) {
if (pivot == null) {
return FeatureField.newSaturationQuery(field, feature);
} else {
Expand Down Expand Up @@ -240,10 +241,55 @@ void doXContent(XContentBuilder builder) throws IOException {
}

@Override
Query toQuery(String field, String feature, boolean positiveScoreImpact) throws IOException {
Query toQuery(String field, String feature, boolean positiveScoreImpact) {
return FeatureField.newSigmoidQuery(field, feature, DEFAULT_BOOST, pivot, exp);
}
}

/**
* A scoring function that scores documents as simply {@code S}
* where S is the indexed value of the static feature.
*/
public static class Linear extends ScoreFunction {

private static final ObjectParser<Linear, Void> PARSER = new ObjectParser<>("linear", Linear::new);

public Linear() {
}

private Linear(StreamInput in) {
this();
}

@Override
public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) {
return false;
}
return true;
}

@Override
public int hashCode() {
return getClass().hashCode();
}

@Override
void writeTo(StreamOutput out) throws IOException {
out.writeByte((byte) 3);
}

@Override
void doXContent(XContentBuilder builder) throws IOException {
builder.startObject("linear");
builder.endObject();
}

@Override
Query toQuery(String field, String feature, boolean positiveScoreImpact) {
return FeatureField.newLinearQuery(field, feature, DEFAULT_BOOST);
}
}
}

private static ScoreFunction readScoreFunction(StreamInput in) throws IOException {
Expand All @@ -255,6 +301,8 @@ private static ScoreFunction readScoreFunction(StreamInput in) throws IOExceptio
return new ScoreFunction.Saturation(in);
case 2:
return new ScoreFunction.Sigmoid(in);
case 3:
return new ScoreFunction.Linear(in);
default:
throw new IOException("Illegal score function id: " + b);
}
Expand All @@ -268,7 +316,7 @@ private static ScoreFunction readScoreFunction(StreamInput in) throws IOExceptio
long numNonNulls = Arrays.stream(args, 3, args.length).filter(Objects::nonNull).count();
final RankFeatureQueryBuilder query;
if (numNonNulls > 1) {
throw new IllegalArgumentException("Can only specify one of [log], [saturation] and [sigmoid]");
throw new IllegalArgumentException("Can only specify one of [log], [saturation], [sigmoid] and [linear]");
} else if (numNonNulls == 0) {
query = new RankFeatureQueryBuilder(field, new ScoreFunction.Saturation());
} else {
Expand All @@ -292,6 +340,8 @@ private static ScoreFunction readScoreFunction(StreamInput in) throws IOExceptio
ScoreFunction.Saturation.PARSER, new ParseField("saturation"));
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(),
ScoreFunction.Sigmoid.PARSER, new ParseField("sigmoid"));
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(),
ScoreFunction.Linear.PARSER, new ParseField("linear"));
}

public static final String NAME = "rank_feature";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,13 @@ public static RankFeatureQueryBuilder sigmoid(String fieldName, float pivot, flo
return new RankFeatureQueryBuilder(fieldName, new RankFeatureQueryBuilder.ScoreFunction.Sigmoid(pivot, exp));
}

/**
* Return a new {@link RankFeatureQueryBuilder} that will score documents as
* {@code S)} where S is the indexed value of the static feature.
* @param fieldName field that stores features
*/
public static RankFeatureQueryBuilder linear(String fieldName) {
return new RankFeatureQueryBuilder(fieldName, new RankFeatureQueryBuilder.ScoreFunction.Linear());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ protected Collection<Class<? extends Plugin>> getPlugins() {
protected RankFeatureQueryBuilder doCreateTestQueryBuilder() {
ScoreFunction function;
boolean mayUseNegativeField = true;
switch (random().nextInt(3)) {
switch (random().nextInt(4)) {
case 0:
mayUseNegativeField = false;
function = new ScoreFunction.Log(1 + randomFloat());
Expand All @@ -75,6 +75,9 @@ protected RankFeatureQueryBuilder doCreateTestQueryBuilder() {
case 2:
function = new ScoreFunction.Sigmoid(randomFloat(), randomFloat());
break;
case 3:
function = new ScoreFunction.Linear();
break;
default:
throw new AssertionError();
}
Expand Down Expand Up @@ -106,7 +109,7 @@ public void testDefaultScoreFunction() throws IOException {
assertEquals(FeatureField.newSaturationQuery("_feature", "my_feature_field"), parsedQuery);
}

public void testIllegalField() throws IOException {
public void testIllegalField() {
String query = "{\n" +
" \"rank_feature\" : {\n" +
" \"field\": \"" + TEXT_FIELD_NAME + "\"\n" +
Expand All @@ -118,7 +121,7 @@ public void testIllegalField() throws IOException {
e.getMessage());
}

public void testIllegalCombination() throws IOException {
public void testIllegalCombination() {
String query = "{\n" +
" \"rank_feature\" : {\n" +
" \"field\": \"my_negative_feature_field\",\n" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ setup:

- do:
search:
rest_total_hits_as_int: true
index: test
body:
query:
rank_feature:
Expand All @@ -46,7 +46,7 @@ setup:
scaling_factor: 3

- match:
hits.total: 2
hits.total.value: 2

- match:
hits.hits.0._id: "2"
Expand All @@ -59,7 +59,7 @@ setup:

- do:
search:
rest_total_hits_as_int: true
index: test
body:
query:
rank_feature:
Expand All @@ -68,7 +68,7 @@ setup:
pivot: 20

- match:
hits.total: 2
hits.total.value: 2

- match:
hits.hits.0._id: "2"
Expand All @@ -81,7 +81,7 @@ setup:

- do:
search:
rest_total_hits_as_int: true
index: test
body:
query:
rank_feature:
Expand All @@ -91,7 +91,27 @@ setup:
exponent: 0.6

- match:
hits.total: 2
hits.total.value: 2

- match:
hits.hits.0._id: "2"

- match:
hits.hits.1._id: "1"

---
"Positive linear":
- do:
search:
index: test
body:
query:
rank_feature:
field: pagerank
linear: {}

- match:
hits.total.value: 2

- match:
hits.hits.0._id: "2"
Expand All @@ -105,7 +125,7 @@ setup:
- do:
catch: bad_request
search:
rest_total_hits_as_int: true
index: test
body:
query:
rank_feature:
Expand All @@ -118,7 +138,7 @@ setup:

- do:
search:
rest_total_hits_as_int: true
index: test
body:
query:
rank_feature:
Expand All @@ -127,7 +147,7 @@ setup:
pivot: 20

- match:
hits.total: 2
hits.total.value: 2

- match:
hits.hits.0._id: "2"
Expand All @@ -140,7 +160,7 @@ setup:

- do:
search:
rest_total_hits_as_int: true
index: test
body:
query:
rank_feature:
Expand All @@ -150,7 +170,28 @@ setup:
exponent: 0.6

- match:
hits.total: 2
hits.total.value: 2

- match:
hits.hits.0._id: "2"

- match:
hits.hits.1._id: "1"

---
"Negative linear":

- do:
search:
index: test
body:
query:
rank_feature:
field: url_length
linear: {}

- match:
hits.total.value: 2

- match:
hits.hits.0._id: "2"
Expand Down
Loading