Skip to content

Commit 03ed7de

Browse files
[ML] Rename evaluation metric result fields to value (#63809)
Renames data frame analytics _evaluate API results as follows: - per class accuracy renamed from `accuracy` to `value` - per class precision renamed from `precision` to `value` - per class recall renamed from `recall` to `value` - auc_roc `score` renamed to `value` for both outlier detection and classification
1 parent 7f2930e commit 03ed7de

File tree

38 files changed

+686
-763
lines changed

38 files changed

+686
-763
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
2525
import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric;
2626
import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric;
27+
import org.elasticsearch.client.ml.dataframe.evaluation.common.AucRocResult;
2728
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
2829
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
2930
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
@@ -122,10 +123,9 @@ Evaluation.class, new ParseField(OutlierDetection.NAME), OutlierDetection::fromX
122123
// Evaluation metrics results
123124
new NamedXContentRegistry.Entry(
124125
EvaluationMetric.Result.class,
125-
new ParseField(
126-
registeredMetricName(
127-
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME)),
128-
org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric.Result::fromXContent),
126+
new ParseField(registeredMetricName(
127+
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME)),
128+
AucRocResult::fromXContent),
129129
new NamedXContentRegistry.Entry(
130130
EvaluationMetric.Result.class,
131131
new ParseField(
@@ -145,7 +145,7 @@ Evaluation.class, new ParseField(OutlierDetection.NAME), OutlierDetection::fromX
145145
new NamedXContentRegistry.Entry(
146146
EvaluationMetric.Result.class,
147147
new ParseField(registeredMetricName(Classification.NAME, AucRocMetric.NAME)),
148-
AucRocMetric.Result::fromXContent),
148+
AucRocResult::fromXContent),
149149
new NamedXContentRegistry.Entry(
150150
EvaluationMetric.Result.class,
151151
new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java

Lines changed: 5 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,9 @@
2020

2121
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
2222
import org.elasticsearch.common.ParseField;
23-
import org.elasticsearch.common.Strings;
2423
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
2524
import org.elasticsearch.common.xcontent.ObjectParser;
2625
import org.elasticsearch.common.xcontent.ToXContent;
27-
import org.elasticsearch.common.xcontent.ToXContentObject;
2826
import org.elasticsearch.common.xcontent.XContentBuilder;
2927
import org.elasticsearch.common.xcontent.XContentParser;
3028

@@ -99,10 +97,10 @@ public static class Result implements EvaluationMetric.Result {
9997

10098
@SuppressWarnings("unchecked")
10199
private static final ConstructingObjectParser<Result, Void> PARSER =
102-
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
100+
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassSingleValue>) a[0], (double) a[1]));
103101

104102
static {
105-
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
103+
PARSER.declareObjectArray(constructorArg(), PerClassSingleValue.PARSER, CLASSES);
106104
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
107105
}
108106

@@ -111,11 +109,11 @@ public static Result fromXContent(XContentParser parser) {
111109
}
112110

113111
/** List of per-class results. */
114-
private final List<PerClassResult> classes;
112+
private final List<PerClassSingleValue> classes;
115113
/** Fraction of documents for which predicted class equals the actual class. */
116114
private final double overallAccuracy;
117115

118-
public Result(List<PerClassResult> classes, double overallAccuracy) {
116+
public Result(List<PerClassSingleValue> classes, double overallAccuracy) {
119117
this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes));
120118
this.overallAccuracy = overallAccuracy;
121119
}
@@ -125,7 +123,7 @@ public String getMetricName() {
125123
return NAME;
126124
}
127125

