Skip to content

Commit cd1a27f

Browse files
authored
[ML] Implement AucRoc metric for classification (#60502)
1 parent 9b9f33e commit cd1a27f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2009
-594
lines changed

docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ the probability that each document is an outlier.
8989
`auc_roc`:::
9090
(Optional, object) The AUC ROC (area under the curve of the receiver
9191
operating characteristic) score and optionally the curve. Default value is
92-
{"includes_curve": false}.
92+
{"include_curve": false}.
9393

9494
`confusion_matrix`:::
9595
(Optional, object) Set the different thresholds of the {olscore} at where
@@ -154,16 +154,39 @@ belongs.
154154
The data type of this field must be categorical.
155155

156156
`predicted_field`::
157-
(Required, string) The field in the `index` that contains the predicted value,
157+
(Optional, string) The field in the `index` which contains the predicted value,
158158
in other words the results of the {classanalysis}.
159159

160+
`top_classes_field`::
161+
(Optional, string) The field of the `index` which is an array of documents
162+
of the form `{ "class_name": XXX, "class_probability": YYY }`.
163+
This field must be defined as `nested` in the mappings.
164+
160165
`metrics`::
161166
(Optional, object) Specifies the metrics that are used for the evaluation.
162167
Available metrics:
163168

164169
`accuracy`:::
165170
(Optional, object) Accuracy of predictions (per-class and overall).
166171

172+
`auc_roc`:::
173+
(Optional, object) The AUC ROC (area under the curve of the receiver
174+
operating characteristic) score and optionally the curve.
175+
It is calculated for a specific class (provided as "class_name")
176+
treated as positive.
177+
178+
`class_name`::::
179+
(Required, string) Name of the only class that will be treated as
180+
positive during AUC ROC calculation. Other classes will be treated as
181+
negative ("one-vs-all" strategy). Documents which do not have `class_name`
182+
in the list of their top classes will not be taken into account for evaluation.
183+
The number of documents taken into account is returned in the evaluation result
184+
(`auc_roc.doc_count` field).
185+
186+
`include_curve`::::
187+
(Optional, boolean) Whether or not the curve should be returned in
188+
addition to the score. Default value is false.
189+
167190
`multiclass_confusion_matrix`:::
168191
(Optional, object) Multiclass confusion matrix.
169192

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,16 @@ public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mapping
392392
return additionalProperties;
393393
}
394394
additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping);
395-
additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping);
395+
396+
Map<String, Object> topClassesProperties = new HashMap<>();
397+
topClassesProperties.put("class_name", dependentVariableMapping);
398+
topClassesProperties.put("class_probability", Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
399+
400+
Map<String, Object> topClassesMapping = new HashMap<>();
401+
topClassesMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE);
402+
topClassesMapping.put("properties", topClassesProperties);
403+
404+
additionalProperties.put(resultsFieldName + ".top_classes", topClassesMapping);
396405
return additionalProperties;
397406
}
398407

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java

Lines changed: 85 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
*/
66
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
77

8+
import org.apache.lucene.search.join.ScoreMode;
89
import org.elasticsearch.action.search.SearchResponse;
910
import org.elasticsearch.common.Nullable;
1011
import org.elasticsearch.common.collect.Tuple;
@@ -21,11 +22,16 @@
2122
import java.util.ArrayList;
2223
import java.util.Collections;
2324
import java.util.Comparator;
25+
import java.util.HashSet;
2426
import java.util.List;
2527
import java.util.Objects;
2628
import java.util.Optional;
29+
import java.util.Set;
2730
import java.util.function.Supplier;
28-
import java.util.stream.Collectors;
31+
32+
import static java.util.stream.Collectors.joining;
33+
import static java.util.stream.Collectors.toList;
34+
import static java.util.stream.Collectors.toSet;
2935

3036
/**
3137
* Defines an evaluation
@@ -38,14 +44,9 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
3844
String getName();
3945

4046
/**
41-
* Returns the field containing the actual value
42-
*/
43-
String getActualField();
44-
45-
/**
46-
* Returns the field containing the predicted value
47+
* Returns the collection of fields required by evaluation
4748
*/
48-
String getPredictedField();
49+
EvaluationFields getFields();
4950

