Skip to content

Commit 778e540

Browse files
committed
Perform evaluation in multiple steps when necessary
1 parent 6de1db8 commit 778e540

File tree

19 files changed

+273
-63
lines changed

19 files changed

+273
-63
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.elasticsearch.common.Nullable;
1010
import org.elasticsearch.common.collect.Tuple;
1111
import org.elasticsearch.common.io.stream.NamedWriteable;
12+
import org.elasticsearch.common.settings.Settings;
1213
import org.elasticsearch.common.xcontent.ToXContentObject;
1314
import org.elasticsearch.index.query.BoolQueryBuilder;
1415
import org.elasticsearch.index.query.QueryBuilder;
@@ -66,7 +67,7 @@ default <T extends EvaluationMetric> List<T> initMetrics(@Nullable List<T> parse
6667
* Builds the search required to collect data to compute the evaluation result
6768
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
6869
*/
69-
default SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
70+
default SearchSourceBuilder buildSearch(Settings settings, QueryBuilder userProvidedQueryBuilder) {
7071
Objects.requireNonNull(userProvidedQueryBuilder);
7172
BoolQueryBuilder boolQuery =
7273
QueryBuilders.boolQuery()
@@ -78,7 +79,8 @@ default SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
7879
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
7980
for (EvaluationMetric metric : getMetrics()) {
8081
// Fetch aggregations requested by individual metrics
81-
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = metric.aggs(getActualField(), getPredictedField());
82+
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs =
83+
metric.aggs(settings, getActualField(), getPredictedField());
8284
aggs.v1().forEach(searchSourceBuilder::aggregation);
8385
aggs.v2().forEach(searchSourceBuilder::aggregation);
8486
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import org.elasticsearch.action.search.SearchResponse;
99
import org.elasticsearch.common.collect.Tuple;
1010
import org.elasticsearch.common.io.stream.NamedWriteable;
11+
import org.elasticsearch.common.settings.Settings;
1112
import org.elasticsearch.common.xcontent.ToXContentObject;
1213
import org.elasticsearch.search.aggregations.AggregationBuilder;
1314
import org.elasticsearch.search.aggregations.Aggregations;
@@ -28,11 +29,12 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
2829

2930
/**
3031
* Builds the aggregation that collect required data to compute the metric
32+
* @param settings settings that may be needed by aggregations
3133
* @param actualField the field that stores the actual value
3234
* @param predictedField the field that stores the predicted value (class name or probability)
3335
* @return the aggregations required to compute the metric
3436
*/
35-
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField);
37+
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(Settings settings, String actualField, String predictedField);
3638

3739
/**
3840
* Processes given aggregations as a step towards computing result

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.common.io.stream.Writeable;
14+
import org.elasticsearch.common.settings.Settings;
1415
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1516
import org.elasticsearch.common.xcontent.ObjectParser;
1617
import org.elasticsearch.common.xcontent.ToXContentObject;
@@ -103,7 +104,9 @@ public String getName() {
103104
}
104105

105106
@Override
106-
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
107+
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(Settings settings,
108+
String actualField,
109+
String predictedField) {
107110
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
108111
this.actualField.trySet(actualField);
109112
List<AggregationBuilder> aggs = new ArrayList<>();
@@ -112,7 +115,8 @@ public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> a
112115
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField)));
113116
}
114117
if (result.get() == null) {
115-
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs = matrix.aggs(actualField, predictedField);
118+
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs =
119+
matrix.aggs(settings, actualField, predictedField);
116120
aggs.addAll(matrixAggs.v1());
117121
pipelineAggs.addAll(matrixAggs.v2());
118122
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.common.io.stream.StreamInput;
1414
import org.elasticsearch.common.io.stream.StreamOutput;
1515
import org.elasticsearch.common.io.stream.Writeable;
16+
import org.elasticsearch.common.settings.Settings;
1617
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1718
import org.elasticsearch.common.xcontent.ToXContentObject;
1819
import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -22,6 +23,7 @@
2223
import org.elasticsearch.search.aggregations.AggregationBuilders;
2324
import org.elasticsearch.search.aggregations.Aggregations;
2425
import org.elasticsearch.search.aggregations.BucketOrder;
26+
import org.elasticsearch.search.aggregations.MultiBucketConsumerService;
2527
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
2628
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
2729
import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter;
@@ -61,7 +63,9 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
6163
private static ConstructingObjectParser<MulticlassConfusionMatrix, Void> createParser() {
6264
ConstructingObjectParser<MulticlassConfusionMatrix, Void> parser =
6365
new ConstructingObjectParser<>(
64-
NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0], (String) args[1]));
66+
NAME.getPreferredName(),
67+
true,
68+
args -> new MulticlassConfusionMatrix((Integer) args[0], (String) args[1]));
6569
parser.declareInt(optionalConstructorArg(), SIZE);
6670
parser.declareString(optionalConstructorArg(), AGG_NAME_PREFIX);
6771
return parser;
@@ -72,9 +76,9 @@ public static MulticlassConfusionMatrix fromXContent(XContentParser parser) {
7276
}
7377

7478
static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
79+
static final String STEP_1_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_cardinality_of_actual_class";
7580
static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
7681
static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
77-
static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
7882
private static final String OTHER_BUCKET_KEY = "_other_";
7983
private static final String DEFAULT_AGG_NAME_PREFIX = "";
8084
private static final int DEFAULT_SIZE = 10;
@@ -83,6 +87,9 @@ public static MulticlassConfusionMatrix fromXContent(XContentParser parser) {
8387
private final int size;
8488
private final String aggNamePrefix;
8589
private final SetOnce<List<String>> topActualClassNames = new SetOnce<>();
90+
private final SetOnce<Long> actualClassesCardinality = new SetOnce<>();
91+
/** Accumulates actual classes processed so far. It may take more than 1 call to #process method to fill this field completely. */
92+
private final List<ActualClass> actualClasses = new ArrayList<>();
8693
private final SetOnce<Result> result = new SetOnce<>();
8794

8895
public MulticlassConfusionMatrix() {
@@ -121,34 +128,46 @@ public int getSize() {
121128
}
122129

123130
@Override
124-
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
125-
if (topActualClassNames.get() == null) { // This is step 1
131+
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(Settings settings,
132+
String actualField,
133+
String predictedField) {
134+
int maxBuckets = MultiBucketConsumerService.MAX_BUCKET_SETTING.get(settings);
135+
if (topActualClassNames.get() == null && actualClassesCardinality.get() == null) { // This is step 1
126136
return Tuple.tuple(
127137
List.of(
128138
AggregationBuilders.terms(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS))
129139
.field(actualField)
130140
.order(List.of(BucketOrder.count(false), BucketOrder.key(true)))
131-
.size(size)),
141+
.size(size),
142+
AggregationBuilders.cardinality(aggName(STEP_1_CARDINALITY_OF_ACTUAL_CLASS))
143+
.field(actualField)),
132144
List.of());
133145
}
134-
if (result.get() == null) { // This is step 2
135-
KeyedFilter[] keyedFiltersActual =
136-
topActualClassNames.get().stream()
137-
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
138-
.toArray(KeyedFilter[]::new);
146+
if (result.get() == null) { // These are steps 2, 3, 4 etc.
139147
KeyedFilter[] keyedFiltersPredicted =
140148
topActualClassNames.get().stream()
141149
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
142150
.toArray(KeyedFilter[]::new);
143-
return Tuple.tuple(
144-
List.of(
145-
AggregationBuilders.cardinality(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS))
146-
.field(actualField),
147-
AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS), keyedFiltersActual)
148-
.subAggregation(AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS), keyedFiltersPredicted)
149-
.otherBucket(true)
150-
.otherBucketKey(OTHER_BUCKET_KEY))),
151-
List.of());
151+
// Knowing exactly how many buckets does each aggregation use, we can choose the size of the batch so that
152+
// too_many_buckets_exception exception is not thrown.
153+
// The only exception is when "search.max_buckets" is set far too low to even have 1 actual class in the batch.
154+
// In such case, the exception will be thrown telling the user they should increase the value of "search.max_buckets".
155+
int actualClassesPerBatch = Math.max(maxBuckets / (topActualClassNames.get().size() + 2), 1);
156+
KeyedFilter[] keyedFiltersActual =
157+
topActualClassNames.get().stream()
158+
.skip(actualClasses.size())
159+
.limit(actualClassesPerBatch)
160+
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
161+
.toArray(KeyedFilter[]::new);
162+
if (keyedFiltersActual.length > 0) {
163+
return Tuple.tuple(
164+
List.of(
165+
AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS), keyedFiltersActual)
166+
.subAggregation(AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS), keyedFiltersPredicted)
167+
.otherBucket(true)
168+
.otherBucketKey(OTHER_BUCKET_KEY))),
169+
List.of());
170+
}
152171
}
153172
return Tuple.tuple(List.of(), List.of());
154173
}
@@ -159,10 +178,12 @@ public void process(Aggregations aggs) {
159178
Terms termsAgg = aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS));
160179
topActualClassNames.set(termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()));
161180
}
181+
if (actualClassesCardinality.get() == null && aggs.get(aggName(STEP_1_CARDINALITY_OF_ACTUAL_CLASS)) != null) {
182+
Cardinality cardinalityAgg = aggs.get(aggName(STEP_1_CARDINALITY_OF_ACTUAL_CLASS));
183+
actualClassesCardinality.set(cardinalityAgg.getValue());
184+
}
162185
if (result.get() == null && aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)) != null) {
163-
Cardinality cardinalityAgg = aggs.get(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS));
164186
Filters filtersAgg = aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS));
165-
List<ActualClass> actualClasses = new ArrayList<>(filtersAgg.getBuckets().size());
166187
for (Filters.Bucket bucket : filtersAgg.getBuckets()) {
167188
String actualClass = bucket.getKeyAsString();
168189
long actualClassDocCount = bucket.getDocCount();
@@ -181,7 +202,9 @@ public void process(Aggregations aggs) {
181202
predictedClasses.sort(comparing(PredictedClass::getPredictedClass));
182203
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount));
183204
}
184-
result.set(new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0)));
205+
if (actualClasses.size() == topActualClassNames.get().size()) {
206+
result.set(new Result(actualClasses, Math.max(actualClassesCardinality.get() - size, 0)));
207+
}
185208
}
186209
}
187210

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.common.io.stream.Writeable;
14+
import org.elasticsearch.common.settings.Settings;
1415
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1516
import org.elasticsearch.common.xcontent.ObjectParser;
1617
import org.elasticsearch.common.xcontent.ToXContentObject;
@@ -96,7 +97,9 @@ public String getName() {
9697
}
9798

