Skip to content

Commit 4bfe288

Browse files
authored
Make classification evaluation metrics work when there is field mapping type mismatch (#53458)
1 parent ec3481e commit 4bfe288

File tree

6 files changed

+227
-53
lines changed

6 files changed

+227
-53
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java

+2-9
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,9 @@
2828
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
2929

3030
import java.io.IOException;
31-
import java.text.MessageFormat;
3231
import java.util.ArrayList;
3332
import java.util.Collections;
3433
import java.util.List;
35-
import java.util.Locale;
3634
import java.util.Objects;
3735
import java.util.Optional;
3836

@@ -66,12 +64,6 @@ public class Accuracy implements EvaluationMetric {
6664

6765
static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy";
6866

69-
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
70-
71-
private static Script buildScript(Object...args) {
72-
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
73-
}
74-
7567
private static final ObjectParser<Accuracy, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new);
7668

7769
public static Accuracy fromXContent(XContentParser parser) {
@@ -112,7 +104,8 @@ public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> a
112104
List<AggregationBuilder> aggs = new ArrayList<>();
113105
List<PipelineAggregationBuilder> pipelineAggs = new ArrayList<>();
114106
if (overallAccuracy.get() == null) {
115-
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField)));
107+
Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField);
108+
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(script));
116109
}
117110
if (result.get() == null) {
118111
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs =
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
7+
8+
import org.elasticsearch.script.Script;
9+
10+
import java.text.MessageFormat;
11+
import java.util.Locale;
12+
13+
/**
14+
* Painless scripts used by classification metrics in this package.
15+
*/
16+
final class PainlessScripts {
17+
18+
/**
19+
* Template for the comparison script.
20+
* It uses "String.valueOf" method in case the mapping types of the two fields are different.
21+
*/
22+
private static final MessageFormat COMPARISON_SCRIPT_TEMPLATE =
23+
new MessageFormat("String.valueOf(doc[''{0}''].value).equals(String.valueOf(doc[''{1}''].value))", Locale.ROOT);
24+
25+
/**
26+
* Builds script that tests field values equality for the given actual and predicted field names.
27+
*
28+
* @param actualField name of the actual field
29+
* @param predictedField name of the predicted field
30+
* @return script that tests whether the values of actualField and predictedField are equal
31+
*/
32+
static Script buildIsEqualScript(String actualField, String predictedField) {
33+
return new Script(COMPARISON_SCRIPT_TEMPLATE.format(new Object[]{ actualField, predictedField }));
34+
}
35+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java

+1-8
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,9 @@
3434
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
3535

3636
import java.io.IOException;
37-
import java.text.MessageFormat;
3837
import java.util.ArrayList;
3938
import java.util.Collections;
4039
import java.util.List;
41-
import java.util.Locale;
4240
import java.util.Objects;
4341
import java.util.Optional;
4442
import java.util.stream.Collectors;
@@ -59,17 +57,12 @@ public class Precision implements EvaluationMetric {
5957

6058
public static final ParseField NAME = new ParseField("precision");
6159

62-
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
6360
private static final String AGG_NAME_PREFIX = "classification_precision_";
6461
static final String ACTUAL_CLASSES_NAMES_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class";
6562
static final String BY_PREDICTED_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_predicted_class";
6663
static final String PER_PREDICTED_CLASS_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "per_predicted_class_precision";
6764
static final String AVG_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "avg_precision";
6865

69-
private static Script buildScript(Object...args) {
70-
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
71-
}
72-
7366
private static final ObjectParser<Precision, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Precision::new);
7467

7568
public static Precision fromXContent(XContentParser parser) {
@@ -116,7 +109,7 @@ public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> a
116109
topActualClassNames.get().stream()
117110
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
118111
.toArray(KeyedFilter[]::new);
119-
Script script = buildScript(actualField, predictedField);
112+
Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField);
120113
return Tuple.tuple(
121114
List.of(
122115
AggregationBuilders.filters(BY_PREDICTED_CLASS_AGG_NAME, keyedFiltersPredicted)

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java

+3-8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.search.aggregations.AggregationBuilder;
2121
import org.elasticsearch.search.aggregations.AggregationBuilders;
2222
import org.elasticsearch.search.aggregations.Aggregations;
23+
import org.elasticsearch.search.aggregations.BucketOrder;
2324
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
2425
import org.elasticsearch.search.aggregations.PipelineAggregatorBuilders;
2526
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
@@ -30,11 +31,9 @@
3031
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
3132

3233
import java.io.IOException;
33-
import java.text.MessageFormat;
3434
import java.util.ArrayList;
3535
import java.util.Collections;
3636
import java.util.List;
37-
import java.util.Locale;
3837
import java.util.Objects;
3938
import java.util.Optional;
4039

@@ -54,16 +53,11 @@ public class Recall implements EvaluationMetric {
5453

5554
public static final ParseField NAME = new ParseField("recall");
5655

57-
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
5856
private static final String AGG_NAME_PREFIX = "classification_recall_";
5957
static final String BY_ACTUAL_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class";
6058
static final String PER_ACTUAL_CLASS_RECALL_AGG_NAME = AGG_NAME_PREFIX + "per_actual_class_recall";
6159
static final String AVG_RECALL_AGG_NAME = AGG_NAME_PREFIX + "avg_recall";
6260

63-
private static Script buildScript(Object...args) {
64-
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
65-
}
66-
6761
private static final ObjectParser<Recall, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Recall::new);
6862

6963
public static Recall fromXContent(XContentParser parser) {
@@ -98,11 +92,12 @@ public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> a
9892
if (result.get() != null) {
9993
return Tuple.tuple(List.of(), List.of());
10094
}
101-
Script script = buildScript(actualField, predictedField);
95+
Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField);
10296
return Tuple.tuple(
10397
List.of(
10498
AggregationBuilders.terms(BY_ACTUAL_CLASS_AGG_NAME)
10599
.field(actualField)
100+
.order(List.of(BucketOrder.count(false), BucketOrder.key(true)))
106101
.size(MAX_CLASSES_CARDINALITY)
107102
.subAggregation(AggregationBuilders.avg(PER_ACTUAL_CLASS_RECALL_AGG_NAME).script(script))),
108103
List.of(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
7+
8+
import org.elasticsearch.script.Script;
9+
import org.elasticsearch.test.ESTestCase;
10+
11+
import static org.hamcrest.Matchers.equalTo;
12+
13+
public class PainlessScriptsTests extends ESTestCase {
14+
15+
public void testBuildIsEqualScript() {
16+
Script script = PainlessScripts.buildIsEqualScript("act", "pred");
17+
assertThat(script.getIdOrCode(), equalTo("String.valueOf(doc['act'].value).equals(String.valueOf(doc['pred'].value))"));
18+
}
19+
}

0 commit comments

Comments
 (0)