5051
/**
5152
* Returns the list of metrics to evaluate
@@ -59,27 +60,74 @@ default <T extends EvaluationMetric> List<T> initMetrics(@Nullable List<T> parse
5960
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", getName());
6061
}
6162
Collections.sort(metrics, Comparator.comparing(EvaluationMetric::getName));
63+
checkRequiredFieldsAreSet(metrics);
6264
return metrics;
6365
}
6466

67+
private <T extends EvaluationMetric> void checkRequiredFieldsAreSet(List<T> metrics) {
68+
assert (metrics == null || metrics.isEmpty()) == false;
69+
for (Tuple<String, String> requiredField : getFields().listPotentiallyRequiredFields()) {
70+
String fieldDescriptor = requiredField.v1();
71+
String field = requiredField.v2();
72+
if (field == null) {
73+
String metricNamesString =
74+
metrics.stream()
75+
.filter(m -> m.getRequiredFields().contains(fieldDescriptor))
76+
.map(EvaluationMetric::getName)
77+
.collect(joining(", "));
78+
if (metricNamesString.isEmpty() == false) {
79+
throw ExceptionsHelper.badRequestException(
80+
"[{}] must define [{}] as required by the following metrics [{}]",
81+
getName(), fieldDescriptor, metricNamesString);
82+
}
83+
}
84+
}
85+
}
86+
6587
/**
6688
* Builds the search required to collect data to compute the evaluation result
6789
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
6890
*/
6991
default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBuilder userProvidedQueryBuilder) {
7092
Objects.requireNonNull(userProvidedQueryBuilder);
71-
BoolQueryBuilder boolQuery =
72-
QueryBuilders.boolQuery()
73-
// Verify existence of required fields
74-
.filter(QueryBuilders.existsQuery(getActualField()))
75-
.filter(QueryBuilders.existsQuery(getPredictedField()))
76-
// Apply user-provided query
77-
.filter(userProvidedQueryBuilder);
93+
Set<String> requiredFields = new HashSet<>(getRequiredFields());
94+
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
95+
if (getFields().getActualField() != null && requiredFields.contains(getFields().getActualField())) {
96+
// Verify existence of the actual field if required
97+
boolQuery.filter(QueryBuilders.existsQuery(getFields().getActualField()));
98+
}
99+
if (getFields().getPredictedField() != null && requiredFields.contains(getFields().getPredictedField())) {
100+
// Verify existence of the predicted field if required
101+
boolQuery.filter(QueryBuilders.existsQuery(getFields().getPredictedField()));
102+
}
103+
if (getFields().getPredictedClassField() != null && requiredFields.contains(getFields().getPredictedClassField())) {
104+
assert getFields().getTopClassesField() != null;
105+
// Verify existence of the predicted class name field if required
106+
QueryBuilder predictedClassFieldExistsQuery = QueryBuilders.existsQuery(getFields().getPredictedClassField());
107+
boolQuery.filter(
108+
QueryBuilders.nestedQuery(getFields().getTopClassesField(), predictedClassFieldExistsQuery, ScoreMode.None)
109+
.ignoreUnmapped(true));
110+
}
111+
if (getFields().getPredictedProbabilityField() != null && requiredFields.contains(getFields().getPredictedProbabilityField())) {
112+
// Verify existence of the predicted probability field if required
113+
QueryBuilder predictedProbabilityFieldExistsQuery = QueryBuilders.existsQuery(getFields().getPredictedProbabilityField());
114+
// predicted probability field may be either nested (just like in case of classification evaluation) or non-nested (just like
115+
// in case of outlier detection evaluation). Here we support both modes.
116+
if (getFields().isPredictedProbabilityFieldNested()) {
117+
assert getFields().getTopClassesField() != null;
118+
boolQuery.filter(
119+
QueryBuilders.nestedQuery(getFields().getTopClassesField(), predictedProbabilityFieldExistsQuery, ScoreMode.None)
120+
.ignoreUnmapped(true));
121+
} else {
122+
boolQuery.filter(predictedProbabilityFieldExistsQuery);
123+
}
124+
}
125+
// Apply user-provided query
126+
boolQuery.filter(userProvidedQueryBuilder);
78127
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
79128
for (EvaluationMetric metric : getMetrics()) {
80129
// Fetch aggregations requested by individual metrics
81-
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs =
82-
metric.aggs(parameters, getActualField(), getPredictedField());
130+
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = metric.aggs(parameters, getFields());
83131
aggs.v1().forEach(searchSourceBuilder::aggregation);
84132
aggs.v2().forEach(searchSourceBuilder::aggregation);
85133
}
@@ -93,14 +141,31 @@ default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBu
93141
default void process(SearchResponse searchResponse) {
94142
Objects.requireNonNull(searchResponse);
95143
if (searchResponse.getHits().getTotalHits().value == 0) {
96-
throw ExceptionsHelper.badRequestException(
97-
"No documents found containing both [{}, {}] fields", getActualField(), getPredictedField());
144+
String requiredFieldsString = String.join(", ", getRequiredFields());
145+
throw ExceptionsHelper.badRequestException("No documents found containing all the required fields [{}]", requiredFieldsString);
98146
}
99147
for (EvaluationMetric metric : getMetrics()) {
100148
metric.process(searchResponse.getAggregations());
101149
}
102150
}
103151

152+
/**
153+
* @return list of fields which are required by at least one of the metrics
154+
*/
155+
private List<String> getRequiredFields() {
156+
Set<String> requiredFieldDescriptors =
157+
getMetrics().stream()
158+
.map(EvaluationMetric::getRequiredFields)
159+
.flatMap(Set::stream)
160+
.collect(toSet());
161+
List<String> requiredFields =
162+
getFields().listPotentiallyRequiredFields().stream()
163+
.filter(f -> requiredFieldDescriptors.contains(f.v1()))
164+
.map(Tuple::v2)
165+
.collect(toList());
166+
return requiredFields;
167+
}
168+
104169
/**
105170
* @return true iff all the metrics have their results computed
106171
*/
@@ -117,6 +182,6 @@ default List<EvaluationMetricResult> getResults() {
117182
.map(EvaluationMetric::getResult)
118183
.filter(Optional::isPresent)
119184
.map(Optional::get)
120-
.collect(Collectors.toList());
185+
.collect(toList());
121186
}
122187
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
7+
8+
import org.elasticsearch.common.Nullable;
9+
import org.elasticsearch.common.ParseField;
10+
import org.elasticsearch.common.collect.Tuple;
11+
12+
import java.util.Arrays;
13+
import java.util.List;
14+
import java.util.Objects;
15+
16+
/**
17+
* Encapsulates fields needed by evaluation.
18+
*/
19+
public final class EvaluationFields {
20+
21+
public static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
22+
public static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
23+
public static final ParseField TOP_CLASSES_FIELD = new ParseField("top_classes_field");
24+
public static final ParseField PREDICTED_CLASS_FIELD = new ParseField("predicted_class_field");
25+
public static final ParseField PREDICTED_PROBABILITY_FIELD = new ParseField("predicted_probability_field");
26+
27+
/**
28+
* The field containing the actual value
29+
*/
30+
private final String actualField;
31+
32+
/**
33+
* The field containing the predicted value
34+
*/
35+
private final String predictedField;
36+
37+
/**
38+
* The field containing the array of top classes
39+
*/
40+
private final String topClassesField;
41+
42+
/**
43+
* The field containing the predicted class name value
44+
*/
45+
private final String predictedClassField;
46+
47+
/**
48+
* The field containing the predicted probability value in [0.0, 1.0]
49+
*/
50+
private final String predictedProbabilityField;
51+
52+
/**
53+
* Whether the {@code predictedProbabilityField} should be treated as nested (e.g.: when used in exists queries).
54+
*/
55+
private final boolean predictedProbabilityFieldNested;
56+
57+
public EvaluationFields(@Nullable String actualField,
58+
@Nullable String predictedField,
59+
@Nullable String topClassesField,
60+
@Nullable String predictedClassField,
61+
@Nullable String predictedProbabilityField,
62+
boolean predictedProbabilityFieldNested) {
63+
64+
this.actualField = actualField;
65+
this.predictedField = predictedField;
66+
this.topClassesField = topClassesField;
67+
this.predictedClassField = predictedClassField;
68+
this.predictedProbabilityField = predictedProbabilityField;
69+
this.predictedProbabilityFieldNested = predictedProbabilityFieldNested;
70+
}
71+
72+
/**
73+
* Returns the field containing the actual value
74+
*/
75+
public String getActualField() {
76+
return actualField;
77+
}
78+
79+
/**
80+
* Returns the field containing the predicted value
81+
*/
82+
public String getPredictedField() {
83+
return predictedField;
84+
}
85+
86+
/**
87+
* Returns the field containing the array of top classes
88+
*/
89+
public String getTopClassesField() {
90+
return topClassesField;
91+
}
92+
93+
/**
94+
* Returns the field containing the predicted class name value
95+
*/
96+
public String getPredictedClassField() {
97+
return predictedClassField;
98+
}
99+
100+
/**
101+
* Returns the field containing the predicted probability value in [0.0, 1.0]
102+
*/
103+
public String getPredictedProbabilityField() {
104+
return predictedProbabilityField;
105+
}
106+
107+
/**
108+
* Returns whether the {@code predictedProbabilityField} should be treated as nested (e.g.: when used in exists queries).
109+
*/
110+
public boolean isPredictedProbabilityFieldNested() {
111+
return predictedProbabilityFieldNested;
112+
}
113+
114+
public List<Tuple<String, String>> listPotentiallyRequiredFields() {
115+
return Arrays.asList(
116+
Tuple.tuple(ACTUAL_FIELD.getPreferredName(), actualField),
117+
Tuple.tuple(PREDICTED_FIELD.getPreferredName(), predictedField),
118+
Tuple.tuple(TOP_CLASSES_FIELD.getPreferredName(), topClassesField),
119+
Tuple.tuple(PREDICTED_CLASS_FIELD.getPreferredName(), predictedClassField),
120+
Tuple.tuple(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField));
121+
}
122+
123+
@Override
124+
public boolean equals(Object o) {
125+
if (this == o) return true;
126+
if (o == null || getClass() != o.getClass()) return false;
127+
EvaluationFields that = (EvaluationFields) o;
128+
return Objects.equals(that.actualField, this.actualField)
129+
&& Objects.equals(that.predictedField, this.predictedField)
130+
&& Objects.equals(that.topClassesField, this.topClassesField)
131+
&& Objects.equals(that.predictedClassField, this.predictedClassField)
132+
&& Objects.equals(that.predictedProbabilityField, this.predictedProbabilityField)
133+
&& Objects.equals(that.predictedProbabilityFieldNested, this.predictedProbabilityFieldNested);
134+
}
135+
136+
@Override
137+
public int hashCode() {
138+
return Objects.hash(
139+
actualField, predictedField, topClassesField, predictedClassField, predictedProbabilityField, predictedProbabilityFieldNested);
140+
}
141+
}

0 commit comments

Comments
 (0)