Skip to content

Commit e901f90

Browse files
authored
Fix accuracy metric (#50310)
1 parent 0a66fef commit e901f90

File tree

14 files changed

+472
-295
lines changed

14 files changed

+472
-295
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java

+53-44
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
2222
import org.elasticsearch.common.ParseField;
23+
import org.elasticsearch.common.Strings;
2324
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
2425
import org.elasticsearch.common.xcontent.ObjectParser;
2526
import org.elasticsearch.common.xcontent.ToXContent;
@@ -35,10 +36,25 @@
3536
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
3637

3738
/**
38-
* {@link AccuracyMetric} is a metric that answers the question:
39-
* "What fraction of examples have been classified correctly by the classifier?"
39+
* {@link AccuracyMetric} is a metric that answers the following two questions:
4040
*
41-
* equation: accuracy = 1/n * Σ(y == y´)
41+
* 1. What is the fraction of documents for which predicted class equals the actual class?
42+
*
43+
* equation: overall_accuracy = 1/n * Σ(y == y')
44+
* where: n = total number of documents
45+
* y = document's actual class
46+
* y' = document's predicted class
47+
*
48+
* 2. For any given class X, what is the fraction of documents for which either
49+
* a) both actual and predicted class are equal to X (true positives)
50+
* or
51+
* b) both actual and predicted class are not equal to X (true negatives)
52+
*
53+
* equation: accuracy(X) = 1/n * (TP(X) + TN(X))
54+
* where: X = class being examined
55+
* n = total number of documents
56+
* TP(X) = number of true positives wrt X
57+
* TN(X) = number of true negatives wrt X
4258
*/
4359
public class AccuracyMetric implements EvaluationMetric {
4460

@@ -78,29 +94,29 @@ public int hashCode() {
7894

7995
public static class Result implements EvaluationMetric.Result {
8096

81-
private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes");
97+
private static final ParseField CLASSES = new ParseField("classes");
8298
private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
8399

84100
@SuppressWarnings("unchecked")
85101
private static final ConstructingObjectParser<Result, Void> PARSER =
86-
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1]));
102+
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
87103

88104
static {
89-
PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES);
105+
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
90106
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
91107
}
92108

93109
public static Result fromXContent(XContentParser parser) {
94110
return PARSER.apply(parser, null);
95111
}
96112

97-
/** List of actual classes. */
98-
private final List<ActualClass> actualClasses;
99-
/** Fraction of documents predicted correctly. */
113+
/** List of per-class results. */
114+
private final List<PerClassResult> classes;
115+
/** Fraction of documents for which predicted class equals the actual class. */
100116
private final double overallAccuracy;
101117

102-
public Result(List<ActualClass> actualClasses, double overallAccuracy) {
103-
this.actualClasses = Collections.unmodifiableList(Objects.requireNonNull(actualClasses));
118+
public Result(List<PerClassResult> classes, double overallAccuracy) {
119+
this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes));
104120
this.overallAccuracy = overallAccuracy;
105121
}
106122

@@ -109,8 +125,8 @@ public String getMetricName() {
109125
return NAME;
110126
}
111127

112-
public List<ActualClass> getActualClasses() {
113-
return actualClasses;
128+
public List<PerClassResult> getClasses() {
129+
return classes;
114130
}
115131

116132
public double getOverallAccuracy() {
@@ -120,7 +136,7 @@ public double getOverallAccuracy() {
120136
@Override
121137
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
122138
builder.startObject();
123-
builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses);
139+
builder.field(CLASSES.getPreferredName(), classes);
124140
builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
125141
builder.endObject();
126142
return builder;
@@ -131,52 +147,42 @@ public boolean equals(Object o) {
131147
if (this == o) return true;
132148
if (o == null || getClass() != o.getClass()) return false;
133149
Result that = (Result) o;
134-
return Objects.equals(this.actualClasses, that.actualClasses)
150+
return Objects.equals(this.classes, that.classes)
135151
&& this.overallAccuracy == that.overallAccuracy;
136152
}
137153

138154
@Override
139155
public int hashCode() {
140-
return Objects.hash(actualClasses, overallAccuracy);
156+
return Objects.hash(classes, overallAccuracy);
141157
}
142158
}
143159

144-
public static class ActualClass implements ToXContentObject {
160+
public static class PerClassResult implements ToXContentObject {
145161

146-
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
147-
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
162+
private static final ParseField CLASS_NAME = new ParseField("class_name");
148163
private static final ParseField ACCURACY = new ParseField("accuracy");
149164

150165
@SuppressWarnings("unchecked")
151-
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
152-
new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2]));
166+
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
167+
new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
153168

154169
static {
155-
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
156-
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
170+
PARSER.declareString(constructorArg(), CLASS_NAME);
157171
PARSER.declareDouble(constructorArg(), ACCURACY);
158172
}
159173

