diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index e4534c5603bd8..9b77518474e67 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -24,6 +24,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.common.AucRocResult; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection; import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric; @@ -122,10 +123,9 @@ Evaluation.class, new ParseField(OutlierDetection.NAME), OutlierDetection::fromX // Evaluation metrics results new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, - new ParseField( - registeredMetricName( - OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME)), - org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric.Result::fromXContent), + new ParseField(registeredMetricName( + OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME)), + AucRocResult::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField( @@ -145,7 +145,7 @@ Evaluation.class, new ParseField(OutlierDetection.NAME), OutlierDetection::fromX new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(registeredMetricName(Classification.NAME, AucRocMetric.NAME)), - AucRocMetric.Result::fromXContent), + AucRocResult::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)), diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java index 151783499e46b..9f496d4cf1245 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java @@ -20,11 +20,9 @@ import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.common.ParseField; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContent; -import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; @@ -99,10 +97,10 @@ public static class Result implements EvaluationMetric.Result { @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List) a[0], (double) a[1])); + new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List) a[0], (double) a[1])); static { - PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES); + PARSER.declareObjectArray(constructorArg(), PerClassSingleValue.PARSER, CLASSES); PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY); } @@ -111,11 +109,11 @@ public static Result fromXContent(XContentParser parser) { } /** List of per-class results. */ - private final List classes; + private final List classes; /** Fraction of documents for which predicted class equals the actual class. */ private final double overallAccuracy; - public Result(List classes, double overallAccuracy) { + public Result(List classes, double overallAccuracy) { this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes)); this.overallAccuracy = overallAccuracy; } @@ -125,7 +123,7 @@ public String getMetricName() { return NAME; } - public List getClasses() { + public List getClasses() { return classes; } @@ -156,65 +154,4 @@ public int hashCode() { return Objects.hash(classes, overallAccuracy); } } - - public static class PerClassResult implements ToXContentObject { - - private static final ParseField CLASS_NAME = new ParseField("class_name"); - private static final ParseField ACCURACY = new ParseField("accuracy"); - - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1])); - - static { - PARSER.declareString(constructorArg(), CLASS_NAME); - PARSER.declareDouble(constructorArg(), ACCURACY); - } - - /** Name of the class. */ - private final String className; - /** Fraction of documents that are either true positives or true negatives wrt {@code className}. */ - private final double accuracy; - - public PerClassResult(String className, double accuracy) { - this.className = Objects.requireNonNull(className); - this.accuracy = accuracy; - } - - public String getClassName() { - return className; - } - - public double getAccuracy() { - return accuracy; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(CLASS_NAME.getPreferredName(), className); - builder.field(ACCURACY.getPreferredName(), accuracy); - builder.endObject(); - return builder; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - PerClassResult that = (PerClassResult) o; - return Objects.equals(this.className, that.className) - && this.accuracy == that.accuracy; - } - - @Override - public int hashCode() { - return Objects.hash(className, accuracy); - } - - @Override - public String toString() { - return Strings.toString(this); - } - } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetric.java index 5e9b28303c977..aec1d8655db7c 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetric.java @@ -19,18 +19,14 @@ package org.elasticsearch.client.ml.dataframe.evaluation.classification; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; -import org.elasticsearch.common.Nullable; +import org.elasticsearch.client.ml.dataframe.evaluation.common.AucRocResult; import org.elasticsearch.common.ParseField; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ToXContent; -import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; -import java.util.Collections; -import java.util.List; import java.util.Objects; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; @@ -43,12 +39,11 @@ */ public class AucRocMetric implements EvaluationMetric { - public static final String NAME = "auc_roc"; + public static final String NAME = AucRocResult.NAME; public static final ParseField CLASS_NAME = new ParseField("class_name"); public static final ParseField INCLUDE_CURVE = new ParseField("include_curve"); - @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, true, args -> new AucRocMetric((String) args[0], (Boolean) args[1])); @@ -106,149 +101,4 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(className, includeCurve); } - - public static class Result implements EvaluationMetric.Result { - - public static Result fromXContent(XContentParser parser) { - return PARSER.apply(parser, null); - } - - private static final ParseField SCORE = new ParseField("score"); - private static final ParseField CURVE = new ParseField("curve"); - - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>( - "auc_roc_result", true, args -> new Result((double) args[0], (List) args[1])); - - static { - PARSER.declareDouble(constructorArg(), SCORE); - PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE); - } - - private final double score; - private final List curve; - - public Result(double score, @Nullable List curve) { - this.score = score; - this.curve = curve; - } - - @Override - public String getMetricName() { - return NAME; - } - - public double getScore() { - return score; - } - - public List getCurve() { - return curve == null ? null : Collections.unmodifiableList(curve); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { - builder.startObject(); - builder.field(SCORE.getPreferredName(), score); - if (curve != null && curve.isEmpty() == false) { - builder.field(CURVE.getPreferredName(), curve); - } - builder.endObject(); - return builder; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Result that = (Result) o; - return score == that.score - && Objects.equals(curve, that.curve); - } - - @Override - public int hashCode() { - return Objects.hash(score, curve); - } - - @Override - public String toString() { - return Strings.toString(this); - } - } - - public static final class AucRocPoint implements ToXContentObject { - - public static AucRocPoint fromXContent(XContentParser parser) { - return PARSER.apply(parser, null); - } - - private static final ParseField TPR = new ParseField("tpr"); - private static final ParseField FPR = new ParseField("fpr"); - private static final ParseField THRESHOLD = new ParseField("threshold"); - - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>( - "auc_roc_point", - true, - args -> new AucRocPoint((double) args[0], (double) args[1], (double) args[2])); - - static { - PARSER.declareDouble(constructorArg(), TPR); - PARSER.declareDouble(constructorArg(), FPR); - PARSER.declareDouble(constructorArg(), THRESHOLD); - } - - private final double tpr; - private final double fpr; - private final double threshold; - - public AucRocPoint(double tpr, double fpr, double threshold) { - this.tpr = tpr; - this.fpr = fpr; - this.threshold = threshold; - } - - public double getTruePositiveRate() { - return tpr; - } - - public double getFalsePositiveRate() { - return fpr; - } - - public double getThreshold() { - return threshold; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder - .startObject() - .field(TPR.getPreferredName(), tpr) - .field(FPR.getPreferredName(), fpr) - .field(THRESHOLD.getPreferredName(), threshold) - .endObject(); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - AucRocPoint that = (AucRocPoint) o; - return tpr == that.tpr && fpr == that.fpr && threshold == that.threshold; - } - - @Override - public int hashCode() { - return Objects.hash(tpr, fpr, threshold); - } - - @Override - public String toString() { - return Strings.toString(this); - } - } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PerClassSingleValue.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PerClassSingleValue.java new file mode 100644 index 0000000000000..2caf09085e751 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PerClassSingleValue.java @@ -0,0 +1,81 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class PerClassSingleValue implements ToXContentObject { + private static final ParseField CLASS_NAME = new ParseField("class_name"); + private static final ParseField VALUE = new ParseField("value"); + + public static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("per_class_result", true, a -> new PerClassSingleValue((String) a[0], (double) a[1])); + + static { + PARSER.declareString(constructorArg(), CLASS_NAME); + PARSER.declareDouble(constructorArg(), VALUE); + } + + private final String className; + private final double value; + + public PerClassSingleValue(String className, double value) { + this.className = Objects.requireNonNull(className); + this.value = value; + } + + public String getClassName() { + return className; + } + + public double getValue() { + return value; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASS_NAME.getPreferredName(), className); + builder.field(VALUE.getPreferredName(), value); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PerClassSingleValue that = (PerClassSingleValue) o; + return Objects.equals(this.className, that.className) + && this.value == that.value; + } + + @Override + public int hashCode() { + return Objects.hash(className, value); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetric.java index 8eff7986dcc36..64190d23b45b8 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetric.java @@ -22,7 +22,6 @@ import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; -import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; @@ -85,10 +84,10 @@ public static class Result implements EvaluationMetric.Result { @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("precision_result", true, a -> new Result((List) a[0], (double) a[1])); + new ConstructingObjectParser<>("precision_result", true, a -> new Result((List) a[0], (double) a[1])); static { - PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES); + PARSER.declareObjectArray(constructorArg(), PerClassSingleValue.PARSER, CLASSES); PARSER.declareDouble(constructorArg(), AVG_PRECISION); } @@ -97,11 +96,11 @@ public static Result fromXContent(XContentParser parser) { } /** List of per-class results. */ - private final List classes; + private final List classes; /** Average of per-class precisions. */ private final double avgPrecision; - public Result(List classes, double avgPrecision) { + public Result(List classes, double avgPrecision) { this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes)); this.avgPrecision = avgPrecision; } @@ -111,7 +110,7 @@ public String getMetricName() { return NAME; } - public List getClasses() { + public List getClasses() { return classes; } @@ -142,60 +141,4 @@ public int hashCode() { return Objects.hash(classes, avgPrecision); } } - - public static class PerClassResult implements ToXContentObject { - - private static final ParseField CLASS_NAME = new ParseField("class_name"); - private static final ParseField PRECISION = new ParseField("precision"); - - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("precision_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1])); - - static { - PARSER.declareString(constructorArg(), CLASS_NAME); - PARSER.declareDouble(constructorArg(), PRECISION); - } - - /** Name of the class. */ - private final String className; - /** Fraction of documents predicted as belonging to the {@code predictedClass} class predicted correctly. */ - private final double precision; - - public PerClassResult(String className, double precision) { - this.className = Objects.requireNonNull(className); - this.precision = precision; - } - - public String getClassName() { - return className; - } - - public double getPrecision() { - return precision; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(CLASS_NAME.getPreferredName(), className); - builder.field(PRECISION.getPreferredName(), precision); - builder.endObject(); - return builder; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - PerClassResult that = (PerClassResult) o; - return Objects.equals(this.className, that.className) - && this.precision == that.precision; - } - - @Override - public int hashCode() { - return Objects.hash(className, precision); - } - } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetric.java index d46a70da8c3f6..f973eada09599 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetric.java @@ -22,7 +22,6 @@ import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; -import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; @@ -85,10 +84,10 @@ public static class Result implements EvaluationMetric.Result { @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("recall_result", true, a -> new Result((List) a[0], (double) a[1])); + new ConstructingObjectParser<>("recall_result", true, a -> new Result((List) a[0], (double) a[1])); static { - PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES); + PARSER.declareObjectArray(constructorArg(), PerClassSingleValue.PARSER, CLASSES); PARSER.declareDouble(constructorArg(), AVG_RECALL); } @@ -97,11 +96,11 @@ public static Result fromXContent(XContentParser parser) { } /** List of per-class results. */ - private final List classes; + private final List classes; /** Average of per-class recalls. */ private final double avgRecall; - public Result(List classes, double avgRecall) { + public Result(List classes, double avgRecall) { this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes)); this.avgRecall = avgRecall; } @@ -111,7 +110,7 @@ public String getMetricName() { return NAME; } - public List getClasses() { + public List getClasses() { return classes; } @@ -142,60 +141,4 @@ public int hashCode() { return Objects.hash(classes, avgRecall); } } - - public static class PerClassResult implements ToXContentObject { - - private static final ParseField CLASS_NAME = new ParseField("class_name"); - private static final ParseField RECALL = new ParseField("recall"); - - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("recall_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1])); - - static { - PARSER.declareString(constructorArg(), CLASS_NAME); - PARSER.declareDouble(constructorArg(), RECALL); - } - - /** Name of the class. */ - private final String className; - /** Fraction of documents actually belonging to the {@code actualClass} class predicted correctly. */ - private final double recall; - - public PerClassResult(String className, double recall) { - this.className = Objects.requireNonNull(className); - this.recall = recall; - } - - public String getClassName() { - return className; - } - - public double getRecall() { - return recall; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(CLASS_NAME.getPreferredName(), className); - builder.field(RECALL.getPreferredName(), recall); - builder.endObject(); - return builder; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - PerClassResult that = (PerClassResult) o; - return Objects.equals(this.className, that.className) - && this.recall == that.recall; - } - - @Override - public int hashCode() { - return Objects.hash(className, recall); - } - } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/common/AucRocPoint.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/common/AucRocPoint.java new file mode 100644 index 0000000000000..92abb460134dc --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/common/AucRocPoint.java @@ -0,0 +1,104 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.common; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class AucRocPoint implements ToXContentObject { + + public static AucRocPoint fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final ParseField TPR = new ParseField("tpr"); + private static final ParseField FPR = new ParseField("fpr"); + private static final ParseField THRESHOLD = new ParseField("threshold"); + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "auc_roc_point", + true, + args -> new AucRocPoint((double) args[0], (double) args[1], (double) args[2])); + + static { + PARSER.declareDouble(constructorArg(), TPR); + PARSER.declareDouble(constructorArg(), FPR); + PARSER.declareDouble(constructorArg(), THRESHOLD); + } + + private final double tpr; + private final double fpr; + private final double threshold; + + public AucRocPoint(double tpr, double fpr, double threshold) { + this.tpr = tpr; + this.fpr = fpr; + this.threshold = threshold; + } + + public double getTruePositiveRate() { + return tpr; + } + + public double getFalsePositiveRate() { + return fpr; + } + + public double getThreshold() { + return threshold; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder + .startObject() + .field(TPR.getPreferredName(), tpr) + .field(FPR.getPreferredName(), fpr) + .field(THRESHOLD.getPreferredName(), threshold) + .endObject(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AucRocPoint that = (AucRocPoint) o; + return tpr == that.tpr && fpr == that.fpr && threshold == that.threshold; + } + + @Override + public int hashCode() { + return Objects.hash(tpr, fpr, threshold); + } + + @Override + public String toString() { + return Strings.toString(this); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/common/AucRocResult.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/common/AucRocResult.java new file mode 100644 index 0000000000000..e61730bd12380 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/common/AucRocResult.java @@ -0,0 +1,109 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.common; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class AucRocResult implements EvaluationMetric.Result { + + public static final String NAME = "auc_roc"; + + public static AucRocResult fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final ParseField VALUE = new ParseField("value"); + private static final ParseField CURVE = new ParseField("curve"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + NAME, true, args -> new AucRocResult((double) args[0], (List) args[1])); + + static { + PARSER.declareDouble(constructorArg(), VALUE); + PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE); + } + + private final double value; + private final List curve; + + public AucRocResult(double value, @Nullable List curve) { + this.value = value; + this.curve = curve; + } + + @Override + public String getMetricName() { + return NAME; + } + + public double getValue() { + return value; + } + + public List getCurve() { + return curve == null ? null : Collections.unmodifiableList(curve); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(VALUE.getPreferredName(), value); + if (curve != null && curve.isEmpty() == false) { + builder.field(CURVE.getPreferredName(), curve); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AucRocResult that = (AucRocResult) o; + return value == that.value + && Objects.equals(curve, that.curve); + } + + @Override + public int hashCode() { + return Objects.hash(value, curve); + } + + @Override + public String toString() { + return Strings.toString(this); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetric.java index 76d8c514daef6..aebb10792f0b4 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetric.java @@ -19,6 +19,7 @@ package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.common.AucRocResult; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -36,7 +37,7 @@ */ public class AucRocMetric implements EvaluationMetric { - public static final String NAME = "auc_roc"; + public static final String NAME = AucRocResult.NAME; public static final ParseField INCLUDE_CURVE = new ParseField("include_curve"); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index a065be7eae0e1..0ef00e8bb0e23 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -141,8 +141,11 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.PerClassSingleValue; import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.common.AucRocPoint; +import org.elasticsearch.client.ml.dataframe.evaluation.common.AucRocResult; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection; import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric; @@ -1789,17 +1792,17 @@ public void testEvaluateDataFrame_OutlierDetection() throws IOException { assertThat(confusionMatrix.getFalseNegatives(), equalTo(3L)); // docs #5, #6 and #7 assertNull(confusionMatrixResult.getScoreByThreshold("0.1")); - AucRocMetric.Result aucRocResult = + AucRocResult aucRocResult = evaluateDataFrameResponse.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME); assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME)); - assertThat(aucRocResult.getScore(), closeTo(0.70025, 1e-9)); + assertThat(aucRocResult.getValue(), closeTo(0.70025, 1e-9)); assertNotNull(aucRocResult.getCurve()); - List curve = aucRocResult.getCurve(); - AucRocMetric.AucRocPoint curvePointAtThreshold0 = curve.stream().filter(p -> p.getThreshold() == 0.0).findFirst().get(); + List curve = aucRocResult.getCurve(); + AucRocPoint curvePointAtThreshold0 = curve.stream().filter(p -> p.getThreshold() == 0.0).findFirst().get(); assertThat(curvePointAtThreshold0.getTruePositiveRate(), equalTo(1.0)); assertThat(curvePointAtThreshold0.getFalsePositiveRate(), equalTo(1.0)); assertThat(curvePointAtThreshold0.getThreshold(), equalTo(0.0)); - AucRocMetric.AucRocPoint curvePointAtThreshold1 = curve.stream().filter(p -> p.getThreshold() == 1.0).findFirst().get(); + AucRocPoint curvePointAtThreshold1 = curve.stream().filter(p -> p.getThreshold() == 1.0).findFirst().get(); assertThat(curvePointAtThreshold1.getTruePositiveRate(), equalTo(0.0)); assertThat(curvePointAtThreshold1.getFalsePositiveRate(), equalTo(0.0)); assertThat(curvePointAtThreshold1.getThreshold(), equalTo(1.0)); @@ -1925,9 +1928,9 @@ public void testEvaluateDataFrame_Classification() throws IOException { assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME)); assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); - AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME); + AucRocResult aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME); assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME)); - assertThat(aucRocResult.getScore(), closeTo(0.6425, 1e-9)); + assertThat(aucRocResult.getValue(), closeTo(0.6425, 1e-9)); assertNotNull(aucRocResult.getCurve()); } { // Accuracy @@ -1947,11 +1950,11 @@ public void testEvaluateDataFrame_Classification() throws IOException { equalTo( List.of( // 9 out of 10 examples were classified correctly - new AccuracyMetric.PerClassResult("ant", 0.9), + new PerClassSingleValue("ant", 0.9), // 6 out of 10 examples were classified correctly - new AccuracyMetric.PerClassResult("cat", 0.6), + new PerClassSingleValue("cat", 0.6), // 8 out of 10 examples were classified correctly - new AccuracyMetric.PerClassResult("dog", 0.8)))); + new PerClassSingleValue("dog", 0.8)))); assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly } { // Precision @@ -1971,9 +1974,9 @@ public void testEvaluateDataFrame_Classification() throws IOException { equalTo( List.of( // 3 out of 5 examples labeled as "cat" were classified correctly - new PrecisionMetric.PerClassResult("cat", 0.6), + new PerClassSingleValue("cat", 0.6), // 3 out of 4 examples labeled as "dog" were classified correctly - new PrecisionMetric.PerClassResult("dog", 0.75)))); + new PerClassSingleValue("dog", 0.75)))); assertThat(precisionResult.getAvgPrecision(), equalTo(0.675)); } { // Recall @@ -1993,11 +1996,11 @@ public void testEvaluateDataFrame_Classification() throws IOException { equalTo( List.of( // 3 out of 5 examples labeled as "cat" were classified correctly - new RecallMetric.PerClassResult("cat", 0.6), + new PerClassSingleValue("cat", 0.6), // 3 out of 4 examples labeled as "dog" were classified correctly - new RecallMetric.PerClassResult("dog", 0.75), + new PerClassSingleValue("dog", 0.75), // no examples labeled as "ant" were classified correctly - new RecallMetric.PerClassResult("ant", 0.0)))); + new PerClassSingleValue("ant", 0.0)))); assertThat(recallResult.getAvgRecall(), equalTo(0.45)); } { // No size provided for MulticlassConfusionMatrixMetric, default used instead diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index b9add106b82d7..4e753365c1e4d 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -162,6 +162,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass; import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.common.AucRocResult; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric.ConfusionMatrix; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection; @@ -3529,8 +3530,8 @@ public void testEvaluateDataFrame_Classification() throws Exception { List confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <8> long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <9> - AucRocMetric.Result aucRocResult = response.getMetricByName(AucRocMetric.NAME); // <10> - double aucRocScore = aucRocResult.getScore(); // <11> + AucRocResult aucRocResult = response.getMetricByName(AucRocMetric.NAME); // <10> + double aucRocScore = aucRocResult.getValue(); // <11> // end::evaluate-data-frame-results-classification assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME)); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java index 50fe97b51956d..a3d23f820d259 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java @@ -26,7 +26,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetricResultTests; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetricResultTests; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; -import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetricResultTests; +import org.elasticsearch.client.ml.dataframe.evaluation.common.AucRocResultTests; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetricResultTests; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetricResultTests; @@ -49,7 +49,7 @@ public static EvaluateDataFrameResponse randomResponse() { case OutlierDetection.NAME: metrics = randomSubsetOf( Arrays.asList( - AucRocMetricResultTests.randomResult(), + AucRocResultTests.randomResult(), PrecisionMetricResultTests.randomResult(), RecallMetricResultTests.randomResult(), ConfusionMatrixMetricResultTests.randomResult())); @@ -63,6 +63,7 @@ public static EvaluateDataFrameResponse randomResponse() { case Classification.NAME: metrics = randomSubsetOf( Arrays.asList( + AucRocResultTests.randomResult(), AccuracyMetricResultTests.randomResult(), org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetricResultTests.randomResult(), org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetricResultTests.randomResult(), diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java index 8758cea86c451..00b254fe3d863 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java @@ -19,7 +19,6 @@ package org.elasticsearch.client.ml.dataframe.evaluation.classification; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; -import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.PerClassResult; import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.Result; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; @@ -41,10 +40,10 @@ protected NamedXContentRegistry xContentRegistry() { public static Result randomResult() { int numClasses = randomIntBetween(2, 100); List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); - List classes = new ArrayList<>(numClasses); + List classes = new ArrayList<>(numClasses); for (int i = 0; i < numClasses; i++) { double accuracy = randomDoubleBetween(0.0, 1.0, true); - classes.add(new PerClassResult(classNames.get(i), accuracy)); + classes.add(new PerClassSingleValue(classNames.get(i), accuracy)); } double overallAccuracy = randomDoubleBetween(0.0, 1.0, true); return new Result(classes, overallAccuracy); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricAucRocPointTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PerClassSingleValueTests.java similarity index 66% rename from client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricAucRocPointTests.java rename to client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PerClassSingleValueTests.java index d3242906e34c4..0a3d3db5d68b4 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricAucRocPointTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PerClassSingleValueTests.java @@ -7,7 +7,7 @@ * not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ + package org.elasticsearch.client.ml.dataframe.evaluation.classification; import org.elasticsearch.common.xcontent.XContentParser; @@ -23,20 +24,16 @@ import java.io.IOException; -public class AucRocMetricAucRocPointTests extends AbstractXContentTestCase { - - static AucRocMetric.AucRocPoint randomPoint() { - return new AucRocMetric.AucRocPoint(randomDouble(), randomDouble(), randomDouble()); - } +public class PerClassSingleValueTests extends AbstractXContentTestCase { @Override - protected AucRocMetric.AucRocPoint createTestInstance() { - return randomPoint(); + protected PerClassSingleValue createTestInstance() { + return new PerClassSingleValue(randomAlphaOfLength(10), randomDouble()); } @Override - protected AucRocMetric.AucRocPoint doParseInstance(XContentParser parser) throws IOException { - return AucRocMetric.AucRocPoint.fromXContent(parser); + protected PerClassSingleValue doParseInstance(XContentParser parser) throws IOException { + return PerClassSingleValue.PARSER.apply(parser, null); } @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetricResultTests.java index ef6e41e78f0e8..50d023bf5e3e7 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetricResultTests.java @@ -19,7 +19,6 @@ package org.elasticsearch.client.ml.dataframe.evaluation.classification; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; -import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult; import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; @@ -41,10 +40,10 @@ protected NamedXContentRegistry xContentRegistry() { public static Result randomResult() { int numClasses = randomIntBetween(2, 100); List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); - List classes = new ArrayList<>(numClasses); + List classes = new ArrayList<>(numClasses); for (int i = 0; i < numClasses; i++) { double precision = randomDoubleBetween(0.0, 1.0, true); - classes.add(new PerClassResult(classNames.get(i), precision)); + classes.add(new PerClassSingleValue(classNames.get(i), precision)); } double avgPrecision = randomDoubleBetween(0.0, 1.0, true); return new Result(classes, avgPrecision); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetricResultTests.java index f8fffb405ea1b..1f001e6014fb9 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetricResultTests.java @@ -19,7 +19,6 @@ package org.elasticsearch.client.ml.dataframe.evaluation.classification; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; -import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult; import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; @@ -41,10 +40,10 @@ protected NamedXContentRegistry xContentRegistry() { public static Result randomResult() { int numClasses = randomIntBetween(2, 100); List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); - List classes = new ArrayList<>(numClasses); + List classes = new ArrayList<>(numClasses); for (int i = 0; i < numClasses; i++) { double recall = randomDoubleBetween(0.0, 1.0, true); - classes.add(new PerClassResult(classNames.get(i), recall)); + classes.add(new PerClassSingleValue(classNames.get(i), recall)); } double avgRecall = randomDoubleBetween(0.0, 1.0, true); return new Result(classes, avgRecall); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/common/AucRocPointTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/common/AucRocPointTests.java new file mode 100644 index 0000000000000..7faaceb7036fd --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/common/AucRocPointTests.java @@ -0,0 +1,46 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.common; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class AucRocPointTests extends AbstractXContentTestCase { + + static AucRocPoint randomPoint() { + return new AucRocPoint(randomDouble(), randomDouble(), randomDouble()); + } + + @Override + protected AucRocPoint createTestInstance() { + return randomPoint(); + } + + @Override + protected AucRocPoint doParseInstance(XContentParser parser) throws IOException { + return AucRocPoint.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/common/AucRocResultTests.java similarity index 70% rename from client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricResultTests.java rename to client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/common/AucRocResultTests.java index 40ada86f48445..85410b9d7e293 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/common/AucRocResultTests.java @@ -7,7 +7,7 @@ * not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an @@ -16,8 +16,9 @@ * specific language governing permissions and limitations * under the License. */ -package org.elasticsearch.client.ml.dataframe.evaluation.classification; +package org.elasticsearch.client.ml.dataframe.evaluation.common; +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; @@ -26,25 +27,25 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -public class AucRocMetricResultTests extends AbstractXContentTestCase { +public class AucRocResultTests extends AbstractXContentTestCase { - public static AucRocMetric.Result randomResult() { - return new AucRocMetric.Result( + public static EvaluationMetric.Result randomResult() { + return new AucRocResult( randomDouble(), Stream - .generate(AucRocMetricAucRocPointTests::randomPoint) + .generate(AucRocPointTests::randomPoint) .limit(randomIntBetween(1, 10)) .collect(Collectors.toList())); } @Override - protected AucRocMetric.Result createTestInstance() { + protected EvaluationMetric.Result createTestInstance() { return randomResult(); } @Override - protected AucRocMetric.Result doParseInstance(XContentParser parser) throws IOException { - return AucRocMetric.Result.fromXContent(parser); + protected EvaluationMetric.Result doParseInstance(XContentParser parser) throws IOException { + return AucRocResult.fromXContent(parser); } @Override diff --git a/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc b/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc index f9c643689a28b..4ae5db53de1b9 100644 --- a/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc +++ b/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc @@ -258,7 +258,7 @@ The API returns the following results: { "outlier_detection": { "auc_roc": { - "score": 0.92584757746414444 + "value": 0.92584757746414444 }, "confusion_matrix": { "0.25": { @@ -534,7 +534,7 @@ The API returns the following result: { "classification" : { "auc_roc" : { - "score" : 0.8941788639536681 + "value" : 0.8941788639536681 } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index c7ae0a3848775..d7ad4781e1be7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -11,10 +11,11 @@ import org.elasticsearch.plugins.spi.NamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AucRoc; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.common.AbstractAucRoc; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.ConfusionMatrix; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.ScoreByThresholdResult; @@ -179,8 +180,8 @@ public static List getNamedWriteables() { registeredMetricName(OutlierDetection.NAME, ConfusionMatrix.NAME), ConfusionMatrix.Result::new), new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, - registeredMetricName(Classification.NAME, AucRoc.NAME), - AucRoc.Result::new), + AbstractAucRoc.Result.NAME, + AbstractAucRoc.Result::new), new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME), MulticlassConfusionMatrix.Result::new), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java index 96d249326fbe6..347b13dd8fd7a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java @@ -10,11 +10,9 @@ import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; -import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.script.Script; @@ -150,14 +148,14 @@ public Optional getResult() { * Time complexity of this method is linear wrt multiclass confusion matrix size, so O(n^2) where n is the matrix dimension. * This method is visible for testing only. */ - static List computePerClassAccuracy(MulticlassConfusionMatrix.Result matrixResult) { + static List computePerClassAccuracy(MulticlassConfusionMatrix.Result matrixResult) { assert matrixResult.getOtherActualClassCount() == 0; // Number of actual classes taken into account int n = matrixResult.getConfusionMatrix().size(); // Total number of documents taken into account long totalDocCount = matrixResult.getConfusionMatrix().stream().mapToLong(MulticlassConfusionMatrix.ActualClass::getActualClassDocCount).sum(); - List classes = new ArrayList<>(n); + List classes = new ArrayList<>(n); for (int i = 0; i < n; ++i) { String className = matrixResult.getConfusionMatrix().get(i).getActualClass(); // Start with the assumption that all the docs were predicted correctly. @@ -172,7 +170,7 @@ static List computePerClassAccuracy(MulticlassConfusionMatrix.Re } // Subtract errors (false negatives) for classes other than explicitly listed in confusion matrix correctDocCount -= matrixResult.getConfusionMatrix().get(i).getOtherPredictedClassDocCount(); - classes.add(new PerClassResult(className, ((double)correctDocCount) / totalDocCount)); + classes.add(new PerClassSingleValue(className, ((double)correctDocCount) / totalDocCount)); } return classes; } @@ -209,10 +207,10 @@ public static class Result implements EvaluationMetricResult { @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List) a[0], (double) a[1])); + new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List) a[0], (double) a[1])); static { - PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES); + PARSER.declareObjectArray(constructorArg(), PerClassSingleValue.PARSER, CLASSES); PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY); } @@ -221,17 +219,17 @@ public static Result fromXContent(XContentParser parser) { } /** List of per-class results. */ - private final List classes; + private final List classes; /** Fraction of documents for which predicted class equals the actual class. */ private final double overallAccuracy; - public Result(List classes, double overallAccuracy) { + public Result(List classes, double overallAccuracy) { this.classes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(classes, CLASSES)); this.overallAccuracy = overallAccuracy; } public Result(StreamInput in) throws IOException { - this.classes = Collections.unmodifiableList(in.readList(PerClassResult::new)); + this.classes = Collections.unmodifiableList(in.readList(PerClassSingleValue::new)); this.overallAccuracy = in.readDouble(); } @@ -245,7 +243,7 @@ public String getMetricName() { return NAME.getPreferredName(); } - public List getClasses() { + public List getClasses() { return classes; } @@ -282,71 +280,4 @@ public int hashCode() { return Objects.hash(classes, overallAccuracy); } } - - public static class PerClassResult implements ToXContentObject, Writeable { - - private static final ParseField CLASS_NAME = new ParseField("class_name"); - private static final ParseField ACCURACY = new ParseField("accuracy"); - - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1])); - - static { - PARSER.declareString(constructorArg(), CLASS_NAME); - PARSER.declareDouble(constructorArg(), ACCURACY); - } - - /** Name of the class. */ - private final String className; - /** Fraction of documents that are either true positives or true negatives wrt {@code className}. */ - private final double accuracy; - - public PerClassResult(String className, double accuracy) { - this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME); - this.accuracy = accuracy; - } - - public PerClassResult(StreamInput in) throws IOException { - this.className = in.readString(); - this.accuracy = in.readDouble(); - } - - public String getClassName() { - return className; - } - - public double getAccuracy() { - return accuracy; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(className); - out.writeDouble(accuracy); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(CLASS_NAME.getPreferredName(), className); - builder.field(ACCURACY.getPreferredName(), accuracy); - builder.endObject(); - return builder; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - PerClassResult that = (PerClassResult) o; - return Objects.equals(this.className, that.className) - && this.accuracy == that.accuracy; - } - - @Override - public int hashCode() { - return Objects.hash(className, accuracy); - } - } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRoc.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRoc.java index 28353691d12ce..b8b53339db3e8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRoc.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRoc.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.common.AbstractAucRoc; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PerClassSingleValue.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PerClassSingleValue.java new file mode 100644 index 0000000000000..3b33ccd5d60c6 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PerClassSingleValue.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class PerClassSingleValue implements ToXContentObject, Writeable { + + private static final ParseField CLASS_NAME = new ParseField("class_name"); + private static final ParseField VALUE = new ParseField("value"); + + public static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("per_class_result", true, a -> new PerClassSingleValue((String) a[0], (double) a[1])); + + static { + PARSER.declareString(constructorArg(), CLASS_NAME); + PARSER.declareDouble(constructorArg(), VALUE); + } + + private final String className; + private final double value; + + public PerClassSingleValue(String className, double value) { + this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME); + this.value = value; + } + + public PerClassSingleValue(StreamInput in) throws IOException { + this.className = in.readString(); + this.value = in.readDouble(); + } + + public String getClassName() { + return className; + } + + public double getValue() { + return value; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(className); + out.writeDouble(value); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASS_NAME.getPreferredName(), className); + builder.field(VALUE.getPreferredName(), value); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PerClassSingleValue that = (PerClassSingleValue) o; + return Objects.equals(this.className, that.className) + && this.value == that.value; + } + + @Override + public int hashCode() { + return Objects.hash(className, value); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java index e5cabf6e90db2..f49ee69f9c3ff 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java @@ -10,11 +10,9 @@ import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; -import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.query.QueryBuilders; @@ -149,13 +147,13 @@ public void process(Aggregations aggs) { aggs.get(AVG_PRECISION_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) { Filters byPredictedClassAgg = aggs.get(BY_PREDICTED_CLASS_AGG_NAME); NumericMetricsAggregation.SingleValue avgPrecisionAgg = aggs.get(AVG_PRECISION_AGG_NAME); - List classes = new ArrayList<>(byPredictedClassAgg.getBuckets().size()); + List classes = new ArrayList<>(byPredictedClassAgg.getBuckets().size()); for (Filters.Bucket bucket : byPredictedClassAgg.getBuckets()) { String className = bucket.getKeyAsString(); NumericMetricsAggregation.SingleValue precisionAgg = bucket.getAggregations().get(PER_PREDICTED_CLASS_PRECISION_AGG_NAME); double precision = precisionAgg.value(); if (Double.isFinite(precision)) { - classes.add(new PerClassResult(className, precision)); + classes.add(new PerClassSingleValue(className, precision)); } } result.set(new Result(classes, avgPrecisionAgg.value())); @@ -197,10 +195,10 @@ public static class Result implements EvaluationMetricResult { @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("precision_result", true, a -> new Result((List) a[0], (double) a[1])); + new ConstructingObjectParser<>("precision_result", true, a -> new Result((List) a[0], (double) a[1])); static { - PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES); + PARSER.declareObjectArray(constructorArg(), PerClassSingleValue.PARSER, CLASSES); PARSER.declareDouble(constructorArg(), AVG_PRECISION); } @@ -209,17 +207,17 @@ public static Result fromXContent(XContentParser parser) { } /** List of per-class results. */ - private final List classes; + private final List classes; /** Average of per-class precisions. */ private final double avgPrecision; - public Result(List classes, double avgPrecision) { + public Result(List classes, double avgPrecision) { this.classes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(classes, CLASSES)); this.avgPrecision = avgPrecision; } public Result(StreamInput in) throws IOException { - this.classes = Collections.unmodifiableList(in.readList(PerClassResult::new)); + this.classes = Collections.unmodifiableList(in.readList(PerClassSingleValue::new)); this.avgPrecision = in.readDouble(); } @@ -233,7 +231,7 @@ public String getMetricName() { return NAME.getPreferredName(); } - public List getClasses() { + public List getClasses() { return classes; } @@ -270,71 +268,4 @@ public int hashCode() { return Objects.hash(classes, avgPrecision); } } - - public static class PerClassResult implements ToXContentObject, Writeable { - - private static final ParseField CLASS_NAME = new ParseField("class_name"); - private static final ParseField PRECISION = new ParseField("precision"); - - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("precision_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1])); - - static { - PARSER.declareString(constructorArg(), CLASS_NAME); - PARSER.declareDouble(constructorArg(), PRECISION); - } - - /** Name of the class. */ - private final String className; - /** Fraction of documents predicted as belonging to the {@code predictedClass} class predicted correctly. */ - private final double precision; - - public PerClassResult(String className, double precision) { - this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME); - this.precision = precision; - } - - public PerClassResult(StreamInput in) throws IOException { - this.className = in.readString(); - this.precision = in.readDouble(); - } - - public String getClassName() { - return className; - } - - public double getPrecision() { - return precision; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(className); - out.writeDouble(precision); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(CLASS_NAME.getPreferredName(), className); - builder.field(PRECISION.getPreferredName(), precision); - builder.endObject(); - return builder; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - PerClassResult that = (PerClassResult) o; - return Objects.equals(this.className, that.className) - && this.precision == that.precision; - } - - @Override - public int hashCode() { - return Objects.hash(className, precision); - } - } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java index 8592a33ac0436..cfc7f810e4edb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java @@ -10,11 +10,9 @@ import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; -import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.script.Script; @@ -128,11 +126,11 @@ public void process(Aggregations aggs) { "Cannot calculate average recall. Cardinality of field [{}] is too high", actualField.get()); } NumericMetricsAggregation.SingleValue avgRecallAgg = aggs.get(AVG_RECALL_AGG_NAME); - List classes = new ArrayList<>(byActualClassAgg.getBuckets().size()); + List classes = new ArrayList<>(byActualClassAgg.getBuckets().size()); for (Terms.Bucket bucket : byActualClassAgg.getBuckets()) { String className = bucket.getKeyAsString(); NumericMetricsAggregation.SingleValue recallAgg = bucket.getAggregations().get(PER_ACTUAL_CLASS_RECALL_AGG_NAME); - classes.add(new PerClassResult(className, recallAgg.value())); + classes.add(new PerClassSingleValue(className, recallAgg.value())); } result.set(new Result(classes, avgRecallAgg.value())); } @@ -173,10 +171,10 @@ public static class Result implements EvaluationMetricResult { @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("recall_result", true, a -> new Result((List) a[0], (double) a[1])); + new ConstructingObjectParser<>("recall_result", true, a -> new Result((List) a[0], (double) a[1])); static { - PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES); + PARSER.declareObjectArray(constructorArg(), PerClassSingleValue.PARSER, CLASSES); PARSER.declareDouble(constructorArg(), AVG_RECALL); } @@ -185,17 +183,17 @@ public static Result fromXContent(XContentParser parser) { } /** List of per-class results. */ - private final List classes; + private final List classes; /** Average of per-class recalls. */ private final double avgRecall; - public Result(List classes, double avgRecall) { + public Result(List classes, double avgRecall) { this.classes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(classes, CLASSES)); this.avgRecall = avgRecall; } public Result(StreamInput in) throws IOException { - this.classes = Collections.unmodifiableList(in.readList(PerClassResult::new)); + this.classes = Collections.unmodifiableList(in.readList(PerClassSingleValue::new)); this.avgRecall = in.readDouble(); } @@ -209,7 +207,7 @@ public String getMetricName() { return NAME.getPreferredName(); } - public List getClasses() { + public List getClasses() { return classes; } @@ -246,71 +244,4 @@ public int hashCode() { return Objects.hash(classes, avgRecall); } } - - public static class PerClassResult implements ToXContentObject, Writeable { - - private static final ParseField CLASS_NAME = new ParseField("class_name"); - private static final ParseField RECALL = new ParseField("recall"); - - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("recall_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1])); - - static { - PARSER.declareString(constructorArg(), CLASS_NAME); - PARSER.declareDouble(constructorArg(), RECALL); - } - - /** Name of the class. */ - private final String className; - /** Fraction of documents actually belonging to the {@code actualClass} class predicted correctly. */ - private final double recall; - - public PerClassResult(String className, double recall) { - this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME); - this.recall = recall; - } - - public PerClassResult(StreamInput in) throws IOException { - this.className = in.readString(); - this.recall = in.readDouble(); - } - - public String getClassName() { - return className; - } - - public double getRecall() { - return recall; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(className); - out.writeDouble(recall); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(CLASS_NAME.getPreferredName(), className); - builder.field(RECALL.getPreferredName(), recall); - builder.endObject(); - return builder; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - PerClassResult that = (PerClassResult) o; - return Objects.equals(this.className, that.className) - && this.recall == that.recall; - } - - @Override - public int hashCode() { - return Objects.hash(className, recall); - } - } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AbstractAucRoc.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRoc.java similarity index 92% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AbstractAucRoc.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRoc.java index b3bb6dcefae19..5d91e6e362669 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AbstractAucRoc.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRoc.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.common; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; @@ -25,8 +25,6 @@ import java.util.List; import java.util.Objects; -import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; - /** * Area under the curve (AUC) of the receiver operating characteristic (ROC). * The ROC curve is a plot of the TPR (true positive rate) against @@ -165,7 +163,7 @@ public static final class AucRocPoint implements Comparable, ToXCon private final double fpr; private final double threshold; - AucRocPoint(double tpr, double fpr, double threshold) { + public AucRocPoint(double tpr, double fpr, double threshold) { this.tpr = tpr; this.fpr = fpr; this.threshold = threshold; @@ -229,24 +227,26 @@ private static double interpolate(double x, double x1, double y1, double x2, dou public static class Result implements EvaluationMetricResult { - private static final String SCORE = "score"; + public static final String NAME = "auc_roc_result"; + + private static final String VALUE = "value"; private static final String CURVE = "curve"; - private final double score; + private final double value; private final List curve; - public Result(double score, List curve) { - this.score = score; + public Result(double value, List curve) { + this.value = value; this.curve = Objects.requireNonNull(curve); } public Result(StreamInput in) throws IOException { - this.score = in.readDouble(); + this.value = in.readDouble(); this.curve = in.readList(AucRocPoint::new); } - public double getScore() { - return score; + public double getValue() { + return value; } public List getCurve() { @@ -255,24 +255,24 @@ public List getCurve() { @Override public String getWriteableName() { - return registeredMetricName(Classification.NAME, NAME); + return NAME; } @Override public String getMetricName() { - return NAME.getPreferredName(); + return AbstractAucRoc.NAME.getPreferredName(); } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeDouble(score); + out.writeDouble(value); out.writeList(curve); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(SCORE, score); + builder.field(VALUE, value); if (curve.isEmpty() == false) { builder.field(CURVE, curve); } @@ -285,13 +285,13 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Result that = (Result) o; - return score == that.score + return value == that.value && Objects.equals(curve, that.curve); } @Override public int hashCode() { - return Objects.hash(score, curve); + return Objects.hash(value, curve); } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AucRoc.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AucRoc.java index fb3ce8da24bef..7d4a50ddb9e11 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AucRoc.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AucRoc.java @@ -23,7 +23,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AbstractAucRoc; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.common.AbstractAucRoc; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java index 176aa6e9a309b..2dfcdcc48978a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java @@ -9,7 +9,6 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result; import java.util.ArrayList; @@ -22,10 +21,10 @@ public class AccuracyResultTests extends AbstractWireSerializingTestCase public static Result createRandom() { int numClasses = randomIntBetween(2, 100); List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); - List classes = new ArrayList<>(numClasses); + List classes = new ArrayList<>(numClasses); for (int i = 0; i < numClasses; i++) { double accuracy = randomDoubleBetween(0.0, 1.0, true); - classes.add(new PerClassResult(classNames.get(i), accuracy)); + classes.add(new PerClassSingleValue(classNames.get(i), accuracy)); } double overallAccuracy = randomDoubleBetween(0.0, 1.0, true); return new Result(classes, overallAccuracy); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java index 44f9bd653c8f0..e9a95df6a3452 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java @@ -12,7 +12,6 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result; import java.io.IOException; @@ -95,8 +94,8 @@ public void testProcess() { result.getClasses(), equalTo( List.of( - new PerClassResult("dog", 0.5), - new PerClassResult("cat", 0.5)))); + new PerClassSingleValue("dog", 0.5), + new PerClassSingleValue("cat", 0.5)))); assertThat(result.getOverallAccuracy(), equalTo(0.5)); } @@ -155,9 +154,9 @@ public void testComputePerClassAccuracy() { 0)), equalTo( List.of( - new Accuracy.PerClassResult("A", 25.0 / 51), // 13 false positives, 13 false negatives - new Accuracy.PerClassResult("B", 26.0 / 51), // 8 false positives, 17 false negatives - new Accuracy.PerClassResult("C", 28.0 / 51))) // 13 false positives, 10 false negatives + new PerClassSingleValue("A", 25.0 / 51), // 13 false positives, 13 false negatives + new PerClassSingleValue("B", 26.0 / 51), // 8 false positives, 17 false negatives + new PerClassSingleValue("C", 28.0 / 51))) // 13 false positives, 10 false negatives ); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRocResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRocResultTests.java index 580c1c85fbf95..ddc39ec6d9c1c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRocResultTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRocResultTests.java @@ -9,8 +9,8 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AbstractAucRoc.AucRocPoint; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AbstractAucRoc.Result; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.common.AbstractAucRoc.AucRocPoint; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.common.AbstractAucRoc.Result; import java.util.List; import java.util.stream.Collectors; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PerClassSingleValueTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PerClassSingleValueTests.java new file mode 100644 index 0000000000000..4cbed2a90c429 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PerClassSingleValueTests.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; + +public class PerClassSingleValueTests extends AbstractSerializingTestCase { + + @Override + protected PerClassSingleValue doParseInstance(XContentParser parser) throws IOException { + return PerClassSingleValue.PARSER.apply(parser, null); + } + + @Override + protected Writeable.Reader instanceReader() { + return PerClassSingleValue::new; + } + + @Override + protected PerClassSingleValue createTestInstance() { + return new PerClassSingleValue(randomAlphaOfLength(10), randomDouble()); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionResultTests.java index b86448a4daacb..c9cdfb8826f27 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionResultTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionResultTests.java @@ -9,7 +9,6 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.PerClassResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.Result; import java.util.ArrayList; @@ -22,10 +21,10 @@ public class PrecisionResultTests extends AbstractWireSerializingTestCase classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); - List classes = new ArrayList<>(numClasses); + List classes = new ArrayList<>(numClasses); for (int i = 0; i < numClasses; i++) { double precision = randomDoubleBetween(0.0, 1.0, true); - classes.add(new PerClassResult(classNames.get(i), precision)); + classes.add(new PerClassSingleValue(classNames.get(i), precision)); } double avgPrecision = randomDoubleBetween(0.0, 1.0, true); return new Result(classes, avgPrecision); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallResultTests.java index a2a44ded76189..7d79a5363b83a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallResultTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallResultTests.java @@ -9,7 +9,6 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.PerClassResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.Result; import java.util.ArrayList; @@ -22,10 +21,10 @@ public class RecallResultTests extends AbstractWireSerializingTestCase { public static Result createRandom() { int numClasses = randomIntBetween(2, 100); List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); - List classes = new ArrayList<>(numClasses); + List classes = new ArrayList<>(numClasses); for (int i = 0; i < numClasses; i++) { double recall = randomDoubleBetween(0.0, 1.0, true); - classes.add(new PerClassResult(classNames.get(i), recall)); + classes.add(new PerClassSingleValue(classNames.get(i), recall)); } double avgRecall = randomDoubleBetween(0.0, 1.0, true); return new Result(classes, avgRecall); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AbstractAucRocTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRocTests.java similarity index 94% rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AbstractAucRocTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRocTests.java index fef8418edb9d7..71666b6972d94 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AbstractAucRocTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/common/AbstractAucRocTests.java @@ -3,9 +3,10 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.common; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AucRoc; import java.util.Arrays; import java.util.List; @@ -20,7 +21,7 @@ public void testCalculateAucScore_GivenZeroPercentiles() { double[] tpPercentiles = zeroPercentiles(); double[] fpPercentiles = zeroPercentiles(); - List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); + List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); double aucRocScore = AucRoc.calculateAucScore(curve); assertThat(aucRocScore, closeTo(0.5, 0.01)); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index b8395940880e8..3deebf757c3a5 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AucRoc; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.PerClassSingleValue; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall; import org.junit.After; @@ -129,7 +130,7 @@ public void testEvaluate_AllMetrics_KeywordField_CaseSensitivity() { List.of(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()))); Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(accuracyResult.getClasses(), contains(new Accuracy.PerClassResult("crocodile", 0.0))); + assertThat(accuracyResult.getClasses(), contains(new PerClassSingleValue("crocodile", 0.0))); assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.0)); MulticlassConfusionMatrix.Result confusionMatrixResult = @@ -145,7 +146,7 @@ public void testEvaluate_AllMetrics_KeywordField_CaseSensitivity() { assertThat(precisionResult.getAvgPrecision(), is(notANumber())); Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(3); - assertThat(recallResult.getClasses(), contains(new Recall.PerClassResult("crocodile", 0.0))); + assertThat(recallResult.getClasses(), contains(new PerClassSingleValue("crocodile", 0.0))); assertThat(recallResult.getAvgRecall(), equalTo(0.0)); } @@ -165,13 +166,13 @@ private AucRoc.Result evaluateAucRoc(boolean includeCurve) { public void testEvaluate_AucRoc_DoNotIncludeCurve() { AucRoc.Result aucrocResult = evaluateAucRoc(false); - assertThat(aucrocResult.getScore(), is(closeTo(0.5, 0.0001))); + assertThat(aucrocResult.getValue(), is(closeTo(0.5, 0.0001))); assertThat(aucrocResult.getCurve(), hasSize(0)); } public void testEvaluate_AucRoc_IncludeCurve() { AucRoc.Result aucrocResult = evaluateAucRoc(true); - assertThat(aucrocResult.getScore(), is(closeTo(0.5, 0.0001))); + assertThat(aucrocResult.getValue(), is(closeTo(0.5, 0.0001))); assertThat(aucrocResult.getCurve(), hasSize(greaterThan(0))); } @@ -188,13 +189,13 @@ private Accuracy.Result evaluateAccuracy(String actualField, String predictedFie } public void testEvaluate_Accuracy_KeywordField() { - List expectedPerClassResults = + List expectedPerClassResults = List.of( - new Accuracy.PerClassResult("ant", 47.0 / 75), - new Accuracy.PerClassResult("cat", 47.0 / 75), - new Accuracy.PerClassResult("dog", 47.0 / 75), - new Accuracy.PerClassResult("fox", 47.0 / 75), - new Accuracy.PerClassResult("mouse", 47.0 / 75)); + new PerClassSingleValue("ant", 47.0 / 75), + new PerClassSingleValue("cat", 47.0 / 75), + new PerClassSingleValue("dog", 47.0 / 75), + new PerClassSingleValue("fox", 47.0 / 75), + new PerClassSingleValue("mouse", 47.0 / 75)); double expectedOverallAccuracy = 5.0 / 75; Accuracy.Result accuracyResult = evaluateAccuracy(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD); @@ -208,13 +209,13 @@ public void testEvaluate_Accuracy_KeywordField() { } public void testEvaluate_Accuracy_IntegerField() { - List expectedPerClassResults = + List expectedPerClassResults = List.of( - new Accuracy.PerClassResult("1", 57.0 / 75), - new Accuracy.PerClassResult("2", 54.0 / 75), - new Accuracy.PerClassResult("3", 51.0 / 75), - new Accuracy.PerClassResult("4", 48.0 / 75), - new Accuracy.PerClassResult("5", 45.0 / 75)); + new PerClassSingleValue("1", 57.0 / 75), + new PerClassSingleValue("2", 54.0 / 75), + new PerClassSingleValue("3", 51.0 / 75), + new PerClassSingleValue("4", 48.0 / 75), + new PerClassSingleValue("5", 45.0 / 75)); double expectedOverallAccuracy = 15.0 / 75; Accuracy.Result accuracyResult = evaluateAccuracy(NO_LEGS_INTEGER_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); @@ -238,10 +239,10 @@ public void testEvaluate_Accuracy_IntegerField() { } public void testEvaluate_Accuracy_BooleanField() { - List expectedPerClassResults = + List expectedPerClassResults = List.of( - new Accuracy.PerClassResult("false", 18.0 / 30), - new Accuracy.PerClassResult("true", 27.0 / 45)); + new PerClassSingleValue("false", 18.0 / 30), + new PerClassSingleValue("true", 27.0 / 45)); double expectedOverallAccuracy = 45.0 / 75; Accuracy.Result accuracyResult = evaluateAccuracy(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); @@ -267,13 +268,13 @@ public void testEvaluate_Accuracy_BooleanField() { public void testEvaluate_Accuracy_FieldTypeMismatch() { { // When actual and predicted fields have different types, the sets of classes are disjoint - List expectedPerClassResults = + List expectedPerClassResults = List.of( - new Accuracy.PerClassResult("1", 0.8), - new Accuracy.PerClassResult("2", 0.8), - new Accuracy.PerClassResult("3", 0.8), - new Accuracy.PerClassResult("4", 0.8), - new Accuracy.PerClassResult("5", 0.8)); + new PerClassSingleValue("1", 0.8), + new PerClassSingleValue("2", 0.8), + new PerClassSingleValue("3", 0.8), + new PerClassSingleValue("4", 0.8), + new PerClassSingleValue("5", 0.8)); double expectedOverallAccuracy = 0.0; Accuracy.Result accuracyResult = evaluateAccuracy(NO_LEGS_INTEGER_FIELD, IS_PREDATOR_BOOLEAN_FIELD); @@ -282,10 +283,10 @@ public void testEvaluate_Accuracy_FieldTypeMismatch() { } { // When actual and predicted fields have different types, the sets of classes are disjoint - List expectedPerClassResults = + List expectedPerClassResults = List.of( - new Accuracy.PerClassResult("false", 0.6), - new Accuracy.PerClassResult("true", 0.4)); + new PerClassSingleValue("false", 0.6), + new PerClassSingleValue("true", 0.4)); double expectedOverallAccuracy = 0.0; Accuracy.Result accuracyResult = evaluateAccuracy(IS_PREDATOR_BOOLEAN_FIELD, NO_LEGS_INTEGER_FIELD); @@ -307,13 +308,13 @@ private Precision.Result evaluatePrecision(String actualField, String predictedF } public void testEvaluate_Precision_KeywordField() { - List expectedPerClassResults = + List expectedPerClassResults = List.of( - new Precision.PerClassResult("ant", 1.0 / 15), - new Precision.PerClassResult("cat", 1.0 / 15), - new Precision.PerClassResult("dog", 1.0 / 15), - new Precision.PerClassResult("fox", 1.0 / 15), - new Precision.PerClassResult("mouse", 1.0 / 15)); + new PerClassSingleValue("ant", 1.0 / 15), + new PerClassSingleValue("cat", 1.0 / 15), + new PerClassSingleValue("dog", 1.0 / 15), + new PerClassSingleValue("fox", 1.0 / 15), + new PerClassSingleValue("mouse", 1.0 / 15)); double expectedAvgPrecision = 5.0 / 75; Precision.Result precisionResult = evaluatePrecision(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD); @@ -324,13 +325,13 @@ public void testEvaluate_Precision_KeywordField() { } public void testEvaluate_Precision_IntegerField() { - List expectedPerClassResults = + List expectedPerClassResults = List.of( - new Precision.PerClassResult("1", 0.2), - new Precision.PerClassResult("2", 0.2), - new Precision.PerClassResult("3", 0.2), - new Precision.PerClassResult("4", 0.2), - new Precision.PerClassResult("5", 0.2)); + new PerClassSingleValue("1", 0.2), + new PerClassSingleValue("2", 0.2), + new PerClassSingleValue("3", 0.2), + new PerClassSingleValue("4", 0.2), + new PerClassSingleValue("5", 0.2)); double expectedAvgPrecision = 0.2; Precision.Result precisionResult = evaluatePrecision(NO_LEGS_INTEGER_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); @@ -348,10 +349,10 @@ public void testEvaluate_Precision_IntegerField() { } public void testEvaluate_Precision_BooleanField() { - List expectedPerClassResults = + List expectedPerClassResults = List.of( - new Precision.PerClassResult("false", 0.5), - new Precision.PerClassResult("true", 9.0 / 13)); + new PerClassSingleValue("false", 0.5), + new PerClassSingleValue("true", 9.0 / 13)); double expectedAvgPrecision = 31.0 / 52; Precision.Result precisionResult = evaluatePrecision(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); @@ -407,13 +408,13 @@ private Recall.Result evaluateRecall(String actualField, String predictedField) } public void testEvaluate_Recall_KeywordField() { - List expectedPerClassResults = + List expectedPerClassResults = List.of( - new Recall.PerClassResult("ant", 1.0 / 15), - new Recall.PerClassResult("cat", 1.0 / 15), - new Recall.PerClassResult("dog", 1.0 / 15), - new Recall.PerClassResult("fox", 1.0 / 15), - new Recall.PerClassResult("mouse", 1.0 / 15)); + new PerClassSingleValue("ant", 1.0 / 15), + new PerClassSingleValue("cat", 1.0 / 15), + new PerClassSingleValue ("dog", 1.0 / 15), + new PerClassSingleValue("fox", 1.0 / 15), + new PerClassSingleValue("mouse", 1.0 / 15)); double expectedAvgRecall = 5.0 / 75; Recall.Result recallResult = evaluateRecall(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD); @@ -424,13 +425,13 @@ public void testEvaluate_Recall_KeywordField() { } public void testEvaluate_Recall_IntegerField() { - List expectedPerClassResults = + List expectedPerClassResults = List.of( - new Recall.PerClassResult("1", 1.0 / 15), - new Recall.PerClassResult("2", 2.0 / 15), - new Recall.PerClassResult("3", 3.0 / 15), - new Recall.PerClassResult("4", 4.0 / 15), - new Recall.PerClassResult("5", 5.0 / 15)); + new PerClassSingleValue("1", 1.0 / 15), + new PerClassSingleValue("2", 2.0 / 15), + new PerClassSingleValue("3", 3.0 / 15), + new PerClassSingleValue("4", 4.0 / 15), + new PerClassSingleValue("5", 5.0 / 15)); double expectedAvgRecall = 3.0 / 15; Recall.Result recallResult = evaluateRecall(NO_LEGS_INTEGER_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); @@ -448,10 +449,10 @@ public void testEvaluate_Recall_IntegerField() { } public void testEvaluate_Recall_BooleanField() { - List expectedPerClassResults = + List expectedPerClassResults = List.of( - new Recall.PerClassResult("true", 0.6), - new Recall.PerClassResult("false", 0.6)); + new PerClassSingleValue("true", 0.6), + new PerClassSingleValue("false", 0.6)); double expectedAvgRecall = 0.6; Recall.Result recallResult = evaluateRecall(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); @@ -471,13 +472,13 @@ public void testEvaluate_Recall_BooleanField() { public void testEvaluate_Recall_FieldTypeMismatch() { { // When actual and predicted fields have different types, the sets of classes are disjoint, hence 0.0 results here - List expectedPerClassResults = + List expectedPerClassResults = List.of( - new Recall.PerClassResult("1", 0.0), - new Recall.PerClassResult("2", 0.0), - new Recall.PerClassResult("3", 0.0), - new Recall.PerClassResult("4", 0.0), - new Recall.PerClassResult("5", 0.0)); + new PerClassSingleValue("1", 0.0), + new PerClassSingleValue("2", 0.0), + new PerClassSingleValue("3", 0.0), + new PerClassSingleValue("4", 0.0), + new PerClassSingleValue("5", 0.0)); double expectedAvgRecall = 0.0; Recall.Result recallResult = evaluateRecall(NO_LEGS_INTEGER_FIELD, IS_PREDATOR_BOOLEAN_FIELD); @@ -486,10 +487,10 @@ public void testEvaluate_Recall_FieldTypeMismatch() { } { // When actual and predicted fields have different types, the sets of classes are disjoint, hence 0.0 results here - List expectedPerClassResults = + List expectedPerClassResults = List.of( - new Recall.PerClassResult("true", 0.0), - new Recall.PerClassResult("false", 0.0)); + new PerClassSingleValue("true", 0.0), + new PerClassSingleValue("false", 0.0)); double expectedAvgRecall = 0.0; Recall.Result recallResult = evaluateRecall(IS_PREDATOR_BOOLEAN_FIELD, NO_LEGS_INTEGER_FIELD); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 81303e647f789..c544ee995a2a1 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -40,6 +40,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AucRoc; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.PerClassSingleValue; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; @@ -969,16 +970,16 @@ private void assertEvaluation(String dependentVariable, List dependentVar { // Accuracy Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); - for (Accuracy.PerClassResult klass : accuracyResult.getClasses()) { + for (PerClassSingleValue klass : accuracyResult.getClasses()) { assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings))); - assertThat(klass.getAccuracy(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))); + assertThat(klass.getValue(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))); } } { // AucRoc AucRoc.Result aucRocResult = (AucRoc.Result) evaluateDataFrameResponse.getMetrics().get(1); assertThat(aucRocResult.getMetricName(), equalTo(AucRoc.NAME.getPreferredName())); - assertThat(aucRocResult.getScore(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))); + assertThat(aucRocResult.getValue(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))); assertThat(aucRocResult.getCurve(), hasSize(greaterThan(0))); } @@ -1004,18 +1005,18 @@ private void assertEvaluation(String dependentVariable, List dependentVar { // Precision Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(3); assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName())); - for (Precision.PerClassResult klass : precisionResult.getClasses()) { + for (PerClassSingleValue klass : precisionResult.getClasses()) { assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings))); - assertThat(klass.getPrecision(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))); + assertThat(klass.getValue(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))); } } { // Recall Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(4); assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName())); - for (Recall.PerClassResult klass : recallResult.getClasses()) { + for (PerClassSingleValue klass : recallResult.getClasses()) { assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings))); - assertThat(klass.getRecall(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))); + assertThat(klass.getValue(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))); } } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionEvaluationIT.java index 22a18654dc872..3bd7be05236da 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionEvaluationIT.java @@ -97,13 +97,13 @@ private AucRoc.Result evaluateAucRoc(String actualField, String predictedField, public void testEvaluate_AucRoc_DoNotIncludeCurve() { AucRoc.Result aucrocResult = evaluateAucRoc(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, false); - assertThat(aucrocResult.getScore(), is(closeTo(1.0, 0.0001))); + assertThat(aucrocResult.getValue(), is(closeTo(1.0, 0.0001))); assertThat(aucrocResult.getCurve(), hasSize(0)); } public void testEvaluate_AucRoc_IncludeCurve() { AucRoc.Result aucrocResult = evaluateAucRoc(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, true); - assertThat(aucrocResult.getScore(), is(closeTo(1.0, 0.0001))); + assertThat(aucrocResult.getValue(), is(closeTo(1.0, 0.0001))); assertThat(aucrocResult.getCurve(), hasSize(greaterThan(0))); } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 944d6a3a80a11..09cf11d266612 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -206,7 +206,7 @@ setup: } } } - - match: { outlier_detection.auc_roc.score: 0.9899 } + - match: { outlier_detection.auc_roc.value: 0.9899 } - is_false: outlier_detection.auc_roc.curve --- @@ -226,7 +226,7 @@ setup: } } } - - match: { outlier_detection.auc_roc.score: 0.9899 } + - match: { outlier_detection.auc_roc.value: 0.9899 } - is_false: outlier_detection.auc_roc.curve --- @@ -246,7 +246,7 @@ setup: } } } - - match: { outlier_detection.auc_roc.score: 0.9899 } + - match: { outlier_detection.auc_roc.value: 0.9899 } - is_true: outlier_detection.auc_roc.curve --- @@ -411,7 +411,7 @@ setup: } } } - - is_true: outlier_detection.auc_roc.score + - is_true: outlier_detection.auc_roc.value - is_true: outlier_detection.precision.0\.25 - is_true: outlier_detection.precision.0\.5 - is_true: outlier_detection.precision.0\.75 @@ -721,7 +721,7 @@ setup: } } } - - match: { classification.auc_roc.score: 0.8050111095212122 } + - match: { classification.auc_roc.value: 0.8050111095212122 } - is_false: classification.auc_roc.curve --- "Test classification auc_roc with default top_classes_field": @@ -741,7 +741,7 @@ setup: } } } - - match: { classification.auc_roc.score: 0.8050111095212122 } + - match: { classification.auc_roc.value: 0.8050111095212122 } - is_false: classification.auc_roc.curve --- "Test classification accuracy with missing predicted_field": @@ -778,11 +778,11 @@ setup: classification.accuracy: classes: - class_name: "cat" - accuracy: 0.625 # 5 out of 8 + value: 0.625 # 5 out of 8 - class_name: "dog" - accuracy: 0.75 # 6 out of 8 + value: 0.75 # 6 out of 8 - class_name: "mouse" - accuracy: 0.875 # 7 out of 8 + value: 0.875 # 7 out of 8 overall_accuracy: 0.625 # 5 out of 8 --- "Test classification precision": @@ -804,11 +804,11 @@ setup: classification.precision: classes: - class_name: "cat" - precision: 0.5 # 2 out of 4 + value: 0.5 # 2 out of 4 - class_name: "dog" - precision: 0.6666666666666666 # 2 out of 3 + value: 0.6666666666666666 # 2 out of 3 - class_name: "mouse" - precision: 1.0 # 1 out of 1 + value: 1.0 # 1 out of 1 avg_precision: 0.7222222222222222 --- "Test classification recall": @@ -830,11 +830,11 @@ setup: classification.recall: classes: - class_name: "cat" - recall: 0.6666666666666666 # 2 out of 3 + value: 0.6666666666666666 # 2 out of 3 - class_name: "dog" - recall: 0.6666666666666666 # 2 out of 3 + value: 0.6666666666666666 # 2 out of 3 - class_name: "mouse" - recall: 0.5 # 1 out of 2 + value: 0.5 # 1 out of 2 avg_recall: 0.611111111111111 --- "Test classification multiclass_confusion_matrix":