Skip to content

Commit 5b2b8b9

Browse files
committed
Introduce EvaluationParameters class and its unit tests
1 parent 778e540 commit 5b2b8b9

File tree

20 files changed

+111
-61
lines changed

20 files changed

+111
-61
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import org.elasticsearch.common.Nullable;
1010
import org.elasticsearch.common.collect.Tuple;
1111
import org.elasticsearch.common.io.stream.NamedWriteable;
12-
import org.elasticsearch.common.settings.Settings;
1312
import org.elasticsearch.common.xcontent.ToXContentObject;
1413
import org.elasticsearch.index.query.BoolQueryBuilder;
1514
import org.elasticsearch.index.query.QueryBuilder;
@@ -67,7 +66,7 @@ default <T extends EvaluationMetric> List<T> initMetrics(@Nullable List<T> parse
6766
* Builds the search required to collect data to compute the evaluation result
6867
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
6968
*/
70-
default SearchSourceBuilder buildSearch(Settings settings, QueryBuilder userProvidedQueryBuilder) {
69+
default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBuilder userProvidedQueryBuilder) {
7170
Objects.requireNonNull(userProvidedQueryBuilder);
7271
BoolQueryBuilder boolQuery =
7372
QueryBuilders.boolQuery()
@@ -80,7 +79,7 @@ default SearchSourceBuilder buildSearch(Settings settings, QueryBuilder userProv
8079
for (EvaluationMetric metric : getMetrics()) {
8180
// Fetch aggregations requested by individual metrics
8281
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs =
83-
metric.aggs(settings, getActualField(), getPredictedField());
82+
metric.aggs(parameters, getActualField(), getPredictedField());
8483
aggs.v1().forEach(searchSourceBuilder::aggregation);
8584
aggs.v2().forEach(searchSourceBuilder::aggregation);
8685
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import org.elasticsearch.action.search.SearchResponse;
99
import org.elasticsearch.common.collect.Tuple;
1010
import org.elasticsearch.common.io.stream.NamedWriteable;
11-
import org.elasticsearch.common.settings.Settings;
1211
import org.elasticsearch.common.xcontent.ToXContentObject;
1312
import org.elasticsearch.search.aggregations.AggregationBuilder;
1413
import org.elasticsearch.search.aggregations.Aggregations;
@@ -29,12 +28,14 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
2928

3029
/**
3130
* Builds the aggregation that collect required data to compute the metric
32-
* @param settings settings that may be needed by aggregations
31+
* @param parameters settings that may be needed by aggregations
3332
* @param actualField the field that stores the actual value
3433
* @param predictedField the field that stores the predicted value (class name or probability)
3534
* @return the aggregations required to compute the metric
3635
*/
37-
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(Settings settings, String actualField, String predictedField);
36+
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
37+
String actualField,
38+
String predictedField);
3839

3940
/**
4041
* Processes given aggregations as a step towards computing result
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
7+
8+
/**
9+
* Encapsulates parameters needed by evaluation.
10+
*/
11+
public class EvaluationParameters {
12+
13+
private final int maxBuckets;
14+
15+
public EvaluationParameters(int maxBuckets) {
16+
this.maxBuckets = maxBuckets;
17+
}
18+
19+
public int getMaxBuckets() {
20+
return maxBuckets;
21+
}
22+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.common.io.stream.Writeable;
14-
import org.elasticsearch.common.settings.Settings;
1514
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1615
import org.elasticsearch.common.xcontent.ObjectParser;
1716
import org.elasticsearch.common.xcontent.ToXContentObject;
@@ -25,6 +24,7 @@
2524
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
2625
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
2726
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
27+
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
2828
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
2929

3030
import java.io.IOException;
@@ -104,7 +104,7 @@ public String getName() {
104104
}
105105

106106
@Override
107-
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(Settings settings,
107+
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
108108
String actualField,
109109
String predictedField) {
110110
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
@@ -116,7 +116,7 @@ public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> a
116116
}
117117
if (result.get() == null) {
118118
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs =
119-
matrix.aggs(settings, actualField, predictedField);
119+
matrix.aggs(parameters, actualField, predictedField);
120120
aggs.addAll(matrixAggs.v1());
121121
pipelineAggs.addAll(matrixAggs.v2());
122122
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import org.elasticsearch.common.io.stream.StreamInput;
1414
import org.elasticsearch.common.io.stream.StreamOutput;
1515
import org.elasticsearch.common.io.stream.Writeable;
16-
import org.elasticsearch.common.settings.Settings;
1716
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1817
import org.elasticsearch.common.xcontent.ToXContentObject;
1918
import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -23,14 +22,14 @@
2322
import org.elasticsearch.search.aggregations.AggregationBuilders;
2423
import org.elasticsearch.search.aggregations.Aggregations;
2524
import org.elasticsearch.search.aggregations.BucketOrder;
26-
import org.elasticsearch.search.aggregations.MultiBucketConsumerService;
2725
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
2826
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
2927
import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter;
3028
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
3129
import org.elasticsearch.search.aggregations.metrics.Cardinality;
3230
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
3331
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
32+
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
3433
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
3534

3635
import java.io.IOException;
@@ -128,10 +127,9 @@ public int getSize() {
128127
}
129128

130129
@Override
131-
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(Settings settings,
130+
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
132131
String actualField,
133132
String predictedField) {
134-
int maxBuckets = MultiBucketConsumerService.MAX_BUCKET_SETTING.get(settings);
135133
if (topActualClassNames.get() == null && actualClassesCardinality.get() == null) { // This is step 1
136134
return Tuple.tuple(
137135
List.of(
@@ -152,7 +150,7 @@ public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> a
152150
// too_many_buckets_exception exception is not thrown.
153151
// The only exception is when "search.max_buckets" is set far too low to even have 1 actual class in the batch.
154152
// In such case, the exception will be thrown telling the user they should increase the value of "search.max_buckets".
155-
int actualClassesPerBatch = Math.max(maxBuckets / (topActualClassNames.get().size() + 2), 1);
153+
int actualClassesPerBatch = Math.max(parameters.getMaxBuckets() / (topActualClassNames.get().size() + 2), 1);
156154
KeyedFilter[] keyedFiltersActual =
157155
topActualClassNames.get().stream()
158156
.skip(actualClasses.size())

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.common.io.stream.Writeable;
14-
import org.elasticsearch.common.settings.Settings;
1514
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1615
import org.elasticsearch.common.xcontent.ObjectParser;
1716
import org.elasticsearch.common.xcontent.ToXContentObject;
@@ -31,6 +30,7 @@
3130
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
3231
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
3332
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
33+
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
3434
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
3535

3636
import java.io.IOException;
@@ -97,7 +97,7 @@ public String getName() {
9797
}
9898

9999
@Override
100-
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(Settings settings,
100+
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
101101
String actualField,
102102
String predictedField) {
103103
// Store given {@code actualField} for the purpose of generating error message in {@code process}.

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.common.io.stream.Writeable;
14-
import org.elasticsearch.common.settings.Settings;
1514
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1615
import org.elasticsearch.common.xcontent.ObjectParser;
1716
import org.elasticsearch.common.xcontent.ToXContentObject;
@@ -27,6 +26,7 @@
2726
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
2827
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
2928
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
29+
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
3030
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
3131

3232
import java.io.IOException;
@@ -90,7 +90,7 @@ public String getName() {
9090
}
9191

9292
@Override
93-
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(Settings settings,
93+
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
9494
String actualField,
9595
String predictedField) {
9696
// Store given {@code actualField} for the purpose of generating error message in {@code process}.

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import org.elasticsearch.common.collect.Tuple;
1010
import org.elasticsearch.common.io.stream.StreamInput;
1111
import org.elasticsearch.common.io.stream.StreamOutput;
12-
import org.elasticsearch.common.settings.Settings;
1312
import org.elasticsearch.common.xcontent.ObjectParser;
1413
import org.elasticsearch.common.xcontent.XContentBuilder;
1514
import org.elasticsearch.common.xcontent.XContentParser;
@@ -21,6 +20,7 @@
2120
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
2221
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
2322
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
23+
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
2424

2525
import java.io.IOException;
2626
import java.text.MessageFormat;
@@ -66,7 +66,7 @@ public String getName() {
6666
}
6767

6868
@Override
69-
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(Settings settings,
69+
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
7070
String actualField,
7171
String predictedField) {
7272
if (result != null) {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import org.elasticsearch.common.collect.Tuple;
1010
import org.elasticsearch.common.io.stream.StreamInput;
1111
import org.elasticsearch.common.io.stream.StreamOutput;
12-
import org.elasticsearch.common.settings.Settings;
1312
import org.elasticsearch.common.xcontent.ObjectParser;
1413
import org.elasticsearch.common.xcontent.XContentBuilder;
1514
import org.elasticsearch.common.xcontent.XContentParser;
@@ -23,6 +22,7 @@
2322
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
2423
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
2524
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
25+
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
2626

2727
import java.io.IOException;
2828
import java.text.MessageFormat;
@@ -71,7 +71,7 @@ public String getName() {
7171
}
7272

7373
@Override
74-
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(Settings settings,
74+
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
7575
String actualField,
7676
String predictedField) {
7777
if (result != null) {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import org.elasticsearch.common.collect.Tuple;
1010
import org.elasticsearch.common.io.stream.StreamInput;
1111
import org.elasticsearch.common.io.stream.StreamOutput;
12-
import org.elasticsearch.common.settings.Settings;
1312
import org.elasticsearch.common.xcontent.XContentBuilder;
1413
import org.elasticsearch.index.query.BoolQueryBuilder;
1514
import org.elasticsearch.index.query.QueryBuilder;
@@ -20,6 +19,7 @@
2019
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
2120
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
2221
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
22+
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
2323
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
2424

2525
import java.io.IOException;
@@ -66,7 +66,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
6666
}
6767

6868
@Override
69-
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(Settings settings,
69+
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
7070
String actualField,
7171
String predictedProbabilityField) {
7272
if (result != null) {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.common.io.stream.Writeable;
14-
import org.elasticsearch.common.settings.Settings;
1514
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1615
import org.elasticsearch.common.xcontent.ToXContentObject;
1716
import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -25,6 +24,7 @@
2524
import org.elasticsearch.search.aggregations.metrics.Percentiles;
2625
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
2726
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
27+
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
2828
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
2929

3030
import java.io.IOException;
@@ -128,7 +128,7 @@ public int hashCode() {
128128
}
129129

130130
@Override
131-
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(Settings settings,
131+
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
132132
String actualField,
133133
String predictedProbabilityField) {
134134
if (result != null) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
7+
8+
import org.elasticsearch.test.ESTestCase;
9+
10+
import static org.hamcrest.Matchers.equalTo;
11+
12+
public class EvaluationParametersTests extends ESTestCase {
13+
14+
public void testConstructorAndGetters() {
15+
EvaluationParameters params = new EvaluationParameters(17);
16+
assertThat(params.getMaxBuckets(), equalTo(17));
17+
}
18+
}

0 commit comments

Comments
 (0)