20
20
21
21
import org .elasticsearch .client .ml .dataframe .evaluation .EvaluationMetric ;
22
22
import org .elasticsearch .common .ParseField ;
23
+ import org .elasticsearch .common .Strings ;
23
24
import org .elasticsearch .common .xcontent .ConstructingObjectParser ;
24
25
import org .elasticsearch .common .xcontent .ObjectParser ;
25
26
import org .elasticsearch .common .xcontent .ToXContent ;
35
36
import static org .elasticsearch .common .xcontent .ConstructingObjectParser .constructorArg ;
36
37
37
38
/**
38
- * {@link AccuracyMetric} is a metric that answers the question:
39
- * "What fraction of examples have been classified correctly by the classifier?"
39
+ * {@link AccuracyMetric} is a metric that answers the following two questions:
40
40
*
41
- * equation: accuracy = 1/n * Σ(y == y´)
41
+ * 1. What is the fraction of documents for which predicted class equals the actual class?
42
+ *
43
+ * equation: overall_accuracy = 1/n * Σ(y == y')
44
+ * where: n = total number of documents
45
+ * y = document's actual class
46
+ * y' = document's predicted class
47
+ *
48
+ * 2. For any given class X, what is the fraction of documents for which either
49
+ * a) both actual and predicted class are equal to X (true positives)
50
+ * or
51
+ * b) both actual and predicted class are not equal to X (true negatives)
52
+ *
53
+ * equation: accuracy(X) = 1/n * (TP(X) + TN(X))
54
+ * where: X = class being examined
55
+ * n = total number of documents
56
+ * TP(X) = number of true positives wrt X
57
+ * TN(X) = number of true negatives wrt X
42
58
*/
43
59
public class AccuracyMetric implements EvaluationMetric {
44
60
@@ -78,29 +94,29 @@ public int hashCode() {
78
94
79
95
public static class Result implements EvaluationMetric .Result {
80
96
81
- private static final ParseField ACTUAL_CLASSES = new ParseField ("actual_classes " );
97
+ private static final ParseField CLASSES = new ParseField ("classes " );
82
98
private static final ParseField OVERALL_ACCURACY = new ParseField ("overall_accuracy" );
83
99
84
100
@ SuppressWarnings ("unchecked" )
85
101
private static final ConstructingObjectParser <Result , Void > PARSER =
86
- new ConstructingObjectParser <>("accuracy_result" , true , a -> new Result ((List <ActualClass >) a [0 ], (double ) a [1 ]));
102
+ new ConstructingObjectParser <>("accuracy_result" , true , a -> new Result ((List <PerClassResult >) a [0 ], (double ) a [1 ]));
87
103
88
104
static {
89
- PARSER .declareObjectArray (constructorArg (), ActualClass .PARSER , ACTUAL_CLASSES );
105
+ PARSER .declareObjectArray (constructorArg (), PerClassResult .PARSER , CLASSES );
90
106
PARSER .declareDouble (constructorArg (), OVERALL_ACCURACY );
91
107
}
92
108
93
109
public static Result fromXContent (XContentParser parser ) {
94
110
return PARSER .apply (parser , null );
95
111
}
96
112
97
- /** List of actual classes . */
98
- private final List <ActualClass > actualClasses ;
99
- /** Fraction of documents predicted correctly . */
113
+ /** List of per-class results . */
114
+ private final List <PerClassResult > classes ;
115
+ /** Fraction of documents for which predicted class equals the actual class . */
100
116
private final double overallAccuracy ;
101
117
102
- public Result (List <ActualClass > actualClasses , double overallAccuracy ) {
103
- this .actualClasses = Collections .unmodifiableList (Objects .requireNonNull (actualClasses ));
118
+ public Result (List <PerClassResult > classes , double overallAccuracy ) {
119
+ this .classes = Collections .unmodifiableList (Objects .requireNonNull (classes ));
104
120
this .overallAccuracy = overallAccuracy ;
105
121
}
106
122
@@ -109,8 +125,8 @@ public String getMetricName() {
109
125
return NAME ;
110
126
}
111
127
112
- public List <ActualClass > getActualClasses () {
113
- return actualClasses ;
128
+ public List <PerClassResult > getClasses () {
129
+ return classes ;
114
130
}
115
131
116
132
public double getOverallAccuracy () {
@@ -120,7 +136,7 @@ public double getOverallAccuracy() {
120
136
@ Override
121
137
public XContentBuilder toXContent (XContentBuilder builder , Params params ) throws IOException {
122
138
builder .startObject ();
123
- builder .field (ACTUAL_CLASSES .getPreferredName (), actualClasses );
139
+ builder .field (CLASSES .getPreferredName (), classes );
124
140
builder .field (OVERALL_ACCURACY .getPreferredName (), overallAccuracy );
125
141
builder .endObject ();
126
142
return builder ;
@@ -131,52 +147,42 @@ public boolean equals(Object o) {
131
147
if (this == o ) return true ;
132
148
if (o == null || getClass () != o .getClass ()) return false ;
133
149
Result that = (Result ) o ;
134
- return Objects .equals (this .actualClasses , that .actualClasses )
150
+ return Objects .equals (this .classes , that .classes )
135
151
&& this .overallAccuracy == that .overallAccuracy ;
136
152
}
137
153
138
154
@ Override
139
155
public int hashCode () {
140
- return Objects .hash (actualClasses , overallAccuracy );
156
+ return Objects .hash (classes , overallAccuracy );
141
157
}
142
158
}
143
159
144
- public static class ActualClass implements ToXContentObject {
160
+ public static class PerClassResult implements ToXContentObject {
145
161
146
- private static final ParseField ACTUAL_CLASS = new ParseField ("actual_class" );
147
- private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField ("actual_class_doc_count" );
162
+ private static final ParseField CLASS_NAME = new ParseField ("class_name" );
148
163
private static final ParseField ACCURACY = new ParseField ("accuracy" );
149
164
150
165
@ SuppressWarnings ("unchecked" )
151
- private static final ConstructingObjectParser <ActualClass , Void > PARSER =
152
- new ConstructingObjectParser <>("accuracy_actual_class " , true , a -> new ActualClass ((String ) a [0 ], (long ) a [ 1 ], ( double ) a [2 ]));
166
+ private static final ConstructingObjectParser <PerClassResult , Void > PARSER =
167
+ new ConstructingObjectParser <>("accuracy_per_class_result " , true , a -> new PerClassResult ((String ) a [0 ], (double ) a [1 ]));
153
168
154
169
static {
155
- PARSER .declareString (constructorArg (), ACTUAL_CLASS );
156
- PARSER .declareLong (constructorArg (), ACTUAL_CLASS_DOC_COUNT );
170
+ PARSER .declareString (constructorArg (), CLASS_NAME );
157
171
PARSER .declareDouble (constructorArg (), ACCURACY );
158
172
}
159
173
160
- /** Name of the actual class. */
161
- private final String actualClass ;
162
- /** Number of documents (examples) belonging to the {code actualClass} class. */
163
- private final long actualClassDocCount ;
164
- /** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
174
+ /** Name of the class. */
175
+ private final String className ;
176
+ /** Fraction of documents that are either true positives or true negatives wrt {@code className}. */
165
177
private final double accuracy ;
166
178
167
- public ActualClass (
168
- String actualClass , long actualClassDocCount , double accuracy ) {
169
- this .actualClass = Objects .requireNonNull (actualClass );
170
- this .actualClassDocCount = actualClassDocCount ;
179
+ public PerClassResult (String className , double accuracy ) {
180
+ this .className = Objects .requireNonNull (className );
171
181
this .accuracy = accuracy ;
172
182
}
173
183
174
- public String getActualClass () {
175
- return actualClass ;
176
- }
177
-
178
- public long getActualClassDocCount () {
179
- return actualClassDocCount ;
184
+ public String getClassName () {
185
+ return className ;
180
186
}
181
187
182
188
public double getAccuracy () {
@@ -186,8 +192,7 @@ public double getAccuracy() {
186
192
@ Override
187
193
public XContentBuilder toXContent (XContentBuilder builder , Params params ) throws IOException {
188
194
builder .startObject ();
189
- builder .field (ACTUAL_CLASS .getPreferredName (), actualClass );
190
- builder .field (ACTUAL_CLASS_DOC_COUNT .getPreferredName (), actualClassDocCount );
195
+ builder .field (CLASS_NAME .getPreferredName (), className );
191
196
builder .field (ACCURACY .getPreferredName (), accuracy );
192
197
builder .endObject ();
193
198
return builder ;
@@ -197,15 +202,19 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
197
202
public boolean equals (Object o ) {
198
203
if (this == o ) return true ;
199
204
if (o == null || getClass () != o .getClass ()) return false ;
200
- ActualClass that = (ActualClass ) o ;
201
- return Objects .equals (this .actualClass , that .actualClass )
202
- && this .actualClassDocCount == that .actualClassDocCount
205
+ PerClassResult that = (PerClassResult ) o ;
206
+ return Objects .equals (this .className , that .className )
203
207
&& this .accuracy == that .accuracy ;
204
208
}
205
209
206
210
@ Override
207
211
public int hashCode () {
208
- return Objects .hash (actualClass , actualClassDocCount , accuracy );
212
+ return Objects .hash (className , accuracy );
213
+ }
214
+
215
+ @ Override
216
+ public String toString () {
217
+ return Strings .toString (this );
209
218
}
210
219
}
211
220
}
0 commit comments