Skip to content

[ML] Rename evaluation metric result fields to value #63809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.common.AucRocResult;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
Expand Down Expand Up @@ -122,10 +123,9 @@ Evaluation.class, new ParseField(OutlierDetection.NAME), OutlierDetection::fromX
// Evaluation metrics results
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(
registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME)),
org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric.Result::fromXContent),
new ParseField(registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME)),
AucRocResult::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(
Expand All @@ -145,7 +145,7 @@ Evaluation.class, new ParseField(OutlierDetection.NAME), OutlierDetection::fromX
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Classification.NAME, AucRocMetric.NAME)),
AucRocMetric.Result::fromXContent),
AucRocResult::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@

import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

Expand Down Expand Up @@ -99,10 +97,10 @@ public static class Result implements EvaluationMetric.Result {

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassSingleValue>) a[0], (double) a[1]));

static {
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
PARSER.declareObjectArray(constructorArg(), PerClassSingleValue.PARSER, CLASSES);
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
}

Expand All @@ -111,11 +109,11 @@ public static Result fromXContent(XContentParser parser) {
}

/** List of per-class results. */
private final List<PerClassResult> classes;
private final List<PerClassSingleValue> classes;
/** Fraction of documents for which predicted class equals the actual class. */
private final double overallAccuracy;

public Result(List<PerClassResult> classes, double overallAccuracy) {
public Result(List<PerClassSingleValue> classes, double overallAccuracy) {
this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes));
this.overallAccuracy = overallAccuracy;
}
Expand All @@ -125,7 +123,7 @@ public String getMetricName() {
return NAME;
}

public List<PerClassResult> getClasses() {
public List<PerClassSingleValue> getClasses() {
return classes;
}

Expand Down Expand Up @@ -156,65 +154,4 @@ public int hashCode() {
return Objects.hash(classes, overallAccuracy);
}
}

public static class PerClassResult implements ToXContentObject {

private static final ParseField CLASS_NAME = new ParseField("class_name");
private static final ParseField ACCURACY = new ParseField("accuracy");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));

static {
PARSER.declareString(constructorArg(), CLASS_NAME);
PARSER.declareDouble(constructorArg(), ACCURACY);
}

/** Name of the class. */
private final String className;
/** Fraction of documents that are either true positives or true negatives wrt {@code className}. */
private final double accuracy;

public PerClassResult(String className, double accuracy) {
this.className = Objects.requireNonNull(className);
this.accuracy = accuracy;
}

public String getClassName() {
return className;
}

public double getAccuracy() {
return accuracy;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASS_NAME.getPreferredName(), className);
builder.field(ACCURACY.getPreferredName(), accuracy);
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PerClassResult that = (PerClassResult) o;
return Objects.equals(this.className, that.className)
&& this.accuracy == that.accuracy;
}

@Override
public int hashCode() {
return Objects.hash(className, accuracy);
}

@Override
public String toString() {
return Strings.toString(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,14 @@
package org.elasticsearch.client.ml.dataframe.evaluation.classification;

import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.client.ml.dataframe.evaluation.common.AucRocResult;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
Expand All @@ -43,12 +39,11 @@
*/
public class AucRocMetric implements EvaluationMetric {

public static final String NAME = "auc_roc";
public static final String NAME = AucRocResult.NAME;

public static final ParseField CLASS_NAME = new ParseField("class_name");
public static final ParseField INCLUDE_CURVE = new ParseField("include_curve");

@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<AucRocMetric, Void> PARSER =
new ConstructingObjectParser<>(NAME, true, args -> new AucRocMetric((String) args[0], (Boolean) args[1]));

Expand Down Expand Up @@ -106,149 +101,4 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(className, includeCurve);
}

public static class Result implements EvaluationMetric.Result {

public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private static final ParseField SCORE = new ParseField("score");
private static final ParseField CURVE = new ParseField("curve");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>(
"auc_roc_result", true, args -> new Result((double) args[0], (List<AucRocPoint>) args[1]));

static {
PARSER.declareDouble(constructorArg(), SCORE);
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE);
}

private final double score;
private final List<AucRocPoint> curve;

public Result(double score, @Nullable List<AucRocPoint> curve) {
this.score = score;
this.curve = curve;
}

@Override
public String getMetricName() {
return NAME;
}

public double getScore() {
return score;
}

public List<AucRocPoint> getCurve() {
return curve == null ? null : Collections.unmodifiableList(curve);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(SCORE.getPreferredName(), score);
if (curve != null && curve.isEmpty() == false) {
builder.field(CURVE.getPreferredName(), curve);
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return score == that.score
&& Objects.equals(curve, that.curve);
}

@Override
public int hashCode() {
return Objects.hash(score, curve);
}

@Override
public String toString() {
return Strings.toString(this);
}
}

public static final class AucRocPoint implements ToXContentObject {

public static AucRocPoint fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private static final ParseField TPR = new ParseField("tpr");
private static final ParseField FPR = new ParseField("fpr");
private static final ParseField THRESHOLD = new ParseField("threshold");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<AucRocPoint, Void> PARSER =
new ConstructingObjectParser<>(
"auc_roc_point",
true,
args -> new AucRocPoint((double) args[0], (double) args[1], (double) args[2]));

static {
PARSER.declareDouble(constructorArg(), TPR);
PARSER.declareDouble(constructorArg(), FPR);
PARSER.declareDouble(constructorArg(), THRESHOLD);
}

private final double tpr;
private final double fpr;
private final double threshold;

public AucRocPoint(double tpr, double fpr, double threshold) {
this.tpr = tpr;
this.fpr = fpr;
this.threshold = threshold;
}

public double getTruePositiveRate() {
return tpr;
}

public double getFalsePositiveRate() {
return fpr;
}

public double getThreshold() {
return threshold;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder
.startObject()
.field(TPR.getPreferredName(), tpr)
.field(FPR.getPreferredName(), fpr)
.field(THRESHOLD.getPreferredName(), threshold)
.endObject();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
AucRocPoint that = (AucRocPoint) o;
return tpr == that.tpr && fpr == that.fpr && threshold == that.threshold;
}

@Override
public int hashCode() {
return Objects.hash(tpr, fpr, threshold);
}

@Override
public String toString() {
return Strings.toString(this);
}
}
}
Loading