Skip to content

Commit 2c7ff81

Browse files
authored
[ML] Add r_squared eval metric to regression (#44248) (#44378)
* [ML] Add r_squared eval metric to regression * fixing tests and binarysoftclassification class * Update RSquared.java * Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java Co-Authored-By: David Kyle <[email protected]> * removing unnecessary debug test
1 parent 858dbfc commit 2c7ff81

File tree

17 files changed

+694
-20
lines changed

17 files changed

+694
-20
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java

+5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package org.elasticsearch.client.ml.dataframe.evaluation;
2020

2121
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
22+
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
2223
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
2324
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
2425
import org.elasticsearch.common.ParseField;
@@ -49,13 +50,17 @@ Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClass
4950
EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent),
5051
new NamedXContentRegistry.Entry(
5152
EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent),
53+
new NamedXContentRegistry.Entry(
54+
EvaluationMetric.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric::fromXContent),
5255
// Evaluation metrics results
5356
new NamedXContentRegistry.Entry(
5457
EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.Result::fromXContent),
5558
new NamedXContentRegistry.Entry(
5659
EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric.Result::fromXContent),
5760
new NamedXContentRegistry.Entry(
5861
EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent),
62+
new NamedXContentRegistry.Entry(
63+
EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent),
5964
new NamedXContentRegistry.Entry(
6065
EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent),
6166
new NamedXContentRegistry.Entry(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
20+
21+
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
22+
import org.elasticsearch.common.ParseField;
23+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
24+
import org.elasticsearch.common.xcontent.ObjectParser;
25+
import org.elasticsearch.common.xcontent.XContentBuilder;
26+
import org.elasticsearch.common.xcontent.XContentParser;
27+
28+
import java.io.IOException;
29+
import java.util.Objects;
30+
31+
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
32+
33+
/**
34+
* Calculates R-Squared between two known numerical fields.
35+
*
36+
* equation: mse = 1 - SSres/SStot
37+
* such that,
38+
* SSres = Σ(y - y´)^2
39+
* SStot = Σ(y - y_mean)^2
40+
*/
41+
public class RSquaredMetric implements EvaluationMetric {
42+
43+
public static final String NAME = "r_squared";
44+
45+
private static final ObjectParser<RSquaredMetric, Void> PARSER =
46+
new ObjectParser<>("r_squared", true, RSquaredMetric::new);
47+
48+
public static RSquaredMetric fromXContent(XContentParser parser) {
49+
return PARSER.apply(parser, null);
50+
}
51+
52+
public RSquaredMetric() {
53+
54+
}
55+
56+
@Override
57+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
58+
builder.startObject();
59+
builder.endObject();
60+
return builder;
61+
}
62+
63+
@Override
64+
public boolean equals(Object o) {
65+
if (this == o) return true;
66+
if (o == null || getClass() != o.getClass()) return false;
67+
return true;
68+
}
69+
70+
@Override
71+
public int hashCode() {
72+
// create static hash code from name as there are currently no unique fields per class instance
73+
return Objects.hashCode(NAME);
74+
}
75+
76+
@Override
77+
public String getName() {
78+
return NAME;
79+
}
80+
81+
public static class Result implements EvaluationMetric.Result {
82+
83+
public static final ParseField VALUE = new ParseField("value");
84+
private final double value;
85+
86+
public static Result fromXContent(XContentParser parser) {
87+
return PARSER.apply(parser, null);
88+
}
89+
90+
private static final ConstructingObjectParser<Result, Void> PARSER =
91+
new ConstructingObjectParser<>("r_squared_result", true, args -> new Result((double) args[0]));
92+
93+
static {
94+
PARSER.declareDouble(constructorArg(), VALUE);
95+
}
96+
97+
public Result(double value) {
98+
this.value = value;
99+
}
100+
101+
@Override
102+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
103+
builder.startObject();
104+
builder.field(VALUE.getPreferredName(), value);
105+
builder.endObject();
106+
return builder;
107+
}
108+
109+
public double getValue() {
110+
return value;
111+
}
112+
113+
@Override
114+
public String getMetricName() {
115+
return NAME;
116+
}
117+
118+
@Override
119+
public boolean equals(Object o) {
120+
if (this == o) return true;
121+
if (o == null || getClass() != o.getClass()) return false;
122+
Result that = (Result) o;
123+
return Objects.equals(that.value, this.value);
124+
}
125+
126+
@Override
127+
public int hashCode() {
128+
return Objects.hash(value);
129+
}
130+
}
131+
}

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java

