21
21
import org .elasticsearch .client .ml .dataframe .evaluation .EvaluationMetric ;
22
22
import org .elasticsearch .common .Nullable ;
23
23
import org .elasticsearch .common .ParseField ;
24
+ import org .elasticsearch .common .Strings ;
24
25
import org .elasticsearch .common .xcontent .ConstructingObjectParser ;
26
+ import org .elasticsearch .common .xcontent .ToXContentObject ;
25
27
import org .elasticsearch .common .xcontent .XContentBuilder ;
26
28
import org .elasticsearch .common .xcontent .XContentParser ;
27
29
28
30
import java .io .IOException ;
29
31
import java .util .Collections ;
30
- import java .util .Map ;
32
+ import java .util .List ;
31
33
import java .util .Objects ;
32
- import java .util .TreeMap ;
33
34
34
- import static org .elasticsearch .common .xcontent .ConstructingObjectParser .constructorArg ;
35
35
import static org .elasticsearch .common .xcontent .ConstructingObjectParser .optionalConstructorArg ;
36
36
37
37
/**
@@ -97,52 +97,52 @@ public int hashCode() {
97
97
public static class Result implements EvaluationMetric .Result {
98
98
99
99
private static final ParseField CONFUSION_MATRIX = new ParseField ("confusion_matrix" );
100
- private static final ParseField OTHER_CLASSES_COUNT = new ParseField ("_other_ " );
100
+ private static final ParseField OTHER_ACTUAL_CLASS_COUNT = new ParseField ("other_actual_class_count " );
101
101
102
102
@ SuppressWarnings ("unchecked" )
103
103
private static final ConstructingObjectParser <Result , Void > PARSER =
104
104
new ConstructingObjectParser <>(
105
- "multiclass_confusion_matrix_result" , true , a -> new Result ((Map < String , Map < String , Long >> ) a [0 ], (long ) a [1 ]));
105
+ "multiclass_confusion_matrix_result" , true , a -> new Result ((List < ActualClass > ) a [0 ], (Long ) a [1 ]));
106
106
107
107
static {
108
- PARSER .declareObject (
109
- constructorArg (),
110
- (p , c ) -> p .map (TreeMap ::new , p2 -> p2 .map (TreeMap ::new , XContentParser ::longValue )),
111
- CONFUSION_MATRIX );
112
- PARSER .declareLong (constructorArg (), OTHER_CLASSES_COUNT );
108
+ PARSER .declareObjectArray (optionalConstructorArg (), ActualClass .PARSER , CONFUSION_MATRIX );
109
+ PARSER .declareLong (optionalConstructorArg (), OTHER_ACTUAL_CLASS_COUNT );
113
110
}
114
111
115
112
public static Result fromXContent (XContentParser parser ) {
116
113
return PARSER .apply (parser , null );
117
114
}
118
115
119
- // Immutable
120
- private final Map <String , Map <String , Long >> confusionMatrix ;
121
- private final long otherClassesCount ;
116
+ private final List <ActualClass > confusionMatrix ;
117
+ private final Long otherActualClassCount ;
122
118
123
- public Result (Map < String , Map < String , Long >> confusionMatrix , long otherClassesCount ) {
124
- this .confusionMatrix = Collections .unmodifiableMap (Objects .requireNonNull (confusionMatrix ));
125
- this .otherClassesCount = otherClassesCount ;
119
+ public Result (@ Nullable List < ActualClass > confusionMatrix , @ Nullable Long otherActualClassCount ) {
120
+ this .confusionMatrix = confusionMatrix != null ? Collections .unmodifiableList (Objects .requireNonNull (confusionMatrix )) : null ;
121
+ this .otherActualClassCount = otherActualClassCount ;
126
122
}
127
123
128
124
@ Override
129
125
public String getMetricName () {
130
126
return NAME ;
131
127
}
132
128
133
- public Map < String , Map < String , Long > > getConfusionMatrix () {
129
+ public List < ActualClass > getConfusionMatrix () {
134
130
return confusionMatrix ;
135
131
}
136
132
137
- public long getOtherClassesCount () {
138
- return otherClassesCount ;
133
+ public Long getOtherActualClassCount () {
134
+ return otherActualClassCount ;
139
135
}
140
136
141
137
@ Override
142
138
public XContentBuilder toXContent (XContentBuilder builder , Params params ) throws IOException {
143
139
builder .startObject ();
144
- builder .field (CONFUSION_MATRIX .getPreferredName (), confusionMatrix );
145
- builder .field (OTHER_CLASSES_COUNT .getPreferredName (), otherClassesCount );
140
+ if (confusionMatrix != null ) {
141
+ builder .field (CONFUSION_MATRIX .getPreferredName (), confusionMatrix );
142
+ }
143
+ if (otherActualClassCount != null ) {
144
+ builder .field (OTHER_ACTUAL_CLASS_COUNT .getPreferredName (), otherActualClassCount );
145
+ }
146
146
builder .endObject ();
147
147
return builder ;
148
148
}
@@ -153,12 +153,140 @@ public boolean equals(Object o) {
153
153
if (o == null || getClass () != o .getClass ()) return false ;
154
154
Result that = (Result ) o ;
155
155
return Objects .equals (this .confusionMatrix , that .confusionMatrix )
156
- && this .otherClassesCount == that .otherClassesCount ;
156
+ && Objects . equals ( this .otherActualClassCount , that .otherActualClassCount ) ;
157
157
}
158
158
159
159
@ Override
160
160
public int hashCode () {
161
- return Objects .hash (confusionMatrix , otherClassesCount );
161
+ return Objects .hash (confusionMatrix , otherActualClassCount );
162
+ }
163
+ }
164
+
165
+ public static class ActualClass implements ToXContentObject {
166
+
167
+ private static final ParseField ACTUAL_CLASS = new ParseField ("actual_class" );
168
+ private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField ("actual_class_doc_count" );
169
+ private static final ParseField PREDICTED_CLASSES = new ParseField ("predicted_classes" );
170
+ private static final ParseField OTHER_PREDICTED_CLASS_DOC_COUNT = new ParseField ("other_predicted_class_doc_count" );
171
+
172
+ @ SuppressWarnings ("unchecked" )
173
+ private static final ConstructingObjectParser <ActualClass , Void > PARSER =
174
+ new ConstructingObjectParser <>(
175
+ "multiclass_confusion_matrix_actual_class" ,
176
+ true ,
177
+ a -> new ActualClass ((String ) a [0 ], (Long ) a [1 ], (List <PredictedClass >) a [2 ], (Long ) a [3 ]));
178
+
179
+ static {
180
+ PARSER .declareString (optionalConstructorArg (), ACTUAL_CLASS );
181
+ PARSER .declareLong (optionalConstructorArg (), ACTUAL_CLASS_DOC_COUNT );
182
+ PARSER .declareObjectArray (optionalConstructorArg (), PredictedClass .PARSER , PREDICTED_CLASSES );
183
+ PARSER .declareLong (optionalConstructorArg (), OTHER_PREDICTED_CLASS_DOC_COUNT );
184
+ }
185
+
186
+ private final String actualClass ;
187
+ private final Long actualClassDocCount ;
188
+ private final List <PredictedClass > predictedClasses ;
189
+ private final Long otherPredictedClassDocCount ;
190
+
191
+ public ActualClass (@ Nullable String actualClass ,
192
+ @ Nullable Long actualClassDocCount ,
193
+ @ Nullable List <PredictedClass > predictedClasses ,
194
+ @ Nullable Long otherPredictedClassDocCount ) {
195
+ this .actualClass = actualClass ;
196
+ this .actualClassDocCount = actualClassDocCount ;
197
+ this .predictedClasses = predictedClasses != null ? Collections .unmodifiableList (predictedClasses ) : null ;
198
+ this .otherPredictedClassDocCount = otherPredictedClassDocCount ;
199
+ }
200
+
201
+ @ Override
202
+ public XContentBuilder toXContent (XContentBuilder builder , Params params ) throws IOException {
203
+ builder .startObject ();
204
+ if (actualClass != null ) {
205
+ builder .field (ACTUAL_CLASS .getPreferredName (), actualClass );
206
+ }
207
+ if (actualClassDocCount != null ) {
208
+ builder .field (ACTUAL_CLASS_DOC_COUNT .getPreferredName (), actualClassDocCount );
209
+ }
210
+ if (predictedClasses != null ) {
211
+ builder .field (PREDICTED_CLASSES .getPreferredName (), predictedClasses );
212
+ }
213
+ if (otherPredictedClassDocCount != null ) {
214
+ builder .field (OTHER_PREDICTED_CLASS_DOC_COUNT .getPreferredName (), otherPredictedClassDocCount );
215
+ }
216
+ builder .endObject ();
217
+ return builder ;
218
+ }
219
+
220
+ @ Override
221
+ public boolean equals (Object o ) {
222
+ if (this == o ) return true ;
223
+ if (o == null || getClass () != o .getClass ()) return false ;
224
+ ActualClass that = (ActualClass ) o ;
225
+ return Objects .equals (this .actualClass , that .actualClass )
226
+ && Objects .equals (this .actualClassDocCount , that .actualClassDocCount )
227
+ && Objects .equals (this .predictedClasses , that .predictedClasses )
228
+ && Objects .equals (this .otherPredictedClassDocCount , that .otherPredictedClassDocCount );
229
+ }
230
+
231
+ @ Override
232
+ public int hashCode () {
233
+ return Objects .hash (actualClass , actualClassDocCount , predictedClasses , otherPredictedClassDocCount );
234
+ }
235
+
236
+ @ Override
237
+ public String toString () {
238
+ return Strings .toString (this );
239
+ }
240
+ }
241
+
242
+ public static class PredictedClass implements ToXContentObject {
243
+
244
+ private static final ParseField PREDICTED_CLASS = new ParseField ("predicted_class" );
245
+ private static final ParseField COUNT = new ParseField ("count" );
246
+
247
+ @ SuppressWarnings ("unchecked" )
248
+ private static final ConstructingObjectParser <PredictedClass , Void > PARSER =
249
+ new ConstructingObjectParser <>(
250
+ "multiclass_confusion_matrix_predicted_class" , true , a -> new PredictedClass ((String ) a [0 ], (Long ) a [1 ]));
251
+
252
+ static {
253
+ PARSER .declareString (optionalConstructorArg (), PREDICTED_CLASS );
254
+ PARSER .declareLong (optionalConstructorArg (), COUNT );
255
+ }
256
+
257
+ private final String predictedClass ;
258
+ private final Long count ;
259
+
260
+ public PredictedClass (@ Nullable String predictedClass , @ Nullable Long count ) {
261
+ this .predictedClass = predictedClass ;
262
+ this .count = count ;
263
+ }
264
+
265
+ @ Override
266
+ public XContentBuilder toXContent (XContentBuilder builder , Params params ) throws IOException {
267
+ builder .startObject ();
268
+ if (predictedClass != null ) {
269
+ builder .field (PREDICTED_CLASS .getPreferredName (), predictedClass );
270
+ }
271
+ if (count != null ) {
272
+ builder .field (COUNT .getPreferredName (), count );
273
+ }
274
+ builder .endObject ();
275
+ return builder ;
276
+ }
277
+
278
+ @ Override
279
+ public boolean equals (Object o ) {
280
+ if (this == o ) return true ;
281
+ if (o == null || getClass () != o .getClass ()) return false ;
282
+ PredictedClass that = (PredictedClass ) o ;
283
+ return Objects .equals (this .predictedClass , that .predictedClass )
284
+ && Objects .equals (this .count , that .count );
285
+ }
286
+
287
+ @ Override
288
+ public int hashCode () {
289
+ return Objects .hash (predictedClass , count );
162
290
}
163
291
}
164
292
}
0 commit comments