diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetric.java index 79cb13718a52c..5e9b28303c977 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetric.java @@ -114,27 +114,23 @@ public static Result fromXContent(XContentParser parser) { } private static final ParseField SCORE = new ParseField("score"); - private static final ParseField DOC_COUNT = new ParseField("doc_count"); private static final ParseField CURVE = new ParseField("curve"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "auc_roc_result", true, args -> new Result((double) args[0], (long) args[1], (List) args[2])); + "auc_roc_result", true, args -> new Result((double) args[0], (List) args[1])); static { PARSER.declareDouble(constructorArg(), SCORE); - PARSER.declareLong(constructorArg(), DOC_COUNT); PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE); } private final double score; - private final long docCount; private final List curve; - public Result(double score, long docCount, @Nullable List curve) { + public Result(double score, @Nullable List curve) { this.score = score; - this.docCount = docCount; this.curve = curve; } @@ -147,10 +143,6 @@ public double getScore() { return score; } - public long getDocCount() { - return docCount; - } - public List getCurve() { return curve == null ? null : Collections.unmodifiableList(curve); } @@ -159,7 +151,6 @@ public List getCurve() { public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { builder.startObject(); builder.field(SCORE.getPreferredName(), score); - builder.field(DOC_COUNT.getPreferredName(), docCount); if (curve != null && curve.isEmpty() == false) { builder.field(CURVE.getPreferredName(), curve); } @@ -173,13 +164,12 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; Result that = (Result) o; return score == that.score - && docCount == that.docCount && Objects.equals(curve, that.curve); } @Override public int hashCode() { - return Objects.hash(score, docCount, curve); + return Objects.hash(score, curve); } @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index a817afe305ff0..acf77aec9b394 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -200,6 +200,7 @@ import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; +import java.util.stream.IntStream; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.anyOf; @@ -1931,18 +1932,17 @@ public void testEvaluateDataFrame_Classification() throws IOException { createIndex(indexName, mappingForClassification()); BulkRequest regressionBulk = new BulkRequest() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .add(docForClassification(indexName, "cat", "cat", 0.9)) - .add(docForClassification(indexName, "cat", "cat", 0.85)) - .add(docForClassification(indexName, "cat", "cat", 0.95)) - .add(docForClassification(indexName, "cat", "dog", 0.4)) - .add(docForClassification(indexName, "cat", "fish", 0.35)) - .add(docForClassification(indexName, "dog", "cat", 0.5)) - .add(docForClassification(indexName, "dog", "dog", 0.4)) - .add(docForClassification(indexName, "dog", "dog", 0.35)) - .add(docForClassification(indexName, "dog", "dog", 0.6)) - .add(docForClassification(indexName, "ant", "cat", 0.1)); + .add(docForClassification(indexName, "cat", "cat", "dog", "ant")) + .add(docForClassification(indexName, "cat", "cat", "dog", "ant")) + .add(docForClassification(indexName, "cat", "cat", "horse", "dog")) + .add(docForClassification(indexName, "cat", "dog", "cat", "mule")) + .add(docForClassification(indexName, "cat", "fish", "cat", "dog")) + .add(docForClassification(indexName, "dog", "cat", "dog", "mule")) + .add(docForClassification(indexName, "dog", "dog", "cat", "ant")) + .add(docForClassification(indexName, "dog", "dog", "cat", "ant")) + .add(docForClassification(indexName, "dog", "dog", "cat", "ant")) + .add(docForClassification(indexName, "ant", "cat", "ant", "wasp")); highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT); - MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); { // AucRoc @@ -1957,8 +1957,7 @@ public void testEvaluateDataFrame_Classification() throws IOException { AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME); assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME)); - assertThat(aucRocResult.getScore(), closeTo(0.99995, 1e-9)); - assertThat(aucRocResult.getDocCount(), equalTo(5L)); + assertThat(aucRocResult.getScore(), closeTo(0.6425, 1e-9)); assertNotNull(aucRocResult.getCurve()); } { // Accuracy @@ -2173,21 +2172,22 @@ private static XContentBuilder mappingForClassification() throws IOException { .endObject(); } - private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass, double p) { + private static IndexRequest docForClassification(String indexName, + String actualClass, + String... topPredictedClasses) { + assert topPredictedClasses.length > 0; return new IndexRequest() .index(indexName) .source(XContentType.JSON, actualClassField, actualClass, - predictedClassField, predictedClass, - topClassesField, Arrays.asList( - new HashMap() {{ - put("class_name", predictedClass); - put("class_probability", p); - }}, - new HashMap() {{ - put("class_name", "other"); - put("class_probability", 1 - p); - }})); + predictedClassField, topPredictedClasses[0], + topClassesField, IntStream.range(0, topPredictedClasses.length) + // Consecutive assigned probabilities are: 0.5, 0.25, 0.125, etc. + .mapToObj(i -> new HashMap() {{ + put("class_name", topPredictedClasses[i]); + put("class_probability", 1.0 / (2 << i)); + }}) + .collect(Collectors.toList())); } private static final String actualRegression = "regression_actual"; diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 660cac1176410..1c55452ea22cb 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -201,7 +201,6 @@ import org.elasticsearch.client.ml.job.results.Influencer; import org.elasticsearch.client.ml.job.results.OverallBucket; import org.elasticsearch.client.ml.job.stats.JobStats; -import org.elasticsearch.common.TriFunction; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; @@ -229,8 +228,11 @@ import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.function.BiFunction; import java.util.stream.Collectors; +import java.util.stream.IntStream; +import static java.util.stream.Collectors.toList; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.contains; @@ -3463,34 +3465,33 @@ public void testEvaluateDataFrame_Classification() throws Exception { .endObject() .endObject() .endObject()); - TriFunction indexRequest = (actualClass, predictedClass, p) -> { + BiFunction indexRequest = (actualClass, topPredictedClasses) -> { + assert topPredictedClasses.length > 0; return new IndexRequest() .source(XContentType.JSON, "actual_class", actualClass, - "predicted_class", predictedClass, - "ml.top_classes", Arrays.asList( - new HashMap() {{ - put("class_name", predictedClass); - put("class_probability", p); - }}, - new HashMap() {{ - put("class_name", "other"); - put("class_probability", 1 - p); - }})); + "predicted_class", topPredictedClasses[0], + "ml.top_classes", IntStream.range(0, topPredictedClasses.length) + // Consecutive assigned probabilities are: 0.5, 0.25, 0.125, etc. + .mapToObj(i -> new HashMap() {{ + put("class_name", topPredictedClasses[i]); + put("class_probability", 1.0 / (2 << i)); + }}) + .collect(toList())); }; BulkRequest bulkRequest = new BulkRequest(indexName) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .add(indexRequest.apply("cat", "cat", 0.9)) // #0 - .add(indexRequest.apply("cat", "cat", 0.9)) // #1 - .add(indexRequest.apply("cat", "cat", 0.9)) // #2 - .add(indexRequest.apply("cat", "dog", 0.9)) // #3 - .add(indexRequest.apply("cat", "fox", 0.9)) // #4 - .add(indexRequest.apply("dog", "cat", 0.9)) // #5 - .add(indexRequest.apply("dog", "dog", 0.9)) // #6 - .add(indexRequest.apply("dog", "dog", 0.9)) // #7 - .add(indexRequest.apply("dog", "dog", 0.9)) // #8 - .add(indexRequest.apply("ant", "cat", 0.9)); // #9 + .add(indexRequest.apply("cat", new String[]{"cat", "dog", "ant"})) // #0 + .add(indexRequest.apply("cat", new String[]{"cat", "dog", "ant"})) // #1 + .add(indexRequest.apply("cat", new String[]{"cat", "horse", "dog"})) // #2 + .add(indexRequest.apply("cat", new String[]{"dog", "cat", "mule"})) // #3 + .add(indexRequest.apply("cat", new String[]{"fox", "cat", "dog"})) // #4 + .add(indexRequest.apply("dog", new String[]{"cat", "dog", "mule"})) // #5 + .add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #6 + .add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #7 + .add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #8 + .add(indexRequest.apply("ant", new String[]{"cat", "ant", "wasp"})); // #9 RestHighLevelClient client = highLevelClient(); client.indices().create(createIndexRequest, RequestOptions.DEFAULT); client.bulk(bulkRequest, RequestOptions.DEFAULT); @@ -3530,7 +3531,6 @@ public void testEvaluateDataFrame_Classification() throws Exception { AucRocMetric.Result aucRocResult = response.getMetricByName(AucRocMetric.NAME); // <10> double aucRocScore = aucRocResult.getScore(); // <11> - Long aucRocDocCount = aucRocResult.getDocCount(); // <12> // end::evaluate-data-frame-results-classification assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME)); @@ -3565,8 +3565,7 @@ public void testEvaluateDataFrame_Classification() throws Exception { assertThat(otherClassesCount, equalTo(0L)); assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME)); - assertThat(aucRocScore, equalTo(0.2625)); - assertThat(aucRocDocCount, equalTo(5L)); + assertThat(aucRocScore, closeTo(0.6425, 1e-9)); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricResultTests.java index 9885564835799..40ada86f48445 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricResultTests.java @@ -31,7 +31,6 @@ public class AucRocMetricResultTests extends AbstractXContentTestCase Fetching the number of classes that were not included in the matrix <10> Fetching AucRoc metric by name <11> Fetching the actual AucRoc score -<12> Fetching the number of documents that were used in order to calculate AucRoc score ===== Regression diff --git a/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc b/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc index 48481f1911418..a7c51235e3a32 100644 --- a/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc +++ b/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc @@ -193,10 +193,8 @@ belongs. `class_name`:::: (Required, string) Name of the only class that will be treated as positive during AUC ROC calculation. Other classes will be treated as - negative ("one-vs-all" strategy). Documents which do not have `class_name` - in the list of their top classes will not be taken into account for - evaluation. The number of documents taken into account is returned in the - evaluation result (`auc_roc.doc_count` field). + negative ("one-vs-all" strategy). All the evaluated documents must have `class_name` + in the list of their top classes. `include_curve`:::: (Optional, boolean) Whether or not the curve should be returned in diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AbstractAucRoc.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AbstractAucRoc.java index 30a7a55c3edd3..b3bb6dcefae19 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AbstractAucRoc.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AbstractAucRoc.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; -import org.elasticsearch.Version; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; @@ -231,26 +230,18 @@ private static double interpolate(double x, double x1, double y1, double x2, dou public static class Result implements EvaluationMetricResult { private static final String SCORE = "score"; - private static final String DOC_COUNT = "doc_count"; private static final String CURVE = "curve"; private final double score; - private final Long docCount; private final List curve; - public Result(double score, Long docCount, List curve) { + public Result(double score, List curve) { this.score = score; - this.docCount = docCount; this.curve = Objects.requireNonNull(curve); } public Result(StreamInput in) throws IOException { this.score = in.readDouble(); - if (in.getVersion().onOrAfter(Version.V_7_10_0)) { - this.docCount = in.readOptionalLong(); - } else { - this.docCount = null; - } this.curve = in.readList(AucRocPoint::new); } @@ -258,10 +249,6 @@ public double getScore() { return score; } - public Long getDocCount() { - return docCount; - } - public List getCurve() { return Collections.unmodifiableList(curve); } @@ -279,9 +266,6 @@ public String getMetricName() { @Override public void writeTo(StreamOutput out) throws IOException { out.writeDouble(score); - if (out.getVersion().onOrAfter(Version.V_7_10_0)) { - out.writeOptionalLong(docCount); - } out.writeList(curve); } @@ -289,9 +273,6 @@ public void writeTo(StreamOutput out) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(SCORE, score); - if (docCount != null) { - builder.field(DOC_COUNT, docCount); - } if (curve.isEmpty() == false) { builder.field(CURVE, curve); } @@ -305,13 +286,12 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; Result that = (Result) o; return score == that.score - && Objects.equals(docCount, that.docCount) && Objects.equals(curve, that.curve); } @Override public int hashCode() { - return Objects.hash(score, docCount, curve); + return Objects.hash(score, curve); } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRoc.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRoc.java index e67361e457c04..d8d23ce241fc2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRoc.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRoc.java @@ -183,42 +183,39 @@ public void process(Aggregations aggs) { Filter classAgg = aggs.get(TRUE_AGG_NAME); Nested classNested = classAgg.getAggregations().get(NESTED_AGG_NAME); Filter classNestedFilter = classNested.getAggregations().get(NESTED_FILTER_AGG_NAME); + + Filter restAgg = aggs.get(NON_TRUE_AGG_NAME); + Nested restNested = restAgg.getAggregations().get(NESTED_AGG_NAME); + Filter restNestedFilter = restNested.getAggregations().get(NESTED_FILTER_AGG_NAME); + if (classAgg.getDocCount() == 0) { throw ExceptionsHelper.badRequestException( "[{}] requires at least one [{}] to have the value [{}]", getName(), fields.get().getActualField(), className); } - if (classNestedFilter.getDocCount() == 0) { - throw ExceptionsHelper.badRequestException( - "[{}] requires at least one [{}] to have the value [{}]", - getName(), fields.get().getPredictedClassField(), className); - } - Percentiles classPercentiles = classNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME); - double[] tpPercentiles = percentilesArray(classPercentiles); - - Filter restAgg = aggs.get(NON_TRUE_AGG_NAME); - Nested restNested = restAgg.getAggregations().get(NESTED_AGG_NAME); - Filter restNestedFilter = restNested.getAggregations().get(NESTED_FILTER_AGG_NAME); if (restAgg.getDocCount() == 0) { throw ExceptionsHelper.badRequestException( "[{}] requires at least one [{}] to have a different value than [{}]", getName(), fields.get().getActualField(), className); } - if (restNestedFilter.getDocCount() == 0) { + long filteredDocCount = classNestedFilter.getDocCount() + restNestedFilter.getDocCount(); + long totalDocCount = classAgg.getDocCount() + restAgg.getDocCount(); + if (filteredDocCount < totalDocCount) { throw ExceptionsHelper.badRequestException( - "[{}] requires at least one [{}] to have the value [{}]", - getName(), fields.get().getPredictedClassField(), className); + "[{}] requires that [{}] appears as one of the [{}] for every document (appeared in {} out of {}). " + + "This is probably caused by the {} value being less than the total number of actual classes in the dataset.", + getName(), className, fields.get().getPredictedClassField(), filteredDocCount, totalDocCount, + org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification.NUM_TOP_CLASSES.getPreferredName()); } + + Percentiles classPercentiles = classNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME); + double[] tpPercentiles = percentilesArray(classPercentiles); Percentiles restPercentiles = restNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME); double[] fpPercentiles = percentilesArray(restPercentiles); List aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles); double aucRocScore = calculateAucScore(aucRocCurve); - result.set( - new Result( - aucRocScore, - classNestedFilter.getDocCount() + restNestedFilter.getDocCount(), - includeCurve ? aucRocCurve : Collections.emptyList())); + result.set(new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList())); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AucRoc.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AucRoc.java index 98698dad5329e..777ec2e600187 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AucRoc.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/outlierdetection/AucRoc.java @@ -174,11 +174,7 @@ public void process(Aggregations aggs) { List aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles); double aucRocScore = calculateAucScore(aucRocCurve); - result.set( - new Result( - aucRocScore, - classAgg.getDocCount() + restAgg.getDocCount(), - includeCurve ? aucRocCurve : Collections.emptyList())); + result.set(new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList())); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRocResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRocResultTests.java index 1a03e2e0c2c78..580c1c85fbf95 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRocResultTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRocResultTests.java @@ -20,13 +20,12 @@ public class AucRocResultTests extends AbstractWireSerializingTestCase { public static Result createRandom() { double score = randomDoubleBetween(0.0, 1.0, true); - Long docCount = randomBoolean() ? randomLong() : null; List curve = Stream .generate(() -> new AucRocPoint(randomDouble(), randomDouble(), randomDouble())) .limit(randomIntBetween(0, 20)) .collect(Collectors.toList()); - return new Result(score, docCount, curve); + return new Result(score, curve); } @Override diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index b16c8ce67b241..0eda91c0186eb 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -167,14 +167,12 @@ private AucRoc.Result evaluateAucRoc(boolean includeCurve) { public void testEvaluate_AucRoc_DoNotIncludeCurve() { AucRoc.Result aucrocResult = evaluateAucRoc(false); assertThat(aucrocResult.getScore(), is(closeTo(0.5, 0.0001))); - assertThat(aucrocResult.getDocCount(), is(equalTo(75L))); assertThat(aucrocResult.getCurve(), hasSize(0)); } public void testEvaluate_AucRoc_IncludeCurve() { AucRoc.Result aucrocResult = evaluateAucRoc(true); assertThat(aucrocResult.getScore(), is(closeTo(0.5, 0.0001))); - assertThat(aucrocResult.getDocCount(), is(equalTo(75L))); assertThat(aucrocResult.getCurve(), hasSize(greaterThan(0))); } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 5cc876e752b26..1664192d85f8e 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -981,7 +981,6 @@ private void assertEvaluation(String dependentVariable, List dependentVar AucRoc.Result aucRocResult = (AucRoc.Result) evaluateDataFrameResponse.getMetrics().get(1); assertThat(aucRocResult.getMetricName(), equalTo(AucRoc.NAME.getPreferredName())); assertThat(aucRocResult.getScore(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))); - assertThat(aucRocResult.getDocCount(), allOf(greaterThanOrEqualTo(1L), lessThanOrEqualTo(350L))); assertThat(aucRocResult.getCurve(), hasSize(greaterThan(0))); } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 94282acec9ac1..944d6a3a80a11 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -207,7 +207,6 @@ setup: } } - match: { outlier_detection.auc_roc.score: 0.9899 } - - match: { outlier_detection.auc_roc.doc_count: 8 } - is_false: outlier_detection.auc_roc.curve --- @@ -228,7 +227,6 @@ setup: } } - match: { outlier_detection.auc_roc.score: 0.9899 } - - match: { outlier_detection.auc_roc.doc_count: 8 } - is_false: outlier_detection.auc_roc.curve --- @@ -249,7 +247,6 @@ setup: } } - match: { outlier_detection.auc_roc.score: 0.9899 } - - match: { outlier_detection.auc_roc.doc_count: 8 } - is_true: outlier_detection.auc_roc.curve --- @@ -415,7 +412,6 @@ setup: } } - is_true: outlier_detection.auc_roc.score - - is_true: outlier_detection.auc_roc.doc_count - is_true: outlier_detection.precision.0\.25 - is_true: outlier_detection.precision.0\.5 - is_true: outlier_detection.precision.0\.75 @@ -689,7 +685,7 @@ setup: --- "Test classification auc_roc given predicted_class_field is never equal to mouse": - do: - catch: /\[auc_roc\] requires at least one \[ml.top_classes.class_name\] to have the value \[mouse\]/ + catch: /\[auc_roc\] requires that \[mouse\] appears as one of the \[ml.top_classes.class_name\] for every document \(appeared in 0 out of 8\)./ ml.evaluate_data_frame: body: > { @@ -726,7 +722,6 @@ setup: } } - match: { classification.auc_roc.score: 0.8050111095212122 } - - match: { classification.auc_roc.doc_count: 8 } - is_false: classification.auc_roc.curve --- "Test classification auc_roc with default top_classes_field": @@ -747,7 +742,6 @@ setup: } } - match: { classification.auc_roc.score: 0.8050111095212122 } - - match: { classification.auc_roc.doc_count: 8 } - is_false: classification.auc_roc.curve --- "Test classification accuracy with missing predicted_field":