Skip to content

Commit 44ea9a8

Browse files
committed
Get rid of maxClassesCardinality internal parameter
1 parent c8b0259 commit 44ea9a8

File tree

6 files changed

+109
-39
lines changed

6 files changed

+109
-39
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java

+2-11
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
77

88
import org.apache.lucene.util.SetOnce;
9-
import org.elasticsearch.common.Nullable;
109
import org.elasticsearch.common.ParseField;
1110
import org.elasticsearch.common.collect.Tuple;
1211
import org.elasticsearch.common.io.stream.StreamInput;
@@ -78,26 +77,18 @@ public static Accuracy fromXContent(XContentParser parser) {
7877
return PARSER.apply(parser, null);
7978
}
8079

81-
private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000;
80+
private static final int MAX_CLASSES_CARDINALITY = 1000;
8281

83-
private final int maxClassesCardinality;
8482
private final MulticlassConfusionMatrix matrix;
8583
private final SetOnce<String> actualField = new SetOnce<>();
8684
private final SetOnce<Double> overallAccuracy = new SetOnce<>();
8785
private final SetOnce<Result> result = new SetOnce<>();
8886

8987
public Accuracy() {
90-
this((Integer) null);
91-
}
92-
93-
// Visible for testing
94-
public Accuracy(@Nullable Integer maxClassesCardinality) {
95-
this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY;
96-
this.matrix = new MulticlassConfusionMatrix(this.maxClassesCardinality, NAME.getPreferredName() + "_");
88+
this.matrix = new MulticlassConfusionMatrix(MAX_CLASSES_CARDINALITY, NAME.getPreferredName() + "_");
9789
}
9890

9991
public Accuracy(StreamInput in) throws IOException {
100-
this.maxClassesCardinality = DEFAULT_MAX_CLASSES_CARDINALITY;
10192
this.matrix = new MulticlassConfusionMatrix(in);
10293
}
10394

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ public static MulticlassConfusionMatrix fromXContent(XContentParser parser) {
7171
return PARSER.apply(parser, null);
7272
}
7373

