Skip to content

Change format of MulticlassConfusionMatrix result to be more self-explanatory #48174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.List;
import java.util.Objects;
import java.util.TreeMap;

import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;

/**
Expand Down Expand Up @@ -97,52 +97,52 @@ public int hashCode() {
public static class Result implements EvaluationMetric.Result {

private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix");
private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_");
private static final ParseField OTHER_ACTUAL_CLASS_COUNT = new ParseField("other_actual_class_count");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>(
"multiclass_confusion_matrix_result", true, a -> new Result((Map<String, Map<String, Long>>) a[0], (long) a[1]));
"multiclass_confusion_matrix_result", true, a -> new Result((List<ActualClass>) a[0], (Long) a[1]));

static {
PARSER.declareObject(
constructorArg(),
(p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)),
CONFUSION_MATRIX);
PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT);
PARSER.declareObjectArray(optionalConstructorArg(), ActualClass.PARSER, CONFUSION_MATRIX);
PARSER.declareLong(optionalConstructorArg(), OTHER_ACTUAL_CLASS_COUNT);
}

public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

// Immutable
private final Map<String, Map<String, Long>> confusionMatrix;
private final long otherClassesCount;
private final List<ActualClass> confusionMatrix;
private final Long otherActualClassCount;

public Result(Map<String, Map<String, Long>> confusionMatrix, long otherClassesCount) {
this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix));
this.otherClassesCount = otherClassesCount;
public Result(@Nullable List<ActualClass> confusionMatrix, @Nullable Long otherActualClassCount) {
this.confusionMatrix = confusionMatrix != null ? Collections.unmodifiableList(Objects.requireNonNull(confusionMatrix)) : null;
this.otherActualClassCount = otherActualClassCount;
}

@Override
public String getMetricName() {
return NAME;
}

public Map<String, Map<String, Long>> getConfusionMatrix() {
public List<ActualClass> getConfusionMatrix() {
return confusionMatrix;
}

public long getOtherClassesCount() {
return otherClassesCount;
public Long getOtherActualClassCount() {
return otherActualClassCount;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount);
if (confusionMatrix != null) {
builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
}
if (otherActualClassCount != null) {
builder.field(OTHER_ACTUAL_CLASS_COUNT.getPreferredName(), otherActualClassCount);
}
builder.endObject();
return builder;
}
Expand All @@ -153,12 +153,140 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(this.confusionMatrix, that.confusionMatrix)
&& this.otherClassesCount == that.otherClassesCount;
&& Objects.equals(this.otherActualClassCount, that.otherActualClassCount);
}

@Override
public int hashCode() {
return Objects.hash(confusionMatrix, otherClassesCount);
return Objects.hash(confusionMatrix, otherActualClassCount);
}
}

public static class ActualClass implements ToXContentObject {

private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes");
private static final ParseField OTHER_PREDICTED_CLASS_DOC_COUNT = new ParseField("other_predicted_class_doc_count");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
new ConstructingObjectParser<>(
"multiclass_confusion_matrix_actual_class",
true,
a -> new ActualClass((String) a[0], (Long) a[1], (List<PredictedClass>) a[2], (Long) a[3]));

static {
PARSER.declareString(optionalConstructorArg(), ACTUAL_CLASS);
PARSER.declareLong(optionalConstructorArg(), ACTUAL_CLASS_DOC_COUNT);
PARSER.declareObjectArray(optionalConstructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES);
PARSER.declareLong(optionalConstructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT);
}

private final String actualClass;
private final Long actualClassDocCount;
private final List<PredictedClass> predictedClasses;
private final Long otherPredictedClassDocCount;

public ActualClass(@Nullable String actualClass,
@Nullable Long actualClassDocCount,
@Nullable List<PredictedClass> predictedClasses,
@Nullable Long otherPredictedClassDocCount) {
this.actualClass = actualClass;
this.actualClassDocCount = actualClassDocCount;
this.predictedClasses = predictedClasses != null ? Collections.unmodifiableList(predictedClasses) : null;
this.otherPredictedClassDocCount = otherPredictedClassDocCount;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (actualClass != null) {
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
}
if (actualClassDocCount != null) {
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
}
if (predictedClasses != null) {
builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses);
}
if (otherPredictedClassDocCount != null) {
builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), otherPredictedClassDocCount);
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ActualClass that = (ActualClass) o;
return Objects.equals(this.actualClass, that.actualClass)
&& Objects.equals(this.actualClassDocCount, that.actualClassDocCount)
&& Objects.equals(this.predictedClasses, that.predictedClasses)
&& Objects.equals(this.otherPredictedClassDocCount, that.otherPredictedClassDocCount);
}

@Override
public int hashCode() {
return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount);
}

@Override
public String toString() {
return Strings.toString(this);
}
}

public static class PredictedClass implements ToXContentObject {

private static final ParseField PREDICTED_CLASS = new ParseField("predicted_class");
private static final ParseField COUNT = new ParseField("count");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<PredictedClass, Void> PARSER =
new ConstructingObjectParser<>(
"multiclass_confusion_matrix_predicted_class", true, a -> new PredictedClass((String) a[0], (Long) a[1]));

static {
PARSER.declareString(optionalConstructorArg(), PREDICTED_CLASS);
PARSER.declareLong(optionalConstructorArg(), COUNT);
}

private final String predictedClass;
private final Long count;

public PredictedClass(@Nullable String predictedClass, @Nullable Long count) {
this.predictedClass = predictedClass;
this.count = count;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (predictedClass != null) {
builder.field(PREDICTED_CLASS.getPreferredName(), predictedClass);
}
if (count != null) {
builder.field(COUNT.getPreferredName(), count);
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PredictedClass that = (PredictedClass) o;
return Objects.equals(this.predictedClass, that.predictedClass)
&& Objects.equals(this.count, that.count);
}

@Override
public int hashCode() {
return Objects.hash(predictedClass, count);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@
import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
Expand Down Expand Up @@ -1777,7 +1779,7 @@ public void testEvaluateDataFrame_Classification() throws IOException {
.add(docForClassification(indexName, "dog", "dog"))
.add(docForClassification(indexName, "dog", "dog"))
.add(docForClassification(indexName, "dog", "dog"))
.add(docForClassification(indexName, "horse", "cat"));
.add(docForClassification(indexName, "ant", "cat"));
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);

MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
Expand All @@ -1800,11 +1802,23 @@ public void testEvaluateDataFrame_Classification() throws IOException {
assertThat(
mcmResult.getConfusionMatrix(),
equalTo(
Map.of(
"cat", Map.of("cat", 3L, "dog", 1L, "horse", 0L, "_other_", 1L),
"dog", Map.of("cat", 1L, "dog", 3L, "horse", 0L),
"horse", Map.of("cat", 1L, "dog", 0L, "horse", 0L))));
assertThat(mcmResult.getOtherClassesCount(), equalTo(0L));
List.of(
new ActualClass(
"ant",
1L,
List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)),
0L),
new ActualClass(
"cat",
5L,
List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)),
1L),
new ActualClass(
"dog",
4L,
List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
0L))));
assertThat(mcmResult.getOtherActualClassCount(), equalTo(0L));
}
{ // Explicit size provided for MulticlassConfusionMatrixMetric metric
EvaluateDataFrameRequest evaluateDataFrameRequest =
Expand All @@ -1824,10 +1838,11 @@ public void testEvaluateDataFrame_Classification() throws IOException {
assertThat(
mcmResult.getConfusionMatrix(),
equalTo(
Map.of(
"cat", Map.of("cat", 3L, "dog", 1L, "_other_", 1L),
"dog", Map.of("cat", 1L, "dog", 3L))));
assertThat(mcmResult.getOtherClassesCount(), equalTo(1L));
List.of(
new ActualClass("cat", 5L, List.of(new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1L),
new ActualClass("dog", 4L, List.of(new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0L)
)));
assertThat(mcmResult.getOtherActualClassCount(), equalTo(1L));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
Expand Down Expand Up @@ -3355,18 +3357,30 @@ public void testEvaluateDataFrame_Classification() throws Exception {
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1>

Map<String, Map<String, Long>> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2>
long otherClassesCount = multiclassConfusionMatrix.getOtherClassesCount(); // <3>
List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2>
long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <3>
// end::evaluate-data-frame-results-classification

assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
assertThat(
confusionMatrix,
equalTo(
Map.of(
"cat", Map.of("cat", 3L, "dog", 1L, "ant", 0L, "_other_", 1L),
"dog", Map.of("cat", 1L, "dog", 3L, "ant", 0L),
"ant", Map.of("cat", 1L, "dog", 0L, "ant", 0L))));
List.of(
new ActualClass(
"ant",
1L,
List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)),
0L),
new ActualClass(
"cat",
5L,
List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)),
1L),
new ActualClass(
"dog",
4L,
List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
0L))));
assertThat(otherClassesCount, equalTo(0L));
}
}
Expand Down
Loading