160-
/** Name of the actual class. */
161-
private final String actualClass;
162-
/** Number of documents (examples) belonging to the {code actualClass} class. */
163-
private final long actualClassDocCount;
164-
/** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
174+
/** Name of the class. */
175+
private final String className;
176+
/** Fraction of documents that are either true positives or true negatives wrt {@code className}. */
165177
private final double accuracy;
166178

167-
public ActualClass(
168-
String actualClass, long actualClassDocCount, double accuracy) {
169-
this.actualClass = Objects.requireNonNull(actualClass);
170-
this.actualClassDocCount = actualClassDocCount;
179+
public PerClassResult(String className, double accuracy) {
180+
this.className = Objects.requireNonNull(className);
171181
this.accuracy = accuracy;
172182
}
173183

174-
public String getActualClass() {
175-
return actualClass;
176-
}
177-
178-
public long getActualClassDocCount() {
179-
return actualClassDocCount;
184+
public String getClassName() {
185+
return className;
180186
}
181187

182188
public double getAccuracy() {
@@ -186,8 +192,7 @@ public double getAccuracy() {
186192
@Override
187193
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
188194
builder.startObject();
189-
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
190-
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
195+
builder.field(CLASS_NAME.getPreferredName(), className);
191196
builder.field(ACCURACY.getPreferredName(), accuracy);
192197
builder.endObject();
193198
return builder;
@@ -197,15 +202,19 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
197202
public boolean equals(Object o) {
198203
if (this == o) return true;
199204
if (o == null || getClass() != o.getClass()) return false;
200-
ActualClass that = (ActualClass) o;
201-
return Objects.equals(this.actualClass, that.actualClass)
202-
&& this.actualClassDocCount == that.actualClassDocCount
205+
PerClassResult that = (PerClassResult) o;
206+
return Objects.equals(this.className, that.className)
203207
&& this.accuracy == that.accuracy;
204208
}
205209

206210
@Override
207211
public int hashCode() {
208-
return Objects.hash(actualClass, actualClassDocCount, accuracy);
212+
return Objects.hash(className, accuracy);
213+
}
214+
215+
@Override
216+
public String toString() {
217+
return Strings.toString(this);
209218
}
210219
}
211220
}

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

+7-7
Original file line numberDiff line numberDiff line change
@@ -1819,15 +1819,15 @@ public void testEvaluateDataFrame_Classification() throws IOException {
18191819
AccuracyMetric.Result accuracyResult = evaluateDataFrameResponse.getMetricByName(AccuracyMetric.NAME);
18201820
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
18211821
assertThat(
1822-
accuracyResult.getActualClasses(),
1822+
accuracyResult.getClasses(),
18231823
equalTo(
18241824
List.of(
1825-
// 3 out of 5 examples labeled as "cat" were classified correctly
1826-
new AccuracyMetric.ActualClass("cat", 5, 0.6),
1827-
// 3 out of 4 examples labeled as "dog" were classified correctly
1828-
new AccuracyMetric.ActualClass("dog", 4, 0.75),
1829-
// no examples labeled as "ant" were classified correctly
1830-
new AccuracyMetric.ActualClass("ant", 1, 0.0))));
1825+
// 9 out of 10 examples were classified correctly
1826+
new AccuracyMetric.PerClassResult("ant", 0.9),
1827+
// 6 out of 10 examples were classified correctly
1828+
new AccuracyMetric.PerClassResult("cat", 0.6),
1829+
// 8 out of 10 examples were classified correctly
1830+
new AccuracyMetric.PerClassResult("dog", 0.8))));
18311831
assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly
18321832
}
18331833
{ // Precision

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
2020

2121
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
22-
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.ActualClass;
22+
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.PerClassResult;
2323
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.Result;
2424
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
2525
import org.elasticsearch.common.xcontent.XContentParser;
@@ -41,13 +41,13 @@ protected NamedXContentRegistry xContentRegistry() {
4141
public static Result randomResult() {
4242
int numClasses = randomIntBetween(2, 100);
4343
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
44-
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
44+
List<PerClassResult> classes = new ArrayList<>(numClasses);
4545
for (int i = 0; i < numClasses; i++) {
4646
double accuracy = randomDoubleBetween(0.0, 1.0, true);
47-
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy));
47+
classes.add(new PerClassResult(classNames.get(i), accuracy));
4848
}
4949
double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
50-
return new Result(actualClasses, overallAccuracy);
50+
return new Result(classes, overallAccuracy);
5151
}
5252

5353
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,5 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
4444
* Gets the evaluation result for this metric.
4545
* @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise
4646
*/
47-
Optional<EvaluationMetricResult> getResult();
47+
Optional<? extends EvaluationMetricResult> getResult();
4848
}

0 commit comments

Comments
 (0)