+6-2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
import java.io.IOException;
3131
import java.util.Arrays;
32+
import java.util.Comparator;
3233
import java.util.List;
3334
import java.util.Objects;
3435

@@ -84,8 +85,11 @@ public Regression(String actualField, String predictedField, EvaluationMetric...
8485
}
8586

8687
public Regression(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
87-
this.actualField = actualField;
88-
this.predictedField = predictedField;
88+
this.actualField = Objects.requireNonNull(actualField);
89+
this.predictedField = Objects.requireNonNull(predictedField);
90+
if (metrics != null) {
91+
metrics.sort(Comparator.comparing(EvaluationMetric::getName));
92+
}
8993
this.metrics = metrics;
9094
}
9195

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java

+16-5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
import java.io.IOException;
3131
import java.util.Arrays;
32+
import java.util.Comparator;
3233
import java.util.List;
3334
import java.util.Objects;
3435

@@ -52,6 +53,7 @@ public class BinarySoftClassification implements Evaluation {
5253
public static final ConstructingObjectParser<BinarySoftClassification, Void> PARSER =
5354
new ConstructingObjectParser<>(
5455
NAME,
56+
true,
5557
args -> new BinarySoftClassification((String) args[0], (String) args[1], (List<EvaluationMetric>) args[2]));
5658

5759
static {
@@ -80,6 +82,10 @@ public static BinarySoftClassification fromXContent(XContentParser parser) {
8082
*/
8183
private final List<EvaluationMetric> metrics;
8284

85+
public BinarySoftClassification(String actualField, String predictedField) {
86+
this(actualField, predictedField, (List<EvaluationMetric>)null);
87+
}
88+
8389
public BinarySoftClassification(String actualField, String predictedProbabilityField, EvaluationMetric... metric) {
8490
this(actualField, predictedProbabilityField, Arrays.asList(metric));
8591
}
@@ -88,7 +94,10 @@ public BinarySoftClassification(String actualField, String predictedProbabilityF
8894
@Nullable List<EvaluationMetric> metrics) {
8995
this.actualField = Objects.requireNonNull(actualField);
9096
this.predictedProbabilityField = Objects.requireNonNull(predictedProbabilityField);
91-
this.metrics = Objects.requireNonNull(metrics);
97+
if (metrics != null) {
98+
metrics.sort(Comparator.comparing(EvaluationMetric::getName));
99+
}
100+
this.metrics = metrics;
92101
}
93102

94103
@Override
@@ -102,11 +111,13 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
102111
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
103112
builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField);
104113

105-
builder.startObject(METRICS.getPreferredName());
106-
for (EvaluationMetric metric : metrics) {
107-
builder.field(metric.getName(), metric);
114+
if (metrics != null) {
115+
builder.startObject(METRICS.getPreferredName());
116+
for (EvaluationMetric metric : metrics) {
117+
builder.field(metric.getName(), metric);
118+
}
119+
builder.endObject();
108120
}
109-
builder.endObject();
110121

111122
builder.endObject();
112123
return builder;

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

+8-2
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@
124124
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
125125
import org.elasticsearch.client.ml.dataframe.QueryConfig;
126126
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
127+
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
127128
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
128129
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
129130
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
@@ -1597,16 +1598,21 @@ public void testEvaluateDataFrame() throws IOException {
15971598
.add(docForRegression(regressionIndex, 0.5, 0.9)); // #9
15981599
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
15991600

1600-
evaluateDataFrameRequest = new EvaluateDataFrameRequest(regressionIndex, new Regression(actualRegression, probabilityRegression));
1601+
evaluateDataFrameRequest = new EvaluateDataFrameRequest(regressionIndex,
1602+
new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
16011603

16021604
evaluateDataFrameResponse =
16031605
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
16041606
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME));
1605-
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
1607+
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(2));
16061608

16071609
MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME);
16081610
assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME));
16091611
assertThat(mseResult.getError(), closeTo(0.061000000, 1e-9));
1612+
1613+
RSquaredMetric.Result rSquaredResult = evaluateDataFrameResponse.getMetricByName(RSquaredMetric.NAME);
1614+
assertThat(rSquaredResult.getMetricName(), equalTo(RSquaredMetric.NAME));
1615+
assertThat(rSquaredResult.getValue(), closeTo(-5.1000000000000005, 1e-9));
16101616
}
16111617

