Skip to content

Commit c9fea1e

Browse files
authored
Change format of MulticlassConfusionMatrix result to be more self-explanatory (#48174)
1 parent ca7cb6a commit c9fea1e

File tree

9 files changed

+697
-163
lines changed

9 files changed

+697
-163
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java

+151-23
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,17 @@
2121
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
2222
import org.elasticsearch.common.Nullable;
2323
import org.elasticsearch.common.ParseField;
24+
import org.elasticsearch.common.Strings;
2425
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
26+
import org.elasticsearch.common.xcontent.ToXContentObject;
2527
import org.elasticsearch.common.xcontent.XContentBuilder;
2628
import org.elasticsearch.common.xcontent.XContentParser;
2729

2830
import java.io.IOException;
2931
import java.util.Collections;
30-
import java.util.Map;
32+
import java.util.List;
3133
import java.util.Objects;
32-
import java.util.TreeMap;
3334

34-
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
3535
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
3636

3737
/**
@@ -97,52 +97,52 @@ public int hashCode() {
9797
public static class Result implements EvaluationMetric.Result {
9898

9999
private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix");
100-
private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_");
100+
private static final ParseField OTHER_ACTUAL_CLASS_COUNT = new ParseField("other_actual_class_count");
101101

102102
@SuppressWarnings("unchecked")
103103
private static final ConstructingObjectParser<Result, Void> PARSER =
104104
new ConstructingObjectParser<>(
105-
"multiclass_confusion_matrix_result", true, a -> new Result((Map<String, Map<String, Long>>) a[0], (long) a[1]));
105+
"multiclass_confusion_matrix_result", true, a -> new Result((List<ActualClass>) a[0], (Long) a[1]));
106106

107107
static {
108-
PARSER.declareObject(
109-
constructorArg(),
110-
(p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)),
111-
CONFUSION_MATRIX);
112-
PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT);
108+
PARSER.declareObjectArray(optionalConstructorArg(), ActualClass.PARSER, CONFUSION_MATRIX);
109+
PARSER.declareLong(optionalConstructorArg(), OTHER_ACTUAL_CLASS_COUNT);
113110
}
114111

115112
public static Result fromXContent(XContentParser parser) {
116113
return PARSER.apply(parser, null);
117114
}
118115

119-
// Immutable
120-
private final Map<String, Map<String, Long>> confusionMatrix;
121-
private final long otherClassesCount;
116+
private final List<ActualClass> confusionMatrix;
117+
private final Long otherActualClassCount;
122118

123-
public Result(Map<String, Map<String, Long>> confusionMatrix, long otherClassesCount) {
124-
this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix));
125-
this.otherClassesCount = otherClassesCount;
119+
public Result(@Nullable List<ActualClass> confusionMatrix, @Nullable Long otherActualClassCount) {
120+
this.confusionMatrix = confusionMatrix != null ? Collections.unmodifiableList(Objects.requireNonNull(confusionMatrix)) : null;
121+
this.otherActualClassCount = otherActualClassCount;
126122
}
127123

128124
@Override
129125
public String getMetricName() {
130126
return NAME;
131127
}
132128

133-
public Map<String, Map<String, Long>> getConfusionMatrix() {
129+
public List<ActualClass> getConfusionMatrix() {
134130
return confusionMatrix;
135131
}
136132

137-
public long getOtherClassesCount() {
138-
return otherClassesCount;
133+
public Long getOtherActualClassCount() {
134+
return otherActualClassCount;
139135
}
140136

141137
@Override
142138
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
143139
builder.startObject();
144-
builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
145-
builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount);
140+
if (confusionMatrix != null) {
141+
builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
142+
}
143+
if (otherActualClassCount != null) {
144+
builder.field(OTHER_ACTUAL_CLASS_COUNT.getPreferredName(), otherActualClassCount);
145+
}
146146
builder.endObject();
147147
return builder;
148148
}
@@ -153,12 +153,140 @@ public boolean equals(Object o) {
153153
if (o == null || getClass() != o.getClass()) return false;
154154
Result that = (Result) o;
155155
return Objects.equals(this.confusionMatrix, that.confusionMatrix)
156-
&& this.otherClassesCount == that.otherClassesCount;
156+
&& Objects.equals(this.otherActualClassCount, that.otherActualClassCount);
157157
}
158158

159159
@Override
160160
public int hashCode() {
161-
return Objects.hash(confusionMatrix, otherClassesCount);
161+
return Objects.hash(confusionMatrix, otherActualClassCount);
162+
}
163+
}
164+
165+
public static class ActualClass implements ToXContentObject {
166+
167+
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
168+
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
169+
private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes");
170+
private static final ParseField OTHER_PREDICTED_CLASS_DOC_COUNT = new ParseField("other_predicted_class_doc_count");
171+
172+
@SuppressWarnings("unchecked")
173+
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
174+
new ConstructingObjectParser<>(
175+
"multiclass_confusion_matrix_actual_class",
176+
true,
177+
a -> new ActualClass((String) a[0], (Long) a[1], (List<PredictedClass>) a[2], (Long) a[3]));
178+
179+
static {
180+
PARSER.declareString(optionalConstructorArg(), ACTUAL_CLASS);
181+
PARSER.declareLong(optionalConstructorArg(), ACTUAL_CLASS_DOC_COUNT);
182+
PARSER.declareObjectArray(optionalConstructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES);
183+
PARSER.declareLong(optionalConstructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT);
184+
}
185+
186+
private final String actualClass;
187+
private final Long actualClassDocCount;
188+
private final List<PredictedClass> predictedClasses;
189+
private final Long otherPredictedClassDocCount;
190+
191+
public ActualClass(@Nullable String actualClass,
192+
@Nullable Long actualClassDocCount,
193+
@Nullable List<PredictedClass> predictedClasses,
194+
@Nullable Long otherPredictedClassDocCount) {
195+
this.actualClass = actualClass;
196+
this.actualClassDocCount = actualClassDocCount;
197+
this.predictedClasses = predictedClasses != null ? Collections.unmodifiableList(predictedClasses) : null;
198+
this.otherPredictedClassDocCount = otherPredictedClassDocCount;
199+
}
200+
201+
@Override
202+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
203+
builder.startObject();
204+
if (actualClass != null) {
205+
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
206+
}
207+
if (actualClassDocCount != null) {
208+
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
209+
}
210+
if (predictedClasses != null) {
211+
builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses);
212+
}
213+
if (otherPredictedClassDocCount != null) {
214+
builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), otherPredictedClassDocCount);
215+
}
216+
builder.endObject();
217+
return builder;
218+
}
219+
220+
@Override
221+
public boolean equals(Object o) {
222+
if (this == o) return true;
223+
if (o == null || getClass() != o.getClass()) return false;
224+
ActualClass that = (ActualClass) o;
225+
return Objects.equals(this.actualClass, that.actualClass)
226+
&& Objects.equals(this.actualClassDocCount, that.actualClassDocCount)
227+
&& Objects.equals(this.predictedClasses, that.predictedClasses)
228+
&& Objects.equals(this.otherPredictedClassDocCount, that.otherPredictedClassDocCount);
229+
}
230+
231+
@Override
232+
public int hashCode() {
233+
return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount);
234+
}
235+
236+
@Override
237+
public String toString() {
238+
return Strings.toString(this);
239+
}
240+
}
241+
242+
public static class PredictedClass implements ToXContentObject {
243+
244+
private static final ParseField PREDICTED_CLASS = new ParseField("predicted_class");
245+
private static final ParseField COUNT = new ParseField("count");
246+
247+
@SuppressWarnings("unchecked")
248+
private static final ConstructingObjectParser<PredictedClass, Void> PARSER =
249+
new ConstructingObjectParser<>(
250+
"multiclass_confusion_matrix_predicted_class", true, a -> new PredictedClass((String) a[0], (Long) a[1]));
251+
252+
static {
253+
PARSER.declareString(optionalConstructorArg(), PREDICTED_CLASS);
254+
PARSER.declareLong(optionalConstructorArg(), COUNT);
255+
}
256+
257+
private final String predictedClass;
258+
private final Long count;
259+
260+
public PredictedClass(@Nullable String predictedClass, @Nullable Long count) {
261+
this.predictedClass = predictedClass;
262+
this.count = count;
263+
}
264+
265+
@Override
266+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
267+
builder.startObject();
268+
if (predictedClass != null) {
269+
builder.field(PREDICTED_CLASS.getPreferredName(), predictedClass);
270+
}
271+
if (count != null) {
272+
builder.field(COUNT.getPreferredName(), count);
273+
}
274+
builder.endObject();
275+
return builder;
276+
}
277+
278+
@Override
279+
public boolean equals(Object o) {
280+
if (this == o) return true;
281+
if (o == null || getClass() != o.getClass()) return false;
282+
PredictedClass that = (PredictedClass) o;
283+
return Objects.equals(this.predictedClass, that.predictedClass)
284+
&& Objects.equals(this.count, that.count);
285+
}
286+
287+
@Override
288+
public int hashCode() {
289+
return Objects.hash(predictedClass, count);
162290
}
163291
}
164292
}

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

+25-10
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@
127127
import org.elasticsearch.client.ml.dataframe.QueryConfig;
128128
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
129129
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
130+
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
131+
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
130132
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
131133
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
132134
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
@@ -1777,7 +1779,7 @@ public void testEvaluateDataFrame_Classification() throws IOException {
17771779
.add(docForClassification(indexName, "dog", "dog"))
17781780
.add(docForClassification(indexName, "dog", "dog"))
17791781
.add(docForClassification(indexName, "dog", "dog"))
1780-
.add(docForClassification(indexName, "horse", "cat"));
1782+
.add(docForClassification(indexName, "ant", "cat"));
17811783
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
17821784

17831785
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
@@ -1800,11 +1802,23 @@ public void testEvaluateDataFrame_Classification() throws IOException {
18001802
assertThat(
18011803
mcmResult.getConfusionMatrix(),
18021804
equalTo(
1803-
Map.of(
1804-
"cat", Map.of("cat", 3L, "dog", 1L, "horse", 0L, "_other_", 1L),
1805-
"dog", Map.of("cat", 1L, "dog", 3L, "horse", 0L),
1806-
"horse", Map.of("cat", 1L, "dog", 0L, "horse", 0L))));
1807-
assertThat(mcmResult.getOtherClassesCount(), equalTo(0L));
1805+
List.of(
1806+
new ActualClass(
1807+
"ant",
1808+
1L,
1809+
List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)),
1810+
0L),
1811+
new ActualClass(
1812+
"cat",
1813+
5L,
1814+
List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)),
1815+
1L),
1816+
new ActualClass(
1817+
"dog",
1818+
4L,
1819+
List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
1820+
0L))));
1821+
assertThat(mcmResult.getOtherActualClassCount(), equalTo(0L));
18081822
}
18091823
{ // Explicit size provided for MulticlassConfusionMatrixMetric metric
18101824
EvaluateDataFrameRequest evaluateDataFrameRequest =
@@ -1824,10 +1838,11 @@ public void testEvaluateDataFrame_Classification() throws IOException {
18241838
assertThat(
18251839
mcmResult.getConfusionMatrix(),
18261840
equalTo(
1827-
Map.of(
1828-
"cat", Map.of("cat", 3L, "dog", 1L, "_other_", 1L),
1829-
"dog", Map.of("cat", 1L, "dog", 3L))));
1830-
assertThat(mcmResult.getOtherClassesCount(), equalTo(1L));
1841+
List.of(
1842+
new ActualClass("cat", 5L, List.of(new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1L),
1843+
new ActualClass("dog", 4L, List.of(new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0L)
1844+
)));
1845+
assertThat(mcmResult.getOtherActualClassCount(), equalTo(1L));
18311846
}
18321847
}
18331848

client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

+20-6
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@
142142
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
143143
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
144144
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
145+
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
146+
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
145147
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
146148
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
147149
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@@ -3355,18 +3357,30 @@ public void testEvaluateDataFrame_Classification() throws Exception {
33553357
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
33563358
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1>
33573359

3358-
Map<String, Map<String, Long>> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2>
3359-
long otherClassesCount = multiclassConfusionMatrix.getOtherClassesCount(); // <3>
3360+
List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2>
3361+
long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <3>
33603362
// end::evaluate-data-frame-results-classification
33613363

33623364
assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
33633365
assertThat(
33643366
confusionMatrix,
33653367
equalTo(
3366-
Map.of(
3367-
"cat", Map.of("cat", 3L, "dog", 1L, "ant", 0L, "_other_", 1L),
3368-
"dog", Map.of("cat", 1L, "dog", 3L, "ant", 0L),
3369-
"ant", Map.of("cat", 1L, "dog", 0L, "ant", 0L))));
3368+
List.of(
3369+
new ActualClass(
3370+
"ant",
3371+
1L,
3372+
List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)),
3373+
0L),
3374+
new ActualClass(
3375+
"cat",
3376+
5L,
3377+
List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)),
3378+
1L),
3379+
new ActualClass(
3380+
"dog",
3381+
4L,
3382+
List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
3383+
0L))));
33703384
assertThat(otherClassesCount, equalTo(0L));
33713385
}
33723386
}

0 commit comments

Comments
 (0)