Skip to content

Commit 38aa474

Browse files
authored
Implement pseudo Huber loss (PseudoHuber) evaluation metric for regression analysis (#58734)
1 parent ad0436f commit 38aa474

File tree

16 files changed

+606
-12
lines changed

16 files changed

+606
-12
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
2424
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
2525
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
26+
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
2627
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
2728
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
2829
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@@ -102,6 +103,10 @@ Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClass
102103
EvaluationMetric.class,
103104
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME)),
104105
MeanSquaredLogarithmicErrorMetric::fromXContent),
106+
new NamedXContentRegistry.Entry(
107+
EvaluationMetric.class,
108+
new ParseField(registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME)),
109+
PseudoHuberMetric::fromXContent),
105110
new NamedXContentRegistry.Entry(
106111
EvaluationMetric.class,
107112
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
@@ -149,6 +154,10 @@ Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClass
149154
EvaluationMetric.Result.class,
150155
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME)),
151156
MeanSquaredLogarithmicErrorMetric.Result::fromXContent),
157+
new NamedXContentRegistry.Entry(
158+
EvaluationMetric.Result.class,
159+
new ParseField(registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME)),
160+
PseudoHuberMetric.Result::fromXContent),
152161
new NamedXContentRegistry.Entry(
153162
EvaluationMetric.Result.class,
154163
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
20+
21+
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
22+
import org.elasticsearch.common.Nullable;
23+
import org.elasticsearch.common.ParseField;
24+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
25+
import org.elasticsearch.common.xcontent.XContentBuilder;
26+
import org.elasticsearch.common.xcontent.XContentParser;
27+
28+
import java.io.IOException;
29+
import java.util.Objects;
30+
31+
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
32+
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
33+
34+
/**
35+
* Calculates the pseudo Huber loss function.
36+
*
37+
* equation: pseudohuber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1)
38+
* where: a = y - y´
39+
* δ - parameter that controls the steepness
40+
*/
41+
public class PseudoHuberMetric implements EvaluationMetric {
42+
43+
public static final String NAME = "pseudo_huber";
44+
45+
public static final ParseField DELTA = new ParseField("delta");
46+
47+
private static final ConstructingObjectParser<PseudoHuberMetric, Void> PARSER =
48+
new ConstructingObjectParser<>(NAME, true, args -> new PseudoHuberMetric((Double) args[0]));
49+
50+
static {
51+
PARSER.declareDouble(optionalConstructorArg(), DELTA);
52+
}
53+
54+
public static PseudoHuberMetric fromXContent(XContentParser parser) {
55+
return PARSER.apply(parser, null);
56+
}
57+
58+
private final Double delta;
59+
60+
public PseudoHuberMetric(@Nullable Double delta) {
61+
this.delta = delta;
62+
}
63+
64+
@Override
65+
public String getName() {
66+
return NAME;
67+
}
68+
69+
@Override
70+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
71+
builder.startObject();
72+
if (delta != null) {
73+
builder.field(DELTA.getPreferredName(), delta);
74+
}
75+
builder.endObject();
76+
return builder;
77+
}
78+
79+
@Override
80+
public boolean equals(Object o) {
81+
if (this == o) return true;
82+
if (o == null || getClass() != o.getClass()) return false;
83+
PseudoHuberMetric that = (PseudoHuberMetric) o;
84+
return Objects.equals(this.delta, that.delta);
85+
}
86+
87+
@Override
88+
public int hashCode() {
89+
return Objects.hash(delta);
90+
}
91+
92+
public static class Result implements EvaluationMetric.Result {
93+
94+
public static final ParseField VALUE = new ParseField("value");
95+
private final double value;
96+
97+
public static Result fromXContent(XContentParser parser) {
98+
return PARSER.apply(parser, null);
99+
}
100+
101+
private static final ConstructingObjectParser<Result, Void> PARSER =
102+
new ConstructingObjectParser<>("pseudo_huber_result", true, args -> new Result((double) args[0]));
103+
104+
static {
105+
PARSER.declareDouble(constructorArg(), VALUE);
106+
}
107+
108+
public Result(double value) {
109+
this.value = value;
110+
}
111+
112+
@Override
113+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
114+
builder.startObject();
115+
builder.field(VALUE.getPreferredName(), value);
116+
builder.endObject();
117+
return builder;
118+
}
119+
120+
public double getValue() {
121+
return value;
122+
}
123+
124+
@Override
125+
public String getMetricName() {
126+
return NAME;
127+
}
128+
129+
@Override
130+
public boolean equals(Object o) {
131+
if (this == o) return true;
132+
if (o == null || getClass() != o.getClass()) return false;
133+
Result that = (Result) o;
134+
return Objects.equals(that.value, this.value);
135+
}
136+
137+
@Override
138+
public int hashCode() {
139+
return Double.hashCode(value);
140+
}
141+
}
142+
}

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@
143143
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
144144
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
145145
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
146+
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
146147
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
147148
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
148149
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@@ -1856,12 +1857,15 @@ public void testEvaluateDataFrame_Regression() throws IOException {
18561857
new Regression(
18571858
actualRegression,
18581859
predictedRegression,
1859-
new MeanSquaredErrorMetric(), new MeanSquaredLogarithmicErrorMetric(1.0), new RSquaredMetric()));
1860+
new MeanSquaredErrorMetric(),
1861+
new MeanSquaredLogarithmicErrorMetric(1.0),
1862+
new PseudoHuberMetric(1.0),
1863+
new RSquaredMetric()));
18601864

