13
13
import org .elasticsearch .common .io .stream .StreamInput ;
14
14
import org .elasticsearch .common .io .stream .StreamOutput ;
15
15
import org .elasticsearch .common .io .stream .Writeable ;
16
+ import org .elasticsearch .common .settings .Settings ;
16
17
import org .elasticsearch .common .xcontent .ConstructingObjectParser ;
17
18
import org .elasticsearch .common .xcontent .ToXContentObject ;
18
19
import org .elasticsearch .common .xcontent .XContentBuilder ;
22
23
import org .elasticsearch .search .aggregations .AggregationBuilders ;
23
24
import org .elasticsearch .search .aggregations .Aggregations ;
24
25
import org .elasticsearch .search .aggregations .BucketOrder ;
26
+ import org .elasticsearch .search .aggregations .MultiBucketConsumerService ;
25
27
import org .elasticsearch .search .aggregations .PipelineAggregationBuilder ;
26
28
import org .elasticsearch .search .aggregations .bucket .filter .Filters ;
27
29
import org .elasticsearch .search .aggregations .bucket .filter .FiltersAggregator .KeyedFilter ;
@@ -61,7 +63,9 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
61
63
private static ConstructingObjectParser <MulticlassConfusionMatrix , Void > createParser () {
62
64
ConstructingObjectParser <MulticlassConfusionMatrix , Void > parser =
63
65
new ConstructingObjectParser <>(
64
- NAME .getPreferredName (), true , args -> new MulticlassConfusionMatrix ((Integer ) args [0 ], (String ) args [1 ]));
66
+ NAME .getPreferredName (),
67
+ true ,
68
+ args -> new MulticlassConfusionMatrix ((Integer ) args [0 ], (String ) args [1 ]));
65
69
parser .declareInt (optionalConstructorArg (), SIZE );
66
70
parser .declareString (optionalConstructorArg (), AGG_NAME_PREFIX );
67
71
return parser ;
@@ -72,9 +76,9 @@ public static MulticlassConfusionMatrix fromXContent(XContentParser parser) {
72
76
}
73
77
74
78
static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME .getPreferredName () + "_step_1_by_actual_class" ;
79
+ static final String STEP_1_CARDINALITY_OF_ACTUAL_CLASS = NAME .getPreferredName () + "_step_1_cardinality_of_actual_class" ;
75
80
static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME .getPreferredName () + "_step_2_by_actual_class" ;
76
81
static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME .getPreferredName () + "_step_2_by_predicted_class" ;
77
- static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME .getPreferredName () + "_step_2_cardinality_of_actual_class" ;
78
82
private static final String OTHER_BUCKET_KEY = "_other_" ;
79
83
private static final String DEFAULT_AGG_NAME_PREFIX = "" ;
80
84
private static final int DEFAULT_SIZE = 10 ;
@@ -83,6 +87,9 @@ public static MulticlassConfusionMatrix fromXContent(XContentParser parser) {
83
87
private final int size ;
84
88
private final String aggNamePrefix ;
85
89
private final SetOnce <List <String >> topActualClassNames = new SetOnce <>();
90
+ private final SetOnce <Long > actualClassesCardinality = new SetOnce <>();
91
+ /** Accumulates actual classes processed so far. It may take more than 1 call to #process method to fill this field completely. */
92
+ private final List <ActualClass > actualClasses = new ArrayList <>();
86
93
private final SetOnce <Result > result = new SetOnce <>();
87
94
88
95
public MulticlassConfusionMatrix () {
@@ -121,34 +128,46 @@ public int getSize() {
121
128
}
122
129
123
130
@ Override
124
- public final Tuple <List <AggregationBuilder >, List <PipelineAggregationBuilder >> aggs (String actualField , String predictedField ) {
125
- if (topActualClassNames .get () == null ) { // This is step 1
131
+ public final Tuple <List <AggregationBuilder >, List <PipelineAggregationBuilder >> aggs (Settings settings ,
132
+ String actualField ,
133
+ String predictedField ) {
134
+ int maxBuckets = MultiBucketConsumerService .MAX_BUCKET_SETTING .get (settings );
135
+ if (topActualClassNames .get () == null && actualClassesCardinality .get () == null ) { // This is step 1
126
136
return Tuple .tuple (
127
137
List .of (
128
138
AggregationBuilders .terms (aggName (STEP_1_AGGREGATE_BY_ACTUAL_CLASS ))
129
139
.field (actualField )
130
140
.order (List .of (BucketOrder .count (false ), BucketOrder .key (true )))
131
- .size (size )),
141
+ .size (size ),
142
+ AggregationBuilders .cardinality (aggName (STEP_1_CARDINALITY_OF_ACTUAL_CLASS ))
143
+ .field (actualField )),
132
144
List .of ());
133
145
}
134
- if (result .get () == null ) { // This is step 2
135
- KeyedFilter [] keyedFiltersActual =
136
- topActualClassNames .get ().stream ()
137
- .map (className -> new KeyedFilter (className , QueryBuilders .termQuery (actualField , className )))
138
- .toArray (KeyedFilter []::new );
146
+ if (result .get () == null ) { // These are steps 2, 3, 4 etc.
139
147
KeyedFilter [] keyedFiltersPredicted =
140
148
topActualClassNames .get ().stream ()
141
149
.map (className -> new KeyedFilter (className , QueryBuilders .termQuery (predictedField , className )))
142
150
.toArray (KeyedFilter []::new );
143
- return Tuple .tuple (
144
- List .of (
145
- AggregationBuilders .cardinality (aggName (STEP_2_CARDINALITY_OF_ACTUAL_CLASS ))
146
- .field (actualField ),
147
- AggregationBuilders .filters (aggName (STEP_2_AGGREGATE_BY_ACTUAL_CLASS ), keyedFiltersActual )
148
- .subAggregation (AggregationBuilders .filters (aggName (STEP_2_AGGREGATE_BY_PREDICTED_CLASS ), keyedFiltersPredicted )
149
- .otherBucket (true )
150
- .otherBucketKey (OTHER_BUCKET_KEY ))),
151
- List .of ());
151
+ // Knowing exactly how many buckets does each aggregation use, we can choose the size of the batch so that
152
+ // too_many_buckets_exception exception is not thrown.
153
+ // The only exception is when "search.max_buckets" is set far too low to even have 1 actual class in the batch.
154
+ // In such case, the exception will be thrown telling the user they should increase the value of "search.max_buckets".
155
+ int actualClassesPerBatch = Math .max (maxBuckets / (topActualClassNames .get ().size () + 2 ), 1 );
156
+ KeyedFilter [] keyedFiltersActual =
157
+ topActualClassNames .get ().stream ()
158
+ .skip (actualClasses .size ())
159
+ .limit (actualClassesPerBatch )
160
+ .map (className -> new KeyedFilter (className , QueryBuilders .termQuery (actualField , className )))
161
+ .toArray (KeyedFilter []::new );
162
+ if (keyedFiltersActual .length > 0 ) {
163
+ return Tuple .tuple (
164
+ List .of (
165
+ AggregationBuilders .filters (aggName (STEP_2_AGGREGATE_BY_ACTUAL_CLASS ), keyedFiltersActual )
166
+ .subAggregation (AggregationBuilders .filters (aggName (STEP_2_AGGREGATE_BY_PREDICTED_CLASS ), keyedFiltersPredicted )
167
+ .otherBucket (true )
168
+ .otherBucketKey (OTHER_BUCKET_KEY ))),
169
+ List .of ());
170
+ }
152
171
}
153
172
return Tuple .tuple (List .of (), List .of ());
154
173
}
@@ -159,10 +178,12 @@ public void process(Aggregations aggs) {
159
178
Terms termsAgg = aggs .get (aggName (STEP_1_AGGREGATE_BY_ACTUAL_CLASS ));
160
179
topActualClassNames .set (termsAgg .getBuckets ().stream ().map (Terms .Bucket ::getKeyAsString ).sorted ().collect (Collectors .toList ()));
161
180
}
181
+ if (actualClassesCardinality .get () == null && aggs .get (aggName (STEP_1_CARDINALITY_OF_ACTUAL_CLASS )) != null ) {
182
+ Cardinality cardinalityAgg = aggs .get (aggName (STEP_1_CARDINALITY_OF_ACTUAL_CLASS ));
183
+ actualClassesCardinality .set (cardinalityAgg .getValue ());
184
+ }
162
185
if (result .get () == null && aggs .get (aggName (STEP_2_AGGREGATE_BY_ACTUAL_CLASS )) != null ) {
163
- Cardinality cardinalityAgg = aggs .get (aggName (STEP_2_CARDINALITY_OF_ACTUAL_CLASS ));
164
186
Filters filtersAgg = aggs .get (aggName (STEP_2_AGGREGATE_BY_ACTUAL_CLASS ));
165
- List <ActualClass > actualClasses = new ArrayList <>(filtersAgg .getBuckets ().size ());
166
187
for (Filters .Bucket bucket : filtersAgg .getBuckets ()) {
167
188
String actualClass = bucket .getKeyAsString ();
168
189
long actualClassDocCount = bucket .getDocCount ();
@@ -181,7 +202,9 @@ public void process(Aggregations aggs) {
181
202
predictedClasses .sort (comparing (PredictedClass ::getPredictedClass ));
182
203
actualClasses .add (new ActualClass (actualClass , actualClassDocCount , predictedClasses , otherPredictedClassDocCount ));
183
204
}
184
- result .set (new Result (actualClasses , Math .max (cardinalityAgg .getValue () - size , 0 )));
205
+ if (actualClasses .size () == topActualClassNames .get ().size ()) {
206
+ result .set (new Result (actualClasses , Math .max (actualClassesCardinality .get () - size , 0 )));
207
+ }
185
208
}
186
209
}
187
210
0 commit comments