Skip to content

Commit ee952da

Browse files
authored
[7.x] Implement evaluation API for multiclass classification problem (#47126) (#47343)
1 parent e3aab12 commit ee952da

File tree

20 files changed

+1963
-41
lines changed

20 files changed

+1963
-41
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
*/
1919
package org.elasticsearch.client.ml.dataframe.evaluation;
2020

21+
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
2122
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
23+
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
2224
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
2325
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
2426
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
@@ -41,13 +43,18 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
4143
// Evaluations
4244
new NamedXContentRegistry.Entry(
4345
Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClassification::fromXContent),
46+
new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Classification.NAME), Classification::fromXContent),
4447
new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Regression.NAME), Regression::fromXContent),
4548
// Evaluation metrics
4649
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent),
4750
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric::fromXContent),
4851
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent),
4952
new NamedXContentRegistry.Entry(
5053
EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent),
54+
new NamedXContentRegistry.Entry(
55+
EvaluationMetric.class,
56+
new ParseField(MulticlassConfusionMatrixMetric.NAME),
57+
MulticlassConfusionMatrixMetric::fromXContent),
5158
new NamedXContentRegistry.Entry(
5259
EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent),
5360
new NamedXContentRegistry.Entry(
@@ -60,10 +67,14 @@ EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMe
6067
new NamedXContentRegistry.Entry(
6168
EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent),
6269
new NamedXContentRegistry.Entry(
63-
EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent),
70+
EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent),
71+
new NamedXContentRegistry.Entry(
72+
EvaluationMetric.Result.class,
73+
new ParseField(MulticlassConfusionMatrixMetric.NAME),
74+
MulticlassConfusionMatrixMetric.Result::fromXContent),
6475
new NamedXContentRegistry.Entry(
6576
EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent),
6677
new NamedXContentRegistry.Entry(
67-
EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent));
78+
EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent));
6879
}
6980
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
20+
21+
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
22+
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
23+
import org.elasticsearch.common.Nullable;
24+
import org.elasticsearch.common.ParseField;
25+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
26+
import org.elasticsearch.common.xcontent.XContentBuilder;
27+
import org.elasticsearch.common.xcontent.XContentParser;
28+
29+
import java.io.IOException;
30+
import java.util.Arrays;
31+
import java.util.Comparator;
32+
import java.util.List;
33+
import java.util.Objects;
34+
35+
/**
36+
* Evaluation of classification results.
37+
*/
38+
public class Classification implements Evaluation {
39+
40+
public static final String NAME = "classification";
41+
42+
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
43+
private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
44+
private static final ParseField METRICS = new ParseField("metrics");
45+
46+
@SuppressWarnings("unchecked")
47+
public static final ConstructingObjectParser<Classification, Void> PARSER = new ConstructingObjectParser<>(
48+
NAME, true, a -> new Classification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
49+
50+
static {
51+
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
52+
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
53+
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
54+
(p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS);
55+
}
56+
57+
public static Classification fromXContent(XContentParser parser) {
58+
return PARSER.apply(parser, null);
59+
}
60+
61+
/**
62+
* The field containing the actual value
63+
* The value of this field is assumed to be numeric
64+
*/
65+
private final String actualField;
66+
67+
/**
68+
* The field containing the predicted value
69+
* The value of this field is assumed to be numeric
70+
*/
71+
private final String predictedField;
72+
73+
/**
74+
* The list of metrics to calculate
75+
*/
76+
private final List<EvaluationMetric> metrics;
77+
78+
public Classification(String actualField, String predictedField) {
79+
this(actualField, predictedField, (List<EvaluationMetric>)null);
80+
}
81+
82+
public Classification(String actualField, String predictedField, EvaluationMetric... metrics) {
83+
this(actualField, predictedField, Arrays.asList(metrics));
84+
}
85+
86+
public Classification(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
87+
this.actualField = Objects.requireNonNull(actualField);
88+
this.predictedField = Objects.requireNonNull(predictedField);
89+
if (metrics != null) {
90+
metrics.sort(Comparator.comparing(EvaluationMetric::getName));
91+
}
92+
this.metrics = metrics;
93+
}
94+
95+
@Override
96+
public String getName() {
97+
return NAME;
98+
}
99+
100+
@Override
101+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
102+
builder.startObject();
103+
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
104+
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
105+
106+
if (metrics != null) {
107+
builder.startObject(METRICS.getPreferredName());
108+
for (EvaluationMetric metric : metrics) {
109+
builder.field(metric.getName(), metric);
110+
}
111+
builder.endObject();
112+
}
113+
114+
builder.endObject();
115+
return builder;
116+
}
117+
118+
@Override
119+
public boolean equals(Object o) {
120+
if (this == o) return true;
121+
if (o == null || getClass() != o.getClass()) return false;
122+
Classification that = (Classification) o;
123+
return Objects.equals(that.actualField, this.actualField)
124+
&& Objects.equals(that.predictedField, this.predictedField)
125+
&& Objects.equals(that.metrics, this.metrics);
126+
}
127+
128+
@Override
129+
public int hashCode() {
130+
return Objects.hash(actualField, predictedField, metrics);
131+
}
132+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
20+
21+
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
22+
import org.elasticsearch.common.Nullable;
23+
import org.elasticsearch.common.ParseField;
24+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
25+
import org.elasticsearch.common.xcontent.XContentBuilder;
26+
import org.elasticsearch.common.xcontent.XContentParser;
27+
28+
import java.io.IOException;
29+
import java.util.Collections;
30+
import java.util.Map;
31+
import java.util.Objects;
32+
import java.util.TreeMap;
33+
34+
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
35+
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
36+
37+
/**
38+
* Calculates the multiclass confusion matrix.
39+
*/
40+
public class MulticlassConfusionMatrixMetric implements EvaluationMetric {
41+
42+
public static final String NAME = "multiclass_confusion_matrix";
43+
44+
public static final ParseField SIZE = new ParseField("size");
45+
46+
private static final ConstructingObjectParser<MulticlassConfusionMatrixMetric, Void> PARSER = createParser();
47+
48+
private static ConstructingObjectParser<MulticlassConfusionMatrixMetric, Void> createParser() {
49+
ConstructingObjectParser<MulticlassConfusionMatrixMetric, Void> parser =
50+
new ConstructingObjectParser<>(NAME, true, args -> new MulticlassConfusionMatrixMetric((Integer) args[0]));
51+
parser.declareInt(optionalConstructorArg(), SIZE);
52+
return parser;
53+
}
54+
55+
public static MulticlassConfusionMatrixMetric fromXContent(XContentParser parser) {
56+
return PARSER.apply(parser, null);
57+
}
58+
59+
private final Integer size;
60+
61+
public MulticlassConfusionMatrixMetric() {
62+
this(null);
63+
}
64+
65+
public MulticlassConfusionMatrixMetric(@Nullable Integer size) {
66+
this.size = size;
67+
}
68+
69+
@Override
70+
public String getName() {
71+
return NAME;
72+
}
73+
74+
@Override
75+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
76+
builder.startObject();
77+
if (size != null) {
78+
builder.field(SIZE.getPreferredName(), size);
79+
}
80+
builder.endObject();
81+
return builder;
82+
}
83+
84+
@Override
85+
public boolean equals(Object o) {
86+
if (this == o) return true;
87+
if (o == null || getClass() != o.getClass()) return false;
88+
MulticlassConfusionMatrixMetric that = (MulticlassConfusionMatrixMetric) o;
89+
return Objects.equals(this.size, that.size);
90+
}
91+
92+
@Override
93+
public int hashCode() {
94+
return Objects.hash(size);
95+
}
96+
97+
public static class Result implements EvaluationMetric.Result {
98+
99+
private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix");
100+
private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_");
101+
102+
@SuppressWarnings("unchecked")
103+
private static final ConstructingObjectParser<Result, Void> PARSER =
104+
new ConstructingObjectParser<>(
105+
"multiclass_confusion_matrix_result", true, a -> new Result((Map<String, Map<String, Long>>) a[0], (long) a[1]));
106+
107+
static {
108+
PARSER.declareObject(
109+
constructorArg(),
110+
(p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)),
111+
CONFUSION_MATRIX);
112+
PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT);
113+
}
114+
115+
public static Result fromXContent(XContentParser parser) {
116+
return PARSER.apply(parser, null);
117+
}
118+
119+
// Immutable
120+
private final Map<String, Map<String, Long>> confusionMatrix;
121+
private final long otherClassesCount;
122+
123+
public Result(Map<String, Map<String, Long>> confusionMatrix, long otherClassesCount) {
124+
this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix));
125+
this.otherClassesCount = otherClassesCount;
126+
}
127+
128+
@Override
129+
public String getMetricName() {
130+
return NAME;
131+
}
132+
133+
public Map<String, Map<String, Long>> getConfusionMatrix() {
134+
return confusionMatrix;
135+
}
136+
137+
public long getOtherClassesCount() {
138+
return otherClassesCount;
139+
}
140+
141+
@Override
142+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
143+
builder.startObject();
144+
builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
145+
builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount);
146+
builder.endObject();
147+
return builder;
148+
}
149+
150+
@Override
151+
public boolean equals(Object o) {
152+
if (this == o) return true;
153+
if (o == null || getClass() != o.getClass()) return false;
154+
Result that = (Result) o;
155+
return Objects.equals(this.confusionMatrix, that.confusionMatrix)
156+
&& this.otherClassesCount == that.otherClassesCount;
157+
}
158+
159+
@Override
160+
public int hashCode() {
161+
return Objects.hash(confusionMatrix, otherClassesCount);
162+
}
163+
}
164+
}

0 commit comments

Comments
 (0)