18611865
EvaluateDataFrameResponse evaluateDataFrameResponse =
18621866
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
18631867
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME));
1864-
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(3));
1868+
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4));
18651869

18661870
MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME);
18671871
assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME));
@@ -1872,6 +1876,10 @@ public void testEvaluateDataFrame_Regression() throws IOException {
18721876
assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicErrorMetric.NAME));
18731877
assertThat(msleResult.getError(), closeTo(0.02759231770210426, 1e-9));
18741878

1879+
PseudoHuberMetric.Result pseudoHuberResult = evaluateDataFrameResponse.getMetricByName(PseudoHuberMetric.NAME);
1880+
assertThat(pseudoHuberResult.getMetricName(), equalTo(PseudoHuberMetric.NAME));
1881+
assertThat(pseudoHuberResult.getValue(), closeTo(0.029669771640929276, 1e-9));
1882+
18751883
RSquaredMetric.Result rSquaredResult = evaluateDataFrameResponse.getMetricByName(RSquaredMetric.NAME);
18761884
assertThat(rSquaredResult.getMetricName(), equalTo(RSquaredMetric.NAME));
18771885
assertThat(rSquaredResult.getValue(), closeTo(-5.1000000000000005, 1e-9));

client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
6363
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
6464
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
65+
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
6566
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
6667
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
6768
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@@ -702,7 +703,7 @@ public void testDefaultNamedXContents() {
702703

703704
public void testProvidedNamedXContents() {
704705
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
705-
assertEquals(66, namedXContents.size());
706+
assertEquals(68, namedXContents.size());
706707
Map<Class<?>, Integer> categories = new HashMap<>();
707708
List<String> names = new ArrayList<>();
708709
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@@ -749,7 +750,7 @@ public void testProvidedNamedXContents() {
749750
assertTrue(names.contains(TimeSyncConfig.NAME));
750751
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
751752
assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME));
752-
assertEquals(Integer.valueOf(11), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
753+
assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
753754
assertThat(names,
754755
hasItems(
755756
registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME),
@@ -764,8 +765,9 @@ public void testProvidedNamedXContents() {
764765
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
765766
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
766767
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
768+
registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME),
767769
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
768-
assertEquals(Integer.valueOf(11), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
770+
assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
769771
assertThat(names,
770772
hasItems(
771773
registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME),
@@ -780,6 +782,7 @@ public void testProvidedNamedXContents() {
780782
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
781783
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
782784
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
785+
registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME),
783786
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
784787
assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
785788
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME));

client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@
162162
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
163163
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
164164
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
165+
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
165166
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
166167
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
167168
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
@@ -3572,7 +3573,8 @@ public void testEvaluateDataFrame_Regression() throws Exception {
35723573
// Evaluation metrics // <4>
35733574
new MeanSquaredErrorMetric(), // <5>
35743575
new MeanSquaredLogarithmicErrorMetric(1.0), // <6>
3575-
new RSquaredMetric()); // <7>
3576+
new PseudoHuberMetric(1.0), // <7>
3577+
new RSquaredMetric()); // <8>
35763578
// end::evaluate-data-frame-evaluation-regression
35773579

35783580
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
@@ -3586,12 +3588,16 @@ public void testEvaluateDataFrame_Regression() throws Exception {
35863588
response.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); // <3>
35873589
double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getError(); // <4>
35883590

3589-
RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <5>
3590-
double rSquared = rSquaredResult.getValue(); // <6>
3591+
PseudoHuberMetric.Result pseudoHuberResult = response.getMetricByName(PseudoHuberMetric.NAME); // <5>
3592+
double pseudoHuber = pseudoHuberResult.getValue(); // <6>
3593+
3594+
RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <7>
3595+
double rSquared = rSquaredResult.getValue(); // <8>
35913596
// end::evaluate-data-frame-results-regression
35923597

35933598
assertThat(meanSquaredError, closeTo(0.021, 1e-3));
35943599
assertThat(meanSquaredLogarithmicError, closeTo(0.003, 1e-3));
3600+
assertThat(pseudoHuber, closeTo(0.01, 1e-3));
35953601
assertThat(rSquared, closeTo(0.941, 1e-3));
35963602
}
35973603
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
20+
21+
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
22+
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
23+
import org.elasticsearch.common.xcontent.XContentParser;
24+
import org.elasticsearch.test.AbstractXContentTestCase;
25+
26+
import java.io.IOException;
27+
28+
public class PseudoHuberMetricResultTests extends AbstractXContentTestCase<PseudoHuberMetric.Result> {
29+
30+
public static PseudoHuberMetric.Result randomResult() {
31+
return new PseudoHuberMetric.Result(randomDouble());
32+
}
33+
34+
@Override
35+
protected PseudoHuberMetric.Result createTestInstance() {
36+
return randomResult();
37+
}
38+
39+
@Override
40+
protected PseudoHuberMetric.Result doParseInstance(XContentParser parser) throws IOException {
41+
return PseudoHuberMetric.Result.fromXContent(parser);
42+
}
43+
44+
@Override
45+
protected boolean supportsUnknownFields() {
46+
return true;
47+
}
48+
49+
@Override
50+
protected NamedXContentRegistry xContentRegistry() {
51+
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
52+
}
53+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
20+
21+
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
22+
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
23+
import org.elasticsearch.common.xcontent.XContentParser;
24+
import org.elasticsearch.test.AbstractXContentTestCase;
25+
26+
import java.io.IOException;
27+
28+
public class PseudoHuberMetricTests extends AbstractXContentTestCase<PseudoHuberMetric> {
29+
30+
@Override
31+
protected NamedXContentRegistry xContentRegistry() {
32+
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
33+
}
34+
35+
@Override
36+
protected PseudoHuberMetric createTestInstance() {
37+
return new PseudoHuberMetric(randomBoolean() ? randomDouble() : null);
38+
}
39+
40+
@Override
41+
protected PseudoHuberMetric doParseInstance(XContentParser parser) throws IOException {
42+
return PseudoHuberMetric.fromXContent(parser);
43+
}
44+
45+
@Override
46+
protected boolean supportsUnknownFields() {
47+
return true;
48+
}
49+
}

0 commit comments

Comments
 (0)