128-
public List<PerClassResult> getClasses() {
126+
public List<PerClassSingleValue> getClasses() {
129127
return classes;
130128
}
131129

@@ -156,65 +154,4 @@ public int hashCode() {
156154
return Objects.hash(classes, overallAccuracy);
157155
}
158156
}
159-
160-
public static class PerClassResult implements ToXContentObject {
161-
162-
private static final ParseField CLASS_NAME = new ParseField("class_name");
163-
private static final ParseField ACCURACY = new ParseField("accuracy");
164-
165-
@SuppressWarnings("unchecked")
166-
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
167-
new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
168-
169-
static {
170-
PARSER.declareString(constructorArg(), CLASS_NAME);
171-
PARSER.declareDouble(constructorArg(), ACCURACY);
172-
}
173-
174-
/** Name of the class. */
175-
private final String className;
176-
/** Fraction of documents that are either true positives or true negatives wrt {@code className}. */
177-
private final double accuracy;
178-
179-
public PerClassResult(String className, double accuracy) {
180-
this.className = Objects.requireNonNull(className);
181-
this.accuracy = accuracy;
182-
}
183-
184-
public String getClassName() {
185-
return className;
186-
}
187-
188-
public double getAccuracy() {
189-
return accuracy;
190-
}
191-
192-
@Override
193-
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
194-
builder.startObject();
195-
builder.field(CLASS_NAME.getPreferredName(), className);
196-
builder.field(ACCURACY.getPreferredName(), accuracy);
197-
builder.endObject();
198-
return builder;
199-
}
200-
201-
@Override
202-
public boolean equals(Object o) {
203-
if (this == o) return true;
204-
if (o == null || getClass() != o.getClass()) return false;
205-
PerClassResult that = (PerClassResult) o;
206-
return Objects.equals(this.className, that.className)
207-
&& this.accuracy == that.accuracy;
208-
}
209-
210-
@Override
211-
public int hashCode() {
212-
return Objects.hash(className, accuracy);
213-
}
214-
215-
@Override
216-
public String toString() {
217-
return Strings.toString(this);
218-
}
219-
}
220157
}

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetric.java

Lines changed: 2 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,14 @@
1919
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
2020

2121
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
22-
import org.elasticsearch.common.Nullable;
22+
import org.elasticsearch.client.ml.dataframe.evaluation.common.AucRocResult;
2323
import org.elasticsearch.common.ParseField;
24-
import org.elasticsearch.common.Strings;
2524
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
2625
import org.elasticsearch.common.xcontent.ToXContent;
27-
import org.elasticsearch.common.xcontent.ToXContentObject;
2826
import org.elasticsearch.common.xcontent.XContentBuilder;
2927
import org.elasticsearch.common.xcontent.XContentParser;
3028

3129
import java.io.IOException;
32-
import java.util.Collections;
33-
import java.util.List;
3430
import java.util.Objects;
3531

