5
5
*/
6
6
package org .elasticsearch .xpack .core .ml .dataframe .evaluation ;
7
7
8
+ import org .apache .lucene .search .join .ScoreMode ;
8
9
import org .elasticsearch .action .search .SearchResponse ;
9
10
import org .elasticsearch .common .Nullable ;
10
11
import org .elasticsearch .common .collect .Tuple ;
21
22
import java .util .ArrayList ;
22
23
import java .util .Collections ;
23
24
import java .util .Comparator ;
25
+ import java .util .HashSet ;
24
26
import java .util .List ;
25
27
import java .util .Objects ;
26
28
import java .util .Optional ;
29
+ import java .util .Set ;
27
30
import java .util .function .Supplier ;
28
- import java .util .stream .Collectors ;
31
+
32
+ import static java .util .stream .Collectors .joining ;
33
+ import static java .util .stream .Collectors .toList ;
34
+ import static java .util .stream .Collectors .toSet ;
29
35
30
36
/**
31
37
* Defines an evaluation
@@ -38,14 +44,9 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
38
44
String getName ();
39
45
40
46
/**
41
- * Returns the field containing the actual value
42
- */
43
- String getActualField ();
44
-
45
- /**
46
- * Returns the field containing the predicted value
47
+ * Returns the collection of fields required by evaluation
47
48
*/
48
- String getPredictedField ();
49
+ EvaluationFields getFields ();
49
50
50
51
/**
51
52
* Returns the list of metrics to evaluate
@@ -59,27 +60,74 @@ default <T extends EvaluationMetric> List<T> initMetrics(@Nullable List<T> parse
59
60
throw ExceptionsHelper .badRequestException ("[{}] must have one or more metrics" , getName ());
60
61
}
61
62
Collections .sort (metrics , Comparator .comparing (EvaluationMetric ::getName ));
63
+ checkRequiredFieldsAreSet (metrics );
62
64
return metrics ;
63
65
}
64
66
67
+ private <T extends EvaluationMetric > void checkRequiredFieldsAreSet (List <T > metrics ) {
68
+ assert (metrics == null || metrics .isEmpty ()) == false ;
69
+ for (Tuple <String , String > requiredField : getFields ().listPotentiallyRequiredFields ()) {
70
+ String fieldDescriptor = requiredField .v1 ();
71
+ String field = requiredField .v2 ();
72
+ if (field == null ) {
73
+ String metricNamesString =
74
+ metrics .stream ()
75
+ .filter (m -> m .getRequiredFields ().contains (fieldDescriptor ))
76
+ .map (EvaluationMetric ::getName )
77
+ .collect (joining (", " ));
78
+ if (metricNamesString .isEmpty () == false ) {
79
+ throw ExceptionsHelper .badRequestException (
80
+ "[{}] must define [{}] as required by the following metrics [{}]" ,
81
+ getName (), fieldDescriptor , metricNamesString );
82
+ }
83
+ }
84
+ }
85
+ }
86
+
65
87
/**
66
88
* Builds the search required to collect data to compute the evaluation result
67
89
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
68
90
*/
69
91
default SearchSourceBuilder buildSearch (EvaluationParameters parameters , QueryBuilder userProvidedQueryBuilder ) {
70
92
Objects .requireNonNull (userProvidedQueryBuilder );
71
- BoolQueryBuilder boolQuery =
72
- QueryBuilders .boolQuery ()
73
- // Verify existence of required fields
74
- .filter (QueryBuilders .existsQuery (getActualField ()))
75
- .filter (QueryBuilders .existsQuery (getPredictedField ()))
76
- // Apply user-provided query
77
- .filter (userProvidedQueryBuilder );
93
+ Set <String > requiredFields = new HashSet <>(getRequiredFields ());
94
+ BoolQueryBuilder boolQuery = QueryBuilders .boolQuery ();
95
+ if (getFields ().getActualField () != null && requiredFields .contains (getFields ().getActualField ())) {
96
+ // Verify existence of the actual field if required
97
+ boolQuery .filter (QueryBuilders .existsQuery (getFields ().getActualField ()));
98
+ }
99
+ if (getFields ().getPredictedField () != null && requiredFields .contains (getFields ().getPredictedField ())) {
100
+ // Verify existence of the predicted field if required
101
+ boolQuery .filter (QueryBuilders .existsQuery (getFields ().getPredictedField ()));
102
+ }
103
+ if (getFields ().getPredictedClassField () != null && requiredFields .contains (getFields ().getPredictedClassField ())) {
104
+ assert getFields ().getTopClassesField () != null ;
105
+ // Verify existence of the predicted class name field if required
106
+ QueryBuilder predictedClassFieldExistsQuery = QueryBuilders .existsQuery (getFields ().getPredictedClassField ());
107
+ boolQuery .filter (
108
+ QueryBuilders .nestedQuery (getFields ().getTopClassesField (), predictedClassFieldExistsQuery , ScoreMode .None )
109
+ .ignoreUnmapped (true ));
110
+ }
111
+ if (getFields ().getPredictedProbabilityField () != null && requiredFields .contains (getFields ().getPredictedProbabilityField ())) {
112
+ // Verify existence of the predicted probability field if required
113
+ QueryBuilder predictedProbabilityFieldExistsQuery = QueryBuilders .existsQuery (getFields ().getPredictedProbabilityField ());
114
+ // predicted probability field may be either nested (just like in case of classification evaluation) or non-nested (just like
115
+ // in case of outlier detection evaluation). Here we support both modes.
116
+ if (getFields ().isPredictedProbabilityFieldNested ()) {
117
+ assert getFields ().getTopClassesField () != null ;
118
+ boolQuery .filter (
119
+ QueryBuilders .nestedQuery (getFields ().getTopClassesField (), predictedProbabilityFieldExistsQuery , ScoreMode .None )
120
+ .ignoreUnmapped (true ));
121
+ } else {
122
+ boolQuery .filter (predictedProbabilityFieldExistsQuery );
123
+ }
124
+ }
125
+ // Apply user-provided query
126
+ boolQuery .filter (userProvidedQueryBuilder );
78
127
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder ().size (0 ).query (boolQuery );
79
128
for (EvaluationMetric metric : getMetrics ()) {
80
129
// Fetch aggregations requested by individual metrics
81
- Tuple <List <AggregationBuilder >, List <PipelineAggregationBuilder >> aggs =
82
- metric .aggs (parameters , getActualField (), getPredictedField ());
130
+ Tuple <List <AggregationBuilder >, List <PipelineAggregationBuilder >> aggs = metric .aggs (parameters , getFields ());
83
131
aggs .v1 ().forEach (searchSourceBuilder ::aggregation );
84
132
aggs .v2 ().forEach (searchSourceBuilder ::aggregation );
85
133
}
@@ -93,14 +141,31 @@ default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBu
93
141
default void process (SearchResponse searchResponse ) {
94
142
Objects .requireNonNull (searchResponse );
95
143
if (searchResponse .getHits ().getTotalHits ().value == 0 ) {
96
- throw ExceptionsHelper . badRequestException (
97
- "No documents found containing both [{}, {}] fields " , getActualField (), getPredictedField () );
144
+ String requiredFieldsString = String . join ( ", " , getRequiredFields ());
145
+ throw ExceptionsHelper . badRequestException ( "No documents found containing all the required fields [{}] " , requiredFieldsString );
98
146
}
99
147
for (EvaluationMetric metric : getMetrics ()) {
100
148
metric .process (searchResponse .getAggregations ());
101
149
}
102
150
}
103
151
152
+ /**
153
+ * @return list of fields which are required by at least one of the metrics
154
+ */
155
+ private List <String > getRequiredFields () {
156
+ Set <String > requiredFieldDescriptors =
157
+ getMetrics ().stream ()
158
+ .map (EvaluationMetric ::getRequiredFields )
159
+ .flatMap (Set ::stream )
160
+ .collect (toSet ());
161
+ List <String > requiredFields =
162
+ getFields ().listPotentiallyRequiredFields ().stream ()
163
+ .filter (f -> requiredFieldDescriptors .contains (f .v1 ()))
164
+ .map (Tuple ::v2 )
165
+ .collect (toList ());
166
+ return requiredFields ;
167
+ }
168
+
104
169
/**
105
170
* @return true iff all the metrics have their results computed
106
171
*/
@@ -117,6 +182,6 @@ default List<EvaluationMetricResult> getResults() {
117
182
.map (EvaluationMetric ::getResult )
118
183
.filter (Optional ::isPresent )
119
184
.map (Optional ::get )
120
- .collect (Collectors . toList ());
185
+ .collect (toList ());
121
186
}
122
187
}
0 commit comments