Skip to content

Commit 7c944d2

Browse files
authored
[7.x] Assert that the results of classification analysis can be evaluated using _evaluate API. (#48626) (#48634)
1 parent 2a863ac commit 7c944d2

File tree

4 files changed

+64
-67
lines changed

4 files changed

+64
-67
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java

+12
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,18 @@ public ActualClass(StreamInput in) throws IOException {
323323
this.otherPredictedClassDocCount = in.readVLong();
324324
}
325325

326+
public String getActualClass() {
327+
return actualClass;
328+
}
329+
330+
public List<PredictedClass> getPredictedClasses() {
331+
return predictedClasses;
332+
}
333+
334+
public long getOtherPredictedClassDocCount() {
335+
return otherPredictedClassDocCount;
336+
}
337+
326338
@Override
327339
public void writeTo(StreamOutput out) throws IOException {
328340
out.writeString(actualClass);

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java

+7-20
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,8 @@ public void cleanup() {
4040
}
4141

4242
public void testEvaluate_MulticlassClassification_DefaultMetrics() {
43-
EvaluateDataFrameAction.Request evaluateDataFrameRequest =
44-
new EvaluateDataFrameAction.Request()
45-
.setIndices(Arrays.asList(ANIMALS_DATA_INDEX))
46-
.setEvaluation(new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, null));
47-
4843
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
49-
client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet();
44+
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, null));
5045

5146
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
5247
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
@@ -105,14 +100,10 @@ public void testEvaluate_MulticlassClassification_DefaultMetrics() {
105100
}
106101

107102
public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefaultSize() {
108-
EvaluateDataFrameAction.Request evaluateDataFrameRequest =
109-
new EvaluateDataFrameAction.Request()
110-
.setIndices(Arrays.asList(ANIMALS_DATA_INDEX))
111-
.setEvaluation(
112-
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
113-
114103
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
115-
client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet();
104+
evaluateDataFrame(
105+
ANIMALS_DATA_INDEX,
106+
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
116107

117108
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
118109
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
@@ -171,14 +162,10 @@ public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefau
171162
}
172163

173164
public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize() {
174-
EvaluateDataFrameAction.Request evaluateDataFrameRequest =
175-
new EvaluateDataFrameAction.Request()
176-
.setIndices(Arrays.asList(ANIMALS_DATA_INDEX))
177-
.setEvaluation(
178-
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3))));
179-
180165
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
181-
client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet();
166+
evaluateDataFrame(
167+
ANIMALS_DATA_INDEX,
168+
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3))));
182169

183170
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
184171
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

+35-47
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,13 @@
1616
import org.elasticsearch.action.search.SearchResponse;
1717
import org.elasticsearch.action.support.WriteRequest;
1818
import org.elasticsearch.search.SearchHit;
19+
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
1920
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
2021
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
2122
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
23+
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
24+
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
25+
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
2226
import org.junit.After;
2327

2428
import java.util.ArrayList;
@@ -28,6 +32,7 @@
2832
import java.util.Map;
2933
import java.util.function.Function;
3034

35+
import static java.util.stream.Collectors.toList;
3136
import static org.hamcrest.Matchers.allOf;
3237
import static org.hamcrest.Matchers.equalTo;
3338
import static org.hamcrest.Matchers.greaterThan;
@@ -48,7 +53,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
4853
private static final List<Boolean> BOOLEAN_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(false, true));
4954
private static final List<Double> NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0));
5055
private static final List<Integer> DISCRETE_NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(10, 20));
51-
private static final List<String> KEYWORD_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat"));
56+
private static final List<String> KEYWORD_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList("cat", "dog"));
5257

5358
private String jobId;
5459
private String sourceIndex;
@@ -97,6 +102,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
97102
"Creating destination index [" + destIndex + "]",
98103
"Finished reindexing to destination index [" + destIndex + "]",
99104
"Finished analysis");
105+
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction");
100106
}
101107

102108
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
@@ -136,11 +142,13 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
136142
"Creating destination index [" + destIndex + "]",
137143
"Finished reindexing to destination index [" + destIndex + "]",
138144
"Finished analysis");
145+
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction");
139146
}
140147

141148
public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
142149
String jobId, String dependentVariable, List<T> dependentVariableValues, Function<String, T> parser) throws Exception {
143150
initialize(jobId);
151+
String predictedClassField = dependentVariable + "_prediction";
144152
indexData(sourceIndex, 300, 0, dependentVariable);
145153

146154
int numTopClasses = 2;
@@ -166,7 +174,6 @@ public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
166174
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
167175
for (SearchHit hit : sourceData.getHits()) {
168176
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
169-
String predictedClassField = dependentVariable + "_prediction";
170177
assertThat(resultsObject.containsKey(predictedClassField), is(true));
171178
T predictedClassValue = parser.apply((String) resultsObject.get(predictedClassField));
172179
assertThat(predictedClassValue, is(in(dependentVariableValues)));
@@ -194,6 +201,10 @@ public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
194201
"Creating destination index [" + destIndex + "]",
195202
"Finished reindexing to destination index [" + destIndex + "]",
196203
"Finished analysis");
204+
assertEvaluation(
205+
dependentVariable,
206+
dependentVariableValues.stream().map(String::valueOf).collect(toList()),
207+
"ml." + predictedClassField);
197208
}
198209

