Skip to content

Commit bd761cc

Browse files
authored
[ML] Validate that AucRoc has the data necessary to be calculated (#63302) (#63454)
1 parent f453058 commit bd761cc

File tree

13 files changed

+74
-126
lines changed

13 files changed

+74
-126
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetric.java

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -114,27 +114,23 @@ public static Result fromXContent(XContentParser parser) {
114114
}
115115

116116
private static final ParseField SCORE = new ParseField("score");
117-
private static final ParseField DOC_COUNT = new ParseField("doc_count");
118117
private static final ParseField CURVE = new ParseField("curve");
119118

120119
@SuppressWarnings("unchecked")
121120
private static final ConstructingObjectParser<Result, Void> PARSER =
122121
new ConstructingObjectParser<>(
123-
"auc_roc_result", true, args -> new Result((double) args[0], (long) args[1], (List<AucRocPoint>) args[2]));
122+
"auc_roc_result", true, args -> new Result((double) args[0], (List<AucRocPoint>) args[1]));
124123

125124
static {
126125
PARSER.declareDouble(constructorArg(), SCORE);
127-
PARSER.declareLong(constructorArg(), DOC_COUNT);
128126
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE);
129127
}
130128

131129
private final double score;
132-
private final long docCount;
133130
private final List<AucRocPoint> curve;
134131

135-
public Result(double score, long docCount, @Nullable List<AucRocPoint> curve) {
132+
public Result(double score, @Nullable List<AucRocPoint> curve) {
136133
this.score = score;
137-
this.docCount = docCount;
138134
this.curve = curve;
139135
}
140136

@@ -147,10 +143,6 @@ public double getScore() {
147143
return score;
148144
}
149145

150-
public long getDocCount() {
151-
return docCount;
152-
}
153-
154146
public List<AucRocPoint> getCurve() {
155147
return curve == null ? null : Collections.unmodifiableList(curve);
156148
}
@@ -159,7 +151,6 @@ public List<AucRocPoint> getCurve() {
159151
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
160152
builder.startObject();
161153
builder.field(SCORE.getPreferredName(), score);
162-
builder.field(DOC_COUNT.getPreferredName(), docCount);
163154
if (curve != null && curve.isEmpty() == false) {
164155
builder.field(CURVE.getPreferredName(), curve);
165156
}
@@ -173,13 +164,12 @@ public boolean equals(Object o) {
173164
if (o == null || getClass() != o.getClass()) return false;
174165
Result that = (Result) o;
175166
return score == that.score
176-
&& docCount == that.docCount
177167
&& Objects.equals(curve, that.curve);
178168
}
179169

180170
@Override
181171
public int hashCode() {
182-
return Objects.hash(score, docCount, curve);
172+
return Objects.hash(score, curve);
183173
}
184174

185175
@Override

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@
200200
import java.util.Map;
201201
import java.util.concurrent.TimeUnit;
202202
import java.util.stream.Collectors;
203+
import java.util.stream.IntStream;
203204

204205
import static org.hamcrest.Matchers.allOf;
205206
import static org.hamcrest.Matchers.anyOf;
@@ -1931,18 +1932,17 @@ public void testEvaluateDataFrame_Classification() throws IOException {
19311932
createIndex(indexName, mappingForClassification());
19321933
BulkRequest regressionBulk = new BulkRequest()
19331934
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
1934-
.add(docForClassification(indexName, "cat", "cat", 0.9))
1935-
.add(docForClassification(indexName, "cat", "cat", 0.85))
1936-
.add(docForClassification(indexName, "cat", "cat", 0.95))
1937-
.add(docForClassification(indexName, "cat", "dog", 0.4))
1938-
.add(docForClassification(indexName, "cat", "fish", 0.35))
1939-
.add(docForClassification(indexName, "dog", "cat", 0.5))
1940-
.add(docForClassification(indexName, "dog", "dog", 0.4))
1941-
.add(docForClassification(indexName, "dog", "dog", 0.35))
1942-
.add(docForClassification(indexName, "dog", "dog", 0.6))
1943-
.add(docForClassification(indexName, "ant", "cat", 0.1));
1935+
.add(docForClassification(indexName, "cat", "cat", "dog", "ant"))
1936+
.add(docForClassification(indexName, "cat", "cat", "dog", "ant"))
1937+
.add(docForClassification(indexName, "cat", "cat", "horse", "dog"))
1938+
.add(docForClassification(indexName, "cat", "dog", "cat", "mule"))
1939+
.add(docForClassification(indexName, "cat", "fish", "cat", "dog"))
1940+
.add(docForClassification(indexName, "dog", "cat", "dog", "mule"))
1941+
.add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
1942+
.add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
1943+
.add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
1944+
.add(docForClassification(indexName, "ant", "cat", "ant", "wasp"));
19441945
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
1945-
19461946
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
19471947

19481948
{ // AucRoc
@@ -1957,8 +1957,7 @@ public void testEvaluateDataFrame_Classification() throws IOException {
19571957

19581958
AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME);
19591959
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
1960-
assertThat(aucRocResult.getScore(), closeTo(0.99995, 1e-9));
1961-
assertThat(aucRocResult.getDocCount(), equalTo(5L));
1960+
assertThat(aucRocResult.getScore(), closeTo(0.6425, 1e-9));
19621961
assertNotNull(aucRocResult.getCurve());
19631962
}
19641963
{ // Accuracy
@@ -2173,21 +2172,22 @@ private static XContentBuilder mappingForClassification() throws IOException {
21732172
.endObject();
21742173
}
21752174

2176-
private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass, double p) {
2175+
private static IndexRequest docForClassification(String indexName,
2176+
String actualClass,
2177+
String... topPredictedClasses) {
2178+
assert topPredictedClasses.length > 0;
21772179
return new IndexRequest()
21782180
.index(indexName)
21792181
.source(XContentType.JSON,
21802182
actualClassField, actualClass,
2181-
predictedClassField, predictedClass,
2182-
topClassesField, Arrays.asList(
2183-
new HashMap<String, Object>() {{
2184-
put("class_name", predictedClass);
2185-
put("class_probability", p);
2186-
}},
2187-
new HashMap<String, Object>() {{
2188-
put("class_name", "other");
2189-
put("class_probability", 1 - p);
2190-
}}));
2183+
predictedClassField, topPredictedClasses[0],
2184+
topClassesField, IntStream.range(0, topPredictedClasses.length)
2185+
// Consecutive assigned probabilities are: 0.5, 0.25, 0.125, etc.
2186+
.mapToObj(i -> new HashMap<String, Object>() {{
2187+
put("class_name", topPredictedClasses[i]);
2188+
put("class_probability", 1.0 / (2 << i));
2189+
}})
2190+
.collect(Collectors.toList()));
21912191
}
21922192

21932193
private static final String actualRegression = "regression_actual";

client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@
201201
import org.elasticsearch.client.ml.job.results.Influencer;
202202
import org.elasticsearch.client.ml.job.results.OverallBucket;
203203
import org.elasticsearch.client.ml.job.stats.JobStats;
204-
import org.elasticsearch.common.TriFunction;
205204
import org.elasticsearch.common.bytes.BytesReference;
206205
import org.elasticsearch.common.unit.ByteSizeUnit;
207206
import org.elasticsearch.common.unit.ByteSizeValue;
@@ -229,8 +228,11 @@
229228
import java.util.Map;
230229
import java.util.concurrent.CountDownLatch;
231230
import java.util.concurrent.TimeUnit;
231+
import java.util.function.BiFunction;
232232
import java.util.stream.Collectors;
233+
import java.util.stream.IntStream;
233234

235+
import static java.util.stream.Collectors.toList;
234236
import static org.hamcrest.Matchers.allOf;
235237
import static org.hamcrest.Matchers.closeTo;
236238
import static org.hamcrest.Matchers.contains;
@@ -3463,34 +3465,33 @@ public void testEvaluateDataFrame_Classification() throws Exception {
34633465
.endObject()
34643466
.endObject()
34653467
.endObject());
3466-
TriFunction<String, String, Double, IndexRequest> indexRequest = (actualClass, predictedClass, p) -> {
3468+
BiFunction<String, String[], IndexRequest> indexRequest = (actualClass, topPredictedClasses) -> {
3469+
assert topPredictedClasses.length > 0;
34673470
return new IndexRequest()
34683471
.source(XContentType.JSON,
34693472
"actual_class", actualClass,
3470-
"predicted_class", predictedClass,
3471-
"ml.top_classes", Arrays.asList(
3472-
new HashMap<String, Object>() {{
3473-
put("class_name", predictedClass);
3474-
put("class_probability", p);
3475-
}},
3476-
new HashMap<String, Object>() {{
3477-
put("class_name", "other");
3478-
put("class_probability", 1 - p);
3479-
}}));
3473+
"predicted_class", topPredictedClasses[0],
3474+
"ml.top_classes", IntStream.range(0, topPredictedClasses.length)
3475+
// Consecutive assigned probabilities are: 0.5, 0.25, 0.125, etc.
3476+
.mapToObj(i -> new HashMap<String, Object>() {{
3477+
put("class_name", topPredictedClasses[i]);
3478+
put("class_probability", 1.0 / (2 << i));
3479+
}})
3480+
.collect(toList()));
34803481
};
34813482
BulkRequest bulkRequest =
34823483
new BulkRequest(indexName)
34833484
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
3484-
.add(indexRequest.apply("cat", "cat", 0.9)) // #0
3485-
.add(indexRequest.apply("cat", "cat", 0.9)) // #1
3486-
.add(indexRequest.apply("cat", "cat", 0.9)) // #2
3487-
.add(indexRequest.apply("cat", "dog", 0.9)) // #3
3488-
.add(indexRequest.apply("cat", "fox", 0.9)) // #4
3489-
.add(indexRequest.apply("dog", "cat", 0.9)) // #5
3490-
.add(indexRequest.apply("dog", "dog", 0.9)) // #6
3491-
.add(indexRequest.apply("dog", "dog", 0.9)) // #7
3492-
.add(indexRequest.apply("dog", "dog", 0.9)) // #8
3493-
.add(indexRequest.apply("ant", "cat", 0.9)); // #9
3485+
.add(indexRequest.apply("cat", new String[]{"cat", "dog", "ant"})) // #0
3486+
.add(indexRequest.apply("cat", new String[]{"cat", "dog", "ant"})) // #1
3487+
.add(indexRequest.apply("cat", new String[]{"cat", "horse", "dog"})) // #2
3488+
.add(indexRequest.apply("cat", new String[]{"dog", "cat", "mule"})) // #3
3489+
.add(indexRequest.apply("cat", new String[]{"fox", "cat", "dog"})) // #4
3490+
.add(indexRequest.apply("dog", new String[]{"cat", "dog", "mule"})) // #5
3491+
.add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #6
3492+
.add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #7
3493+
.add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #8
3494+
.add(indexRequest.apply("ant", new String[]{"cat", "ant", "wasp"})); // #9
34943495
RestHighLevelClient client = highLevelClient();
34953496
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
34963497
client.bulk(bulkRequest, RequestOptions.DEFAULT);
@@ -3530,7 +3531,6 @@ public void testEvaluateDataFrame_Classification() throws Exception {
35303531

35313532
AucRocMetric.Result aucRocResult = response.getMetricByName(AucRocMetric.NAME); // <10>
35323533
double aucRocScore = aucRocResult.getScore(); // <11>
3533-
Long aucRocDocCount = aucRocResult.getDocCount(); // <12>
35343534
// end::evaluate-data-frame-results-classification
35353535

35363536
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
@@ -3565,8 +3565,7 @@ public void testEvaluateDataFrame_Classification() throws Exception {
35653565
assertThat(otherClassesCount, equalTo(0L));
35663566

35673567
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
3568-
assertThat(aucRocScore, equalTo(0.2625));
3569-
assertThat(aucRocDocCount, equalTo(5L));
3568+
assertThat(aucRocScore, closeTo(0.6425, 1e-9));
35703569
}
35713570
}
35723571

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricResultTests.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ public class AucRocMetricResultTests extends AbstractXContentTestCase<AucRocMetr
3131
public static AucRocMetric.Result randomResult() {
3232
return new AucRocMetric.Result(
3333
randomDouble(),
34-
randomLong(),
3534
Stream
3635
.generate(AucRocMetricAucRocPointTests::randomPoint)
3736
.limit(randomIntBetween(1, 10))

docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ include-tagged::{doc-tests-file}[{api}-results-classification]
121121
<9> Fetching the number of classes that were not included in the matrix
122122
<10> Fetching AucRoc metric by name
123123
<11> Fetching the actual AucRoc score
124-
<12> Fetching the number of documents that were used in order to calculate AucRoc score
125124

126125
===== Regression
127126

docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,8 @@ belongs.
193193
`class_name`::::
194194
(Required, string) Name of the only class that will be treated as
195195
positive during AUC ROC calculation. Other classes will be treated as
196-
negative ("one-vs-all" strategy). Documents which do not have `class_name`
197-
in the list of their top classes will not be taken into account for
198-
evaluation. The number of documents taken into account is returned in the
199-
evaluation result (`auc_roc.doc_count` field).
196+
negative ("one-vs-all" strategy). All the evaluated documents must have `class_name`
197+
in the list of their top classes.
200198

201199
`include_curve`::::
202200
(Optional, boolean) Whether or not the curve should be returned in

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AbstractAucRoc.java

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
*/
66
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
77

8-
import org.elasticsearch.Version;
98
import org.elasticsearch.common.ParseField;
109
import org.elasticsearch.common.Strings;
1110
import org.elasticsearch.common.io.stream.StreamInput;
@@ -231,37 +230,25 @@ private static double interpolate(double x, double x1, double y1, double x2, dou
231230
public static class Result implements EvaluationMetricResult {
232231

233232
private static final String SCORE = "score";
234-
private static final String DOC_COUNT = "doc_count";
235233
private static final String CURVE = "curve";
236234

237235
private final double score;
238-
private final Long docCount;
239236
private final List<AucRocPoint> curve;
240237

241-
public Result(double score, Long docCount, List<AucRocPoint> curve) {
238+
public Result(double score, List<AucRocPoint> curve) {
242239
this.score = score;
243-
this.docCount = docCount;
244240
this.curve = Objects.requireNonNull(curve);
245241
}
246242

247243
public Result(StreamInput in) throws IOException {
248244
this.score = in.readDouble();
249-
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
250-
this.docCount = in.readOptionalLong();
251-
} else {
252-
this.docCount = null;
253-
}
254245
this.curve = in.readList(AucRocPoint::new);
255246
}
256247

257248
public double getScore() {
258249
return score;
259250
}
260251

261-
public Long getDocCount() {
262-
return docCount;
263-
}
264-
265252
public List<AucRocPoint> getCurve() {
266253
return Collections.unmodifiableList(curve);
267254
}
@@ -279,19 +266,13 @@ public String getMetricName() {
279266
@Override
280267
public void writeTo(StreamOutput out) throws IOException {
281268
out.writeDouble(score);
282-
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
283-
out.writeOptionalLong(docCount);
284-
}
285269
out.writeList(curve);
286270
}
287271

288272
@Override
289273
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
290274
builder.startObject();
291275
builder.field(SCORE, score);
292-
if (docCount != null) {
293-
builder.field(DOC_COUNT, docCount);
294-
}
295276
if (curve.isEmpty() == false) {
296277
builder.field(CURVE, curve);
297278
}
@@ -305,13 +286,12 @@ public boolean equals(Object o) {
305286
if (o == null || getClass() != o.getClass()) return false;
306287
Result that = (Result) o;
307288
return score == that.score
308-
&& Objects.equals(docCount, that.docCount)
309289
&& Objects.equals(curve, that.curve);
310290
}
311291

312292
@Override
313293
public int hashCode() {
314-
return Objects.hash(score, docCount, curve);
294+
return Objects.hash(score, curve);
315295
}
316296
}
317297
}

0 commit comments

Comments
 (0)