16121618
private static XContentBuilder defaultMappingForTest() throws IOException {

client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

+16-5
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
6262
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
6363
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
64+
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
6465
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
6566
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
6667
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
@@ -676,7 +677,7 @@ public void testDefaultNamedXContents() {
676677

677678
public void testProvidedNamedXContents() {
678679
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
679-
assertEquals(34, namedXContents.size());
680+
assertEquals(36, namedXContents.size());
680681
Map<Class<?>, Integer> categories = new HashMap<>();
681682
List<String> names = new ArrayList<>();
682683
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@@ -716,12 +717,22 @@ public void testProvidedNamedXContents() {
716717
assertTrue(names.contains(TimeSyncConfig.NAME));
717718
assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
718719
assertThat(names, hasItems(BinarySoftClassification.NAME, Regression.NAME));
719-
assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
720+
assertEquals(Integer.valueOf(6), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
720721
assertThat(names,
721-
hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME));
722-
assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
722+
hasItems(AucRocMetric.NAME,
723+
PrecisionMetric.NAME,
724+
RecallMetric.NAME,
725+
ConfusionMatrixMetric.NAME,
726+
MeanSquaredErrorMetric.NAME,
727+
RSquaredMetric.NAME));
728+
assertEquals(Integer.valueOf(6), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
723729
assertThat(names,
724-
hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME));
730+
hasItems(AucRocMetric.NAME,
731+
PrecisionMetric.NAME,
732+
RecallMetric.NAME,
733+
ConfusionMatrixMetric.NAME,
734+
MeanSquaredErrorMetric.NAME,
735+
RSquaredMetric.NAME));
725736
}
726737

727738
public void testApiNamingConventions() throws Exception {

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricConfusionMatrixTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
public class ConfusionMatrixMetricConfusionMatrixTests extends AbstractXContentTestCase<ConfusionMatrixMetric.ConfusionMatrix> {
2828

29-
static ConfusionMatrixMetric.ConfusionMatrix randomConfusionMatrix() {
29+
public static ConfusionMatrixMetric.ConfusionMatrix randomConfusionMatrix() {
3030
return new ConfusionMatrixMetric.ConfusionMatrix(randomInt(), randomInt(), randomInt(), randomInt());
3131
}
3232

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
20+
21+
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
22+
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
23+
import org.elasticsearch.common.xcontent.XContentParser;
24+
import org.elasticsearch.test.AbstractXContentTestCase;
25+
26+
import java.io.IOException;
27+
28+
public class RSquaredMetricResultTests extends AbstractXContentTestCase<RSquaredMetric.Result> {
29+
30+
public static RSquaredMetric.Result randomResult() {
31+
return new RSquaredMetric.Result(randomDouble());
32+
}
33+
34+
@Override
35+
protected RSquaredMetric.Result createTestInstance() {
36+
return randomResult();
37+
}
38+
39+
@Override
40+
protected RSquaredMetric.Result doParseInstance(XContentParser parser) throws IOException {
41+
return RSquaredMetric.Result.fromXContent(parser);
42+
}
43+
44+
@Override
45+
protected boolean supportsUnknownFields() {
46+
return true;
47+
}
48+
49+
@Override
50+
protected NamedXContentRegistry xContentRegistry() {
51+
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
52+
}
53+
}

0 commit comments

Comments
 (0)