9899
@Override
99-
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
100+
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(Settings settings,
101+
String actualField,
102+
String predictedField) {
100103
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
101104
this.actualField.trySet(actualField);
102105
if (topActualClassNames.get() == null) { // This is step 1

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.common.io.stream.Writeable;
14+
import org.elasticsearch.common.settings.Settings;
1415
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1516
import org.elasticsearch.common.xcontent.ObjectParser;
1617
import org.elasticsearch.common.xcontent.ToXContentObject;
@@ -89,7 +90,9 @@ public String getName() {
8990
}
9091

9192
@Override
92-
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
93+
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(Settings settings,
94+
String actualField,
95+
String predictedField) {
9396
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
9497
this.actualField.trySet(actualField);
9598
if (result.get() != null) {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.elasticsearch.common.collect.Tuple;
1010
import org.elasticsearch.common.io.stream.StreamInput;
1111
import org.elasticsearch.common.io.stream.StreamOutput;
12+
import org.elasticsearch.common.settings.Settings;
1213
import org.elasticsearch.common.xcontent.ObjectParser;
1314
import org.elasticsearch.common.xcontent.XContentBuilder;
1415
import org.elasticsearch.common.xcontent.XContentParser;
@@ -65,7 +66,9 @@ public String getName() {
6566
}
6667

6768
@Override
68-
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
69+
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(Settings settings,
70+
String actualField,
71+
String predictedField) {
6972
if (result != null) {
7073
return Tuple.tuple(List.of(), List.of());
7174
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.elasticsearch.common.collect.Tuple;
1010
import org.elasticsearch.common.io.stream.StreamInput;
1111
import org.elasticsearch.common.io.stream.StreamOutput;
12+
import org.elasticsearch.common.settings.Settings;
1213
import org.elasticsearch.common.xcontent.ObjectParser;
1314
import org.elasticsearch.common.xcontent.XContentBuilder;
1415
import org.elasticsearch.common.xcontent.XContentParser;
@@ -70,7 +71,9 @@ public String getName() {
7071
}
7172

7273
@Override
73-
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
74+
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(Settings settings,
75+
String actualField,
76+
String predictedField) {
7477
if (result != null) {
7578
return Tuple.tuple(List.of(), List.of());
7679
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.elasticsearch.common.collect.Tuple;
1010
import org.elasticsearch.common.io.stream.StreamInput;
1111
import org.elasticsearch.common.io.stream.StreamOutput;
12+
import org.elasticsearch.common.settings.Settings;
1213
import org.elasticsearch.common.xcontent.XContentBuilder;
1314
import org.elasticsearch.index.query.BoolQueryBuilder;
1415
import org.elasticsearch.index.query.QueryBuilder;
@@ -65,7 +66,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
6566
}
6667

6768
@Override
68-
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedProbabilityField) {
69+
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(Settings settings,
70+
String actualField,
71+
String predictedProbabilityField) {
6972
if (result != null) {
7073
return Tuple.tuple(List.of(), List.of());
7174
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.common.io.stream.Writeable;
14+
import org.elasticsearch.common.settings.Settings;
1415
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1516
import org.elasticsearch.common.xcontent.ToXContentObject;
1617
import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -127,7 +128,9 @@ public int hashCode() {
127128
}
128129

129130
@Override
130-
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedProbabilityField) {
131+
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(Settings settings,
132+
String actualField,
133+
String predictedProbabilityField) {
131134
if (result != null) {
132135
return Tuple.tuple(List.of(), List.of());
133136
}

0 commit comments

Comments
 (0)