74-
private static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
75-
private static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
76-
private static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
77-
private static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
74+
static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
75+
static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
76+
static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
77+
static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
7878
private static final String OTHER_BUCKET_KEY = "_other_";
7979
private static final String DEFAULT_AGG_NAME_PREFIX = "";
8080
private static final int DEFAULT_SIZE = 10;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> a
108108
.size(MAX_CLASSES_CARDINALITY)),
109109
List.of());
110110
}
111-
if (result == null) { // This is step 2
111+
if (result.get() == null) { // This is step 2
112112
KeyedFilter[] keyedFiltersPredicted =
113113
topActualClassNames.get().stream()
114114
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java

+87
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,26 @@
55
*/
66
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
77

8+
import org.elasticsearch.ElasticsearchStatusException;
89
import org.elasticsearch.common.io.stream.Writeable;
910
import org.elasticsearch.common.xcontent.XContentParser;
11+
import org.elasticsearch.search.aggregations.Aggregations;
1012
import org.elasticsearch.test.AbstractSerializingTestCase;
13+
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult;
14+
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
1115

1216
import java.io.IOException;
1317
import java.util.List;
1418

19+
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality;
20+
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
21+
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
22+
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
23+
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
24+
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
25+
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple;
26+
import static org.hamcrest.Matchers.containsString;
27+
import static org.hamcrest.Matchers.empty;
1528
import static org.hamcrest.Matchers.equalTo;
1629

1730
public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
@@ -40,6 +53,80 @@ public static Accuracy createRandom() {
4053
return new Accuracy();
4154
}
4255

56+
public void testProcess() {
57+
Aggregations aggs = new Aggregations(List.of(
58+
mockTerms(
59+
"accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
60+
List.of(
61+
mockTermsBucket("dog", new Aggregations(List.of())),
62+
mockTermsBucket("cat", new Aggregations(List.of()))),
63+
100L),
64+
mockFilters(
65+
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
66+
List.of(
67+
mockFiltersBucket(
68+
"dog",
69+
30,
70+
new Aggregations(List.of(mockFilters(
71+
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
72+
List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
73+
mockFiltersBucket(
74+
"cat",
75+
70,
76+
new Aggregations(List.of(mockFilters(
77+
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
78+
List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
79+
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1000L),
80+
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
81+
82+
Accuracy accuracy = new Accuracy();
83+
accuracy.process(aggs);
84+
85+
assertThat(accuracy.aggs("act", "pred"), isTuple(empty(), empty()));
86+
87+
Result result = accuracy.getResult().get();
88+
assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
89+
assertThat(
90+
result.getClasses(),
91+
equalTo(
92+
List.of(
93+
new PerClassResult("dog", 0.5),
94+
new PerClassResult("cat", 0.5))));
95+
assertThat(result.getOverallAccuracy(), equalTo(0.5));
96+
}
97+
98+
public void testProcess_GivenCardinalityTooHigh() {
99+
Aggregations aggs = new Aggregations(List.of(
100+
mockTerms(
101+
"accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
102+
List.of(
103+
mockTermsBucket("dog", new Aggregations(List.of())),
104+
mockTermsBucket("cat", new Aggregations(List.of()))),
105+
100L),
106+
mockFilters(
107+
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
108+
List.of(
109+
mockFiltersBucket(
110+
"dog",
111+
30,
112+
new Aggregations(List.of(mockFilters(
113+
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
114+
List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
115+
mockFiltersBucket(
116+
"cat",
117+
70,
118+
new Aggregations(List.of(mockFilters(
119+
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
120+
List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
121+
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1001L),
122+
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
123+
124+
Accuracy accuracy = new Accuracy();
125+
accuracy.aggs("foo", "bar");
126+
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs));
127+
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
128+
}
129+
43130
public void testComputePerClassAccuracy() {
44131
assertThat(
45132
Accuracy.computePerClassAccuracy(

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java

+15-14
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.test.AbstractSerializingTestCase;
1616
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
1717
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
18+
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result;
1819

1920
import java.io.IOException;
2021
import java.util.List;
@@ -85,34 +86,34 @@ public void testAggs() {
8586
public void testEvaluate() {
8687
Aggregations aggs = new Aggregations(List.of(
8788
mockTerms(
88-
"multiclass_confusion_matrix_step_1_by_actual_class",
89+
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
8990
List.of(
9091
mockTermsBucket("dog", new Aggregations(List.of())),
9192
mockTermsBucket("cat", new Aggregations(List.of()))),
9293
0L),
9394
mockFilters(
94-
"multiclass_confusion_matrix_step_2_by_actual_class",
95+
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
9596
List.of(
9697
mockFiltersBucket(
9798
"dog",
9899
30,
99100
new Aggregations(List.of(mockFilters(
100-
"multiclass_confusion_matrix_step_2_by_predicted_class",
101+
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
101102
List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
102103
mockFiltersBucket(
103104
"cat",
104105
70,
105106
new Aggregations(List.of(mockFilters(
106-
"multiclass_confusion_matrix_step_2_by_predicted_class",
107+
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
107108
List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
108-
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L)));
109+
mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 2L)));
109110

110111
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
111112
confusionMatrix.process(aggs);
112113

113114
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
114-
MulticlassConfusionMatrix.Result result = confusionMatrix.getResult().get();
115-
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
115+
Result result = confusionMatrix.getResult().get();
116+
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
116117
assertThat(
117118
result.getConfusionMatrix(),
118119
equalTo(
@@ -125,34 +126,34 @@ public void testEvaluate() {
125126
public void testEvaluate_OtherClassesCountGreaterThanZero() {
126127
Aggregations aggs = new Aggregations(List.of(
127128
mockTerms(
128-
"multiclass_confusion_matrix_step_1_by_actual_class",
129+
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
129130
List.of(
130131
mockTermsBucket("dog", new Aggregations(List.of())),
131132
mockTermsBucket("cat", new Aggregations(List.of()))),
132133
100L),
133134
mockFilters(
134-
"multiclass_confusion_matrix_step_2_by_actual_class",
135+
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
135136
List.of(
136137
mockFiltersBucket(
137138
"dog",
138139
30,
139140
new Aggregations(List.of(mockFilters(
140-
"multiclass_confusion_matrix_step_2_by_predicted_class",
141+
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
141142
List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
142143
mockFiltersBucket(
143144
"cat",
144145
85,
145146
new Aggregations(List.of(mockFilters(
146-
"multiclass_confusion_matrix_step_2_by_predicted_class",
147+
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
147148
List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))),
148-
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L)));
149+
mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 5L)));
149150

150151
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
151152
confusionMatrix.process(aggs);
152153

153154
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
154-
MulticlassConfusionMatrix.Result result = confusionMatrix.getResult().get();
155-
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
155+
Result result = confusionMatrix.getResult().get();
156+
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
156157
assertThat(
157158
result.getConfusionMatrix(),
158159
equalTo(

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java

-9
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,6 @@ public void testEvaluate_Accuracy_BooleanField() {
142142
assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
143143
}
144144

145-
public void testEvaluate_Accuracy_CardinalityTooHigh() {
146-
ElasticsearchStatusException e =
147-
expectThrows(
148-
ElasticsearchStatusException.class,
149-
() -> evaluateDataFrame(
150-
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Accuracy(4)))));
151-
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
152-
}
153-
154145
public void testEvaluate_Precision() {
155146
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
156147
evaluateDataFrame(

0 commit comments

Comments
 (0)