diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java index 9f0150c5b8fe6..1a53b48666a5a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -322,6 +322,18 @@ public ActualClass(StreamInput in) throws IOException { this.otherPredictedClassDocCount = in.readVLong(); } + public String getActualClass() { + return actualClass; + } + + public List getPredictedClasses() { + return predictedClasses; + } + + public long getOtherPredictedClassDocCount() { + return otherPredictedClassDocCount; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(actualClass); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 196ca87fb1213..13f90bf5becdb 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -39,13 +39,8 @@ public void cleanup() { } public void testEvaluate_MulticlassClassification_DefaultMetrics() { - EvaluateDataFrameAction.Request evaluateDataFrameRequest = - new EvaluateDataFrameAction.Request() - .setIndices(List.of(ANIMALS_DATA_INDEX)) - .setEvaluation(new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, null)); - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet(); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, null)); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); @@ -104,13 +99,10 @@ public void testEvaluate_MulticlassClassification_DefaultMetrics() { } public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefaultSize() { - EvaluateDataFrameAction.Request evaluateDataFrameRequest = - new EvaluateDataFrameAction.Request() - .setIndices(List.of(ANIMALS_DATA_INDEX)) - .setEvaluation(new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, List.of(new MulticlassConfusionMatrix()))); - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet(); + evaluateDataFrame( + ANIMALS_DATA_INDEX, + new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, List.of(new MulticlassConfusionMatrix()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); @@ -169,13 +161,10 @@ public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefau } public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize() { - EvaluateDataFrameAction.Request evaluateDataFrameRequest = - new EvaluateDataFrameAction.Request() - .setIndices(List.of(ANIMALS_DATA_INDEX)) - .setEvaluation(new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, List.of(new MulticlassConfusionMatrix(3)))); - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet(); + evaluateDataFrame( + ANIMALS_DATA_INDEX, + new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, List.of(new MulticlassConfusionMatrix(3)))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index e7a0751610334..d354cf4cd4ba7 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -16,9 +16,13 @@ import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass; import org.junit.After; import java.util.ArrayList; @@ -28,6 +32,7 @@ import java.util.Map; import java.util.function.Function; +import static java.util.stream.Collectors.toList; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -48,7 +53,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { private static final List BOOLEAN_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(false, true)); private static final List NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0)); private static final List DISCRETE_NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(10, 20)); - private static final List KEYWORD_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat")); + private static final List KEYWORD_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList("cat", "dog")); private String jobId; private String sourceIndex; @@ -97,6 +102,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws "Creating destination index [" + destIndex + "]", "Finished reindexing to destination index [" + destIndex + "]", "Finished analysis"); + assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction"); } public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception { @@ -136,11 +142,13 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti "Creating destination index [" + destIndex + "]", "Finished reindexing to destination index [" + destIndex + "]", "Finished analysis"); + assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction"); } public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty( String jobId, String dependentVariable, List dependentVariableValues, Function parser) throws Exception { initialize(jobId); + String predictedClassField = dependentVariable + "_prediction"; indexData(sourceIndex, 300, 0, dependentVariable); int numTopClasses = 2; @@ -166,7 +174,6 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty( SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); for (SearchHit hit : sourceData.getHits()) { Map resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit)); - String predictedClassField = dependentVariable + "_prediction"; assertThat(resultsObject.containsKey(predictedClassField), is(true)); T predictedClassValue = parser.apply((String) resultsObject.get(predictedClassField)); assertThat(predictedClassValue, is(in(dependentVariableValues))); @@ -194,6 +201,10 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty( "Creating destination index [" + destIndex + "]", "Finished reindexing to destination index [" + destIndex + "]", "Finished analysis"); + assertEvaluation( + dependentVariable, + dependentVariableValues.stream().map(String::valueOf).collect(toList()), + "ml." + predictedClassField); } public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsKeyword() throws Exception { @@ -219,51 +230,6 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableI "classification_training_percent_is_50_boolean", BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, Boolean::valueOf); } - public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested() throws Exception { - initialize("classification_top_classes_requested"); - indexData(sourceIndex, 300, 50, KEYWORD_FIELD); - - int numTopClasses = 2; - DataFrameAnalyticsConfig config = - buildAnalytics( - jobId, - sourceIndex, - destIndex, - null, - new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, numTopClasses, null)); - registerAnalytics(config); - putAnalytics(config); - - assertIsStopped(jobId); - assertProgress(jobId, 0, 0, 0, 0); - - startAnalytics(jobId); - waitUntilAnalyticsIsStopped(jobId); - - client().admin().indices().refresh(new RefreshRequest(destIndex)); - SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); - for (SearchHit hit : sourceData.getHits()) { - Map destDoc = getDestDoc(config, hit); - Map resultsObject = getMlResultsObjectFromDestDoc(destDoc); - - assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true)); - assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES))); - assertTopClasses(resultsObject, numTopClasses, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf); - } - - assertProgress(jobId, 100, 100, 100, 100); - assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); - assertInferenceModelPersisted(jobId); - assertThatAuditMessagesMatch(jobId, - "Created analytics with analysis type [classification]", - "Estimated memory usage for this analytics to be", - "Starting analytics on node", - "Started analytics", - "Creating destination index [" + destIndex + "]", - "Finished reindexing to destination index [" + destIndex + "]", - "Finished analysis"); - } - public void testDependentVariableCardinalityTooHighError() { initialize("cardinality_too_high"); indexData(sourceIndex, 6, 5, KEYWORD_FIELD); @@ -377,4 +343,26 @@ private static void assertTopClasses( // Assert that the top classes are listed in the order of decreasing probabilities. assertThat(Ordering.natural().reverse().isOrdered(classProbabilities), is(true)); } + + private void assertEvaluation(String dependentVariable, List dependentVariableValues, String predictedClassField) { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + destIndex, + new org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification( + dependentVariable, predictedClassField, null)); + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + MulticlassConfusionMatrix.Result confusionMatrixResult = + (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); + List actualClasses = confusionMatrixResult.getConfusionMatrix(); + assertThat(actualClasses.stream().map(ActualClass::getActualClass).collect(toList()), equalTo(dependentVariableValues)); + for (ActualClass actualClass : actualClasses) { + assertThat(actualClass.getOtherPredictedClassDocCount(), equalTo(0L)); + assertThat( + actualClass.getPredictedClasses().stream().map(PredictedClass::getPredictedClass).collect(toList()), + equalTo(dependentVariableValues)); + } + assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L)); + } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java index 31ceeaf63297a..c988684881308 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java @@ -18,6 +18,7 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; @@ -28,6 +29,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; @@ -143,6 +145,14 @@ protected GetDataFrameAnalyticsStatsAction.Response.Stats getAnalyticsStats(Stri return stats.get(0); } + protected EvaluateDataFrameAction.Response evaluateDataFrame(String index, Evaluation evaluation) { + EvaluateDataFrameAction.Request request = + new EvaluateDataFrameAction.Request() + .setIndices(List.of(index)) + .setEvaluation(evaluation); + return client().execute(EvaluateDataFrameAction.INSTANCE, request).actionGet(); + } + protected static DataFrameAnalyticsConfig buildAnalytics(String id, String sourceIndex, String destIndex, @Nullable String resultsField, DataFrameAnalysis analysis) { DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder();