Skip to content

Commit f9ab47f

Browse files
authored
Do not fail Evaluate API when the actual and predicted fields' types differ (#54255)
1 parent 4c86da9 commit f9ab47f

File tree

3 files changed

+310
-185
lines changed

3 files changed

+310
-185
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> a
142142
if (result.get() == null) { // These are steps 2, 3, 4 etc.
143143
KeyedFilter[] keyedFiltersPredicted =
144144
topActualClassNames.get().stream()
145-
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
145+
.map(className -> new KeyedFilter(className, QueryBuilders.matchQuery(predictedField, className).lenient(true)))
146146
.toArray(KeyedFilter[]::new);
147147
// Knowing exactly how many buckets does each aggregation use, we can choose the size of the batch so that
148148
// too_many_buckets_exception exception is not thrown.
@@ -153,7 +153,7 @@ public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> a
153153
topActualClassNames.get().stream()
154154
.skip(actualClasses.size())
155155
.limit(actualClassesPerBatch)
156-
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
156+
.map(className -> new KeyedFilter(className, QueryBuilders.matchQuery(actualField, className).lenient(true)))
157157
.toArray(KeyedFilter[]::new);
158158
if (keyedFiltersActual.length > 0) {
159159
return Tuple.tuple(

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> a
107107
if (result.get() == null) { // This is step 2
108108
KeyedFilter[] keyedFiltersPredicted =
109109
topActualClassNames.get().stream()
110-
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
110+
.map(className -> new KeyedFilter(className, QueryBuilders.matchQuery(predictedField, className).lenient(true)))
111111
.toArray(KeyedFilter[]::new);
112112
Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField);
113113
return Tuple.tuple(

0 commit comments

Comments
 (0)