Skip to content

Commit c7ac201

Browse files
authored
[7.x] Implement accuracy metric for multiclass classification (#47772) (#49430)
1 parent 03600e4 commit c7ac201

File tree

18 files changed

+913
-77
lines changed

18 files changed

+913
-77
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java

+5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
*/
1919
package org.elasticsearch.client.ml.dataframe.evaluation;
2020

21+
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
2122
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
2223
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
2324
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
@@ -51,6 +52,8 @@ Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClass
5152
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent),
5253
new NamedXContentRegistry.Entry(
5354
EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent),
55+
new NamedXContentRegistry.Entry(
56+
EvaluationMetric.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric::fromXContent),
5457
new NamedXContentRegistry.Entry(
5558
EvaluationMetric.class,
5659
new ParseField(MulticlassConfusionMatrixMetric.NAME),
@@ -68,6 +71,8 @@ EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMe
6871
EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent),
6972
new NamedXContentRegistry.Entry(
7073
EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent),
74+
new NamedXContentRegistry.Entry(
75+
EvaluationMetric.Result.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric.Result::fromXContent),
7176
new NamedXContentRegistry.Entry(
7277
EvaluationMetric.Result.class,
7378
new ParseField(MulticlassConfusionMatrixMetric.NAME),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
20+
21+
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
22+
import org.elasticsearch.common.ParseField;
23+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
24+
import org.elasticsearch.common.xcontent.ObjectParser;
25+
import org.elasticsearch.common.xcontent.ToXContent;
26+
import org.elasticsearch.common.xcontent.ToXContentObject;
27+
import org.elasticsearch.common.xcontent.XContentBuilder;
28+
import org.elasticsearch.common.xcontent.XContentParser;
29+
30+
import java.io.IOException;
31+
import java.util.Collections;
32+
import java.util.List;
33+
import java.util.Objects;
34+
35+
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
36+
37+
/**
38+
* {@link AccuracyMetric} is a metric that answers the question:
39+
* "What fraction of examples have been classified correctly by the classifier?"
40+
*
41+
* equation: accuracy = 1/n * Σ(y == y´)
42+
*/
43+
public class AccuracyMetric implements EvaluationMetric {
44+
45+
public static final String NAME = "accuracy";
46+
47+
private static final ObjectParser<AccuracyMetric, Void> PARSER = new ObjectParser<>(NAME, true, AccuracyMetric::new);
48+
49+
public static AccuracyMetric fromXContent(XContentParser parser) {
50+
return PARSER.apply(parser, null);
51+
}
52+
53+
public AccuracyMetric() {}
54+
55+
@Override
56+
public String getName() {
57+
return NAME;
58+
}
59+
60+
@Override
61+
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
62+
builder.startObject();
63+
builder.endObject();
64+
return builder;
65+
}
66+
67+
@Override
68+
public boolean equals(Object o) {
69+
if (this == o) return true;
70+
if (o == null || getClass() != o.getClass()) return false;
71+
return true;
72+
}
73+
74+
@Override
75+
public int hashCode() {
76+
return Objects.hashCode(NAME);
77+
}
78+
79+
public static class Result implements EvaluationMetric.Result {
80+
81+
private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes");
82+
private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
83+
84+
@SuppressWarnings("unchecked")
85+
private static final ConstructingObjectParser<Result, Void> PARSER =
86+
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1]));
87+
88+
static {
89+
PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES);
90+
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
91+
}
92+
93+
public static Result fromXContent(XContentParser parser) {
94+
return PARSER.apply(parser, null);
95+
}
96+
97+
/** List of actual classes. */
98+
private final List<ActualClass> actualClasses;
99+
/** Fraction of documents predicted correctly. */
100+
private final double overallAccuracy;
101+
102+
public Result(List<ActualClass> actualClasses, double overallAccuracy) {
103+
this.actualClasses = Collections.unmodifiableList(Objects.requireNonNull(actualClasses));
104+
this.overallAccuracy = overallAccuracy;
105+
}
106+
107+
@Override
108+
public String getMetricName() {
109+
return NAME;
110+
}
111+
112+
public List<ActualClass> getActualClasses() {
113+
return actualClasses;
114+
}
115+
116+
public double getOverallAccuracy() {
117+
return overallAccuracy;
118+
}
119+
120+
@Override
121+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
122+
builder.startObject();
123+
builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses);
124+
builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
125+
builder.endObject();
126+
return builder;
127+
}
128+
129+
@Override
130+
public boolean equals(Object o) {
131+
if (this == o) return true;
132+
if (o == null || getClass() != o.getClass()) return false;
133+
Result that = (Result) o;
134+
return Objects.equals(this.actualClasses, that.actualClasses)
135+
&& this.overallAccuracy == that.overallAccuracy;
136+
}
137+
138+
@Override
139+
public int hashCode() {
140+
return Objects.hash(actualClasses, overallAccuracy);
141+
}
142+
}
143+
144+
public static class ActualClass implements ToXContentObject {
145+
146+
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
147+
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
148+
private static final ParseField ACCURACY = new ParseField("accuracy");
149+
150+
@SuppressWarnings("unchecked")
151+
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
152+
new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2]));
153+
154+
static {
155+
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
156+
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
157+
PARSER.declareDouble(constructorArg(), ACCURACY);
158+
}
159+
160+
/** Name of the actual class. */
161+
private final String actualClass;
162+
/** Number of documents (examples) belonging to the {code actualClass} class. */
163+
private final long actualClassDocCount;
164+
/** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
165+
private final double accuracy;
166+
167+
public ActualClass(
168+
String actualClass, long actualClassDocCount, double accuracy) {
169+
this.actualClass = Objects.requireNonNull(actualClass);
170+
this.actualClassDocCount = actualClassDocCount;
171+
this.accuracy = accuracy;
172+
}
173+
174+
public String getActualClass() {
175+
return actualClass;
176+
}
177+
178+
public long getActualClassDocCount() {
179+
return actualClassDocCount;
180+
}
181+
182+
public double getAccuracy() {
183+
return accuracy;
184+
}
185+
186+
@Override
187+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
188+
builder.startObject();
189+
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
190+
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
191+
builder.field(ACCURACY.getPreferredName(), accuracy);
192+
builder.endObject();
193+
return builder;
194+
}
195+
196+
@Override
197+
public boolean equals(Object o) {
198+
if (this == o) return true;
199+
if (o == null || getClass() != o.getClass()) return false;
200+
ActualClass that = (ActualClass) o;
201+
return Objects.equals(this.actualClass, that.actualClass)
202+
&& this.actualClassDocCount == that.actualClassDocCount
203+
&& this.accuracy == that.accuracy;
204+
}
205+
206+
@Override
207+
public int hashCode() {
208+
return Objects.hash(actualClass, actualClassDocCount, accuracy);
209+
}
210+
}
211+
}

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

+22
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@
125125
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
126126
import org.elasticsearch.client.ml.dataframe.PhaseProgress;
127127
import org.elasticsearch.client.ml.dataframe.QueryConfig;
128+
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
128129
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
129130
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
130131
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
@@ -1813,6 +1814,27 @@ public void testEvaluateDataFrame_Classification() throws IOException {
18131814

18141815
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
18151816

1817+
{ // Accuracy
1818+
EvaluateDataFrameRequest evaluateDataFrameRequest =
1819+
new EvaluateDataFrameRequest(
1820+
indexName, null, new Classification(actualClassField, predictedClassField, new AccuracyMetric()));
1821+
1822+
EvaluateDataFrameResponse evaluateDataFrameResponse =
1823+
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
1824+
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
1825+
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
1826+
1827+
AccuracyMetric.Result accuracyResult = evaluateDataFrameResponse.getMetricByName(AccuracyMetric.NAME);
1828+
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
1829+
assertThat(
1830+
accuracyResult.getActualClasses(),
1831+
equalTo(
1832+
Arrays.asList(
1833+
new AccuracyMetric.ActualClass("cat", 5, 0.6), // 3 out of 5 examples labeled as "cat" were classified correctly
1834+
new AccuracyMetric.ActualClass("dog", 4, 0.75), // 3 out of 4 examples labeled as "dog" were classified correctly
1835+
new AccuracyMetric.ActualClass("ant", 1, 0.0)))); // no examples labeled as "ant" were classified correctly
1836+
assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly
1837+
}
18161838
{ // No size provided for MulticlassConfusionMatrixMetric, default used instead
18171839
EvaluateDataFrameRequest evaluateDataFrameRequest =
18181840
new EvaluateDataFrameRequest(

client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

+6-3
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
import org.elasticsearch.client.indexlifecycle.UnfollowAction;
5858
import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
5959
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
60+
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
6061
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
6162
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
6263
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
@@ -687,7 +688,7 @@ public void testDefaultNamedXContents() {
687688

688689
public void testProvidedNamedXContents() {
689690
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
690-
assertEquals(49, namedXContents.size());
691+
assertEquals(51, namedXContents.size());
691692
Map<Class<?>, Integer> categories = new HashMap<>();
692693
List<String> names = new ArrayList<>();
693694
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@@ -729,21 +730,23 @@ public void testProvidedNamedXContents() {
729730
assertTrue(names.contains(TimeSyncConfig.NAME));
730731
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
731732
assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME));
732-
assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
733+
assertEquals(Integer.valueOf(8), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
733734
assertThat(names,
734735
hasItems(AucRocMetric.NAME,
735736
PrecisionMetric.NAME,
736737
RecallMetric.NAME,
737738
ConfusionMatrixMetric.NAME,
739+
AccuracyMetric.NAME,
738740
MulticlassConfusionMatrixMetric.NAME,
739741
MeanSquaredErrorMetric.NAME,
740742
RSquaredMetric.NAME));
741-
assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
743+
assertEquals(Integer.valueOf(8), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
742744
assertThat(names,
743745
hasItems(AucRocMetric.NAME,
744746
PrecisionMetric.NAME,
745747
RecallMetric.NAME,
746748
ConfusionMatrixMetric.NAME,
749+
AccuracyMetric.NAME,
747750
MulticlassConfusionMatrixMetric.NAME,
748751
MeanSquaredErrorMetric.NAME,
749752
RSquaredMetric.NAME));

client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

+12-4
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@
141141
import org.elasticsearch.client.ml.dataframe.QueryConfig;
142142
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
143143
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
144+
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
144145
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
145146
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
146147
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
@@ -3347,20 +3348,27 @@ public void testEvaluateDataFrame_Classification() throws Exception {
33473348
"actual_class", // <2>
33483349
"predicted_class", // <3>
33493350
// Evaluation metrics // <4>
3350-
new MulticlassConfusionMatrixMetric(3)); // <5>
3351+
new AccuracyMetric(), // <5>
3352+
new MulticlassConfusionMatrixMetric(3)); // <6>
33513353
// end::evaluate-data-frame-evaluation-classification
33523354

33533355
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
33543356
EvaluateDataFrameResponse response = client.machineLearning().evaluateDataFrame(request, RequestOptions.DEFAULT);
33553357

33563358
// tag::evaluate-data-frame-results-classification
3359+
AccuracyMetric.Result accuracyResult = response.getMetricByName(AccuracyMetric.NAME); // <1>
3360+
double accuracy = accuracyResult.getOverallAccuracy(); // <2>
3361+
33573362
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
3358-
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1>
3363+
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <3>
33593364

3360-
List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2>
3361-
long otherActualClassCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <3>
3365+
List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <4>
3366+
long otherActualClassCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <5>
33623367
// end::evaluate-data-frame-results-classification
33633368

3369+
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
3370+
assertThat(accuracy, equalTo(0.6));
3371+
33643372
assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
33653373
assertThat(
33663374
confusionMatrix,

0 commit comments

Comments
 (0)