3632
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
@@ -43,12 +39,11 @@
4339
*/
4440
public class AucRocMetric implements EvaluationMetric {
4541

46-
public static final String NAME = "auc_roc";
42+
public static final String NAME = AucRocResult.NAME;
4743

4844
public static final ParseField CLASS_NAME = new ParseField("class_name");
4945
public static final ParseField INCLUDE_CURVE = new ParseField("include_curve");
5046

51-
@SuppressWarnings("unchecked")
5247
public static final ConstructingObjectParser<AucRocMetric, Void> PARSER =
5348
new ConstructingObjectParser<>(NAME, true, args -> new AucRocMetric((String) args[0], (Boolean) args[1]));
5449

@@ -106,149 +101,4 @@ public boolean equals(Object o) {
106101
public int hashCode() {
107102
return Objects.hash(className, includeCurve);
108103
}
109-
110-
public static class Result implements EvaluationMetric.Result {
111-
112-
public static Result fromXContent(XContentParser parser) {
113-
return PARSER.apply(parser, null);
114-
}
115-
116-
private static final ParseField SCORE = new ParseField("score");
117-
private static final ParseField CURVE = new ParseField("curve");
118-
119-
@SuppressWarnings("unchecked")
120-
private static final ConstructingObjectParser<Result, Void> PARSER =
121-
new ConstructingObjectParser<>(
122-
"auc_roc_result", true, args -> new Result((double) args[0], (List<AucRocPoint>) args[1]));
123-
124-
static {
125-
PARSER.declareDouble(constructorArg(), SCORE);
126-
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE);
127-
}
128-
129-
private final double score;
130-
private final List<AucRocPoint> curve;
131-
132-
public Result(double score, @Nullable List<AucRocPoint> curve) {
133-
this.score = score;
134-
this.curve = curve;
135-
}
136-
137-
@Override
138-
public String getMetricName() {
139-
return NAME;
140-
}
141-
142-
public double getScore() {
143-
return score;
144-
}
145-
146-
public List<AucRocPoint> getCurve() {
147-
return curve == null ? null : Collections.unmodifiableList(curve);
148-
}
149-
150-
@Override
151-
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
152-
builder.startObject();
153-
builder.field(SCORE.getPreferredName(), score);
154-
if (curve != null && curve.isEmpty() == false) {
155-
builder.field(CURVE.getPreferredName(), curve);
156-
}
157-
builder.endObject();
158-
return builder;
159-
}
160-
161-
@Override
162-
public boolean equals(Object o) {
163-
if (this == o) return true;
164-
if (o == null || getClass() != o.getClass()) return false;
165-
Result that = (Result) o;
166-
return score == that.score
167-
&& Objects.equals(curve, that.curve);
168-
}
169-
170-
@Override
171-
public int hashCode() {
172-
return Objects.hash(score, curve);
173-
}
174-
175-
@Override
176-
public String toString() {
177-
return Strings.toString(this);
178-
}
179-
}
180-
181-
public static final class AucRocPoint implements ToXContentObject {
182-
183-
public static AucRocPoint fromXContent(XContentParser parser) {
184-
return PARSER.apply(parser, null);
185-
}
186-
187-
private static final ParseField TPR = new ParseField("tpr");
188-
private static final ParseField FPR = new ParseField("fpr");
189-
private static final ParseField THRESHOLD = new ParseField("threshold");
190-
191-
@SuppressWarnings("unchecked")
192-
private static final ConstructingObjectParser<AucRocPoint, Void> PARSER =
193-
new ConstructingObjectParser<>(
194-
"auc_roc_point",
195-
true,
196-
args -> new AucRocPoint((double) args[0], (double) args[1], (double) args[2]));
197-
198-
static {
199-
PARSER.declareDouble(constructorArg(), TPR);
200-
PARSER.declareDouble(constructorArg(), FPR);
201-
PARSER.declareDouble(constructorArg(), THRESHOLD);
202-
}
203-
204-
private final double tpr;
205-
private final double fpr;
206-
private final double threshold;
207-
208-
public AucRocPoint(double tpr, double fpr, double threshold) {
209-
this.tpr = tpr;
210-
this.fpr = fpr;
211-
this.threshold = threshold;
212-
}
213-
214-
public double getTruePositiveRate() {
215-
return tpr;
216-
}
217-
218-
public double getFalsePositiveRate() {
219-
return fpr;
220-
}
221-
222-
public double getThreshold() {
223-
return threshold;
224-
}
225-
226-
@Override
227-
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
228-
return builder
229-
.startObject()
230-
.field(TPR.getPreferredName(), tpr)
231-
.field(FPR.getPreferredName(), fpr)
232-
.field(THRESHOLD.getPreferredName(), threshold)
233-
.endObject();
234-
}
235-
236-
@Override
237-
public boolean equals(Object o) {
238-
if (this == o) return true;
239-
if (o == null || getClass() != o.getClass()) return false;
240-
AucRocPoint that = (AucRocPoint) o;
241-
return tpr == that.tpr && fpr == that.fpr && threshold == that.threshold;
242-
}
243-
244-
@Override
245-
public int hashCode() {
246-
return Objects.hash(tpr, fpr, threshold);
247-
}
248-
249-
@Override
250-
public String toString() {
251-
return Strings.toString(this);
252-
}
253-
}
254104
}

0 commit comments

Comments
 (0)