|
5 | 5 | */
|
6 | 6 | package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
7 | 7 |
|
| 8 | +import org.elasticsearch.ElasticsearchStatusException; |
8 | 9 | import org.elasticsearch.common.io.stream.Writeable;
|
9 | 10 | import org.elasticsearch.common.xcontent.XContentParser;
|
| 11 | +import org.elasticsearch.search.aggregations.Aggregations; |
10 | 12 | import org.elasticsearch.test.AbstractSerializingTestCase;
|
| 13 | +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult; |
| 14 | +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result; |
11 | 15 |
|
12 | 16 | import java.io.IOException;
|
13 | 17 | import java.util.List;
|
14 | 18 |
|
| 19 | +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality; |
| 20 | +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters; |
| 21 | +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket; |
| 22 | +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue; |
| 23 | +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms; |
| 24 | +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket; |
| 25 | +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple; |
| 26 | +import static org.hamcrest.Matchers.containsString; |
| 27 | +import static org.hamcrest.Matchers.empty; |
15 | 28 | import static org.hamcrest.Matchers.equalTo;
|
16 | 29 |
|
17 | 30 | public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
@@ -40,6 +53,80 @@ public static Accuracy createRandom() {
|
40 | 53 | return new Accuracy();
|
41 | 54 | }
|
42 | 55 |
|
| 56 | + public void testProcess() { |
| 57 | + Aggregations aggs = new Aggregations(List.of( |
| 58 | + mockTerms( |
| 59 | + "accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS, |
| 60 | + List.of( |
| 61 | + mockTermsBucket("dog", new Aggregations(List.of())), |
| 62 | + mockTermsBucket("cat", new Aggregations(List.of()))), |
| 63 | + 100L), |
| 64 | + mockFilters( |
| 65 | + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, |
| 66 | + List.of( |
| 67 | + mockFiltersBucket( |
| 68 | + "dog", |
| 69 | + 30, |
| 70 | + new Aggregations(List.of(mockFilters( |
| 71 | + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, |
| 72 | + List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), |
| 73 | + mockFiltersBucket( |
| 74 | + "cat", |
| 75 | + 70, |
| 76 | + new Aggregations(List.of(mockFilters( |
| 77 | + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, |
| 78 | + List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), |
| 79 | + mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1000L), |
| 80 | + mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5))); |
| 81 | + |
| 82 | + Accuracy accuracy = new Accuracy(); |
| 83 | + accuracy.process(aggs); |
| 84 | + |
| 85 | + assertThat(accuracy.aggs("act", "pred"), isTuple(empty(), empty())); |
| 86 | + |
| 87 | + Result result = accuracy.getResult().get(); |
| 88 | + assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); |
| 89 | + assertThat( |
| 90 | + result.getClasses(), |
| 91 | + equalTo( |
| 92 | + List.of( |
| 93 | + new PerClassResult("dog", 0.5), |
| 94 | + new PerClassResult("cat", 0.5)))); |
| 95 | + assertThat(result.getOverallAccuracy(), equalTo(0.5)); |
| 96 | + } |
| 97 | + |
| 98 | + public void testProcess_GivenCardinalityTooHigh() { |
| 99 | + Aggregations aggs = new Aggregations(List.of( |
| 100 | + mockTerms( |
| 101 | + "accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS, |
| 102 | + List.of( |
| 103 | + mockTermsBucket("dog", new Aggregations(List.of())), |
| 104 | + mockTermsBucket("cat", new Aggregations(List.of()))), |
| 105 | + 100L), |
| 106 | + mockFilters( |
| 107 | + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, |
| 108 | + List.of( |
| 109 | + mockFiltersBucket( |
| 110 | + "dog", |
| 111 | + 30, |
| 112 | + new Aggregations(List.of(mockFilters( |
| 113 | + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, |
| 114 | + List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), |
| 115 | + mockFiltersBucket( |
| 116 | + "cat", |
| 117 | + 70, |
| 118 | + new Aggregations(List.of(mockFilters( |
| 119 | + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, |
| 120 | + List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), |
| 121 | + mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1001L), |
| 122 | + mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5))); |
| 123 | + |
| 124 | + Accuracy accuracy = new Accuracy(); |
| 125 | + accuracy.aggs("foo", "bar"); |
| 126 | + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs)); |
| 127 | + assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high")); |
| 128 | + } |
| 129 | + |
43 | 130 | public void testComputePerClassAccuracy() {
|
44 | 131 | assertThat(
|
45 | 132 | Accuracy.computePerClassAccuracy(
|
|
0 commit comments