199210
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsKeyword() throws Exception {
@@ -219,51 +230,6 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableI
219230
"classification_training_percent_is_50_boolean", BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, Boolean::valueOf);
220231
}
221232

222-
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested() throws Exception {
223-
initialize("classification_top_classes_requested");
224-
indexData(sourceIndex, 300, 50, KEYWORD_FIELD);
225-
226-
int numTopClasses = 2;
227-
DataFrameAnalyticsConfig config =
228-
buildAnalytics(
229-
jobId,
230-
sourceIndex,
231-
destIndex,
232-
null,
233-
new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, numTopClasses, null));
234-
registerAnalytics(config);
235-
putAnalytics(config);
236-
237-
assertIsStopped(jobId);
238-
assertProgress(jobId, 0, 0, 0, 0);
239-
240-
startAnalytics(jobId);
241-
waitUntilAnalyticsIsStopped(jobId);
242-
243-
client().admin().indices().refresh(new RefreshRequest(destIndex));
244-
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
245-
for (SearchHit hit : sourceData.getHits()) {
246-
Map<String, Object> destDoc = getDestDoc(config, hit);
247-
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
248-
249-
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
250-
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
251-
assertTopClasses(resultsObject, numTopClasses, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf);
252-
}
253-
254-
assertProgress(jobId, 100, 100, 100, 100);
255-
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
256-
assertInferenceModelPersisted(jobId);
257-
assertThatAuditMessagesMatch(jobId,
258-
"Created analytics with analysis type [classification]",
259-
"Estimated memory usage for this analytics to be",
260-
"Starting analytics on node",
261-
"Started analytics",
262-
"Creating destination index [" + destIndex + "]",
263-
"Finished reindexing to destination index [" + destIndex + "]",
264-
"Finished analysis");
265-
}
266-
267233
public void testDependentVariableCardinalityTooHighError() {
268234
initialize("cardinality_too_high");
269235
indexData(sourceIndex, 6, 5, KEYWORD_FIELD);
@@ -377,4 +343,26 @@ private static <T> void assertTopClasses(
377343
// Assert that the top classes are listed in the order of decreasing probabilities.
378344
assertThat(Ordering.natural().reverse().isOrdered(classProbabilities), is(true));
379345
}
346+
347+
private void assertEvaluation(String dependentVariable, List<String> dependentVariableValues, String predictedClassField) {
348+
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
349+
evaluateDataFrame(
350+
destIndex,
351+
new org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification(
352+
dependentVariable, predictedClassField, null));
353+
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
354+
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
355+
MulticlassConfusionMatrix.Result confusionMatrixResult =
356+
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
357+
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
358+
List<ActualClass> actualClasses = confusionMatrixResult.getConfusionMatrix();
359+
assertThat(actualClasses.stream().map(ActualClass::getActualClass).collect(toList()), equalTo(dependentVariableValues));
360+
for (ActualClass actualClass : actualClasses) {
361+
assertThat(actualClass.getOtherPredictedClassDocCount(), equalTo(0L));
362+
assertThat(
363+
actualClass.getPredictedClasses().stream().map(PredictedClass::getPredictedClass).collect(toList()),
364+
equalTo(dependentVariableValues));
365+
}
366+
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
367+
}
380368
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java

+10
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.index.query.QueryBuilders;
1919
import org.elasticsearch.search.sort.SortOrder;
2020
import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction;
21+
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
2122
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction;
2223
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
2324
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
@@ -28,6 +29,7 @@
2829
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
2930
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
3031
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
32+
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
3133
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
3234
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
3335
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
@@ -143,6 +145,14 @@ protected GetDataFrameAnalyticsStatsAction.Response.Stats getAnalyticsStats(Stri
143145
return stats.get(0);
144146
}
145147

148+
protected EvaluateDataFrameAction.Response evaluateDataFrame(String index, Evaluation evaluation) {
149+
EvaluateDataFrameAction.Request request =
150+
new EvaluateDataFrameAction.Request()
151+
.setIndices(Arrays.asList(index))
152+
.setEvaluation(evaluation);
153+
return client().execute(EvaluateDataFrameAction.INSTANCE, request).actionGet();
154+
}
155+
146156
protected static DataFrameAnalyticsConfig buildAnalytics(String id, String sourceIndex, String destIndex,
147157
@Nullable String resultsField, DataFrameAnalysis analysis) {
148158
DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder();

0 commit comments

Comments
 (0)