Skip to content

[ML] Implement AucRoc metric for classification #60502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Sep 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
aeba04e
Implement AucRoc metric for classification
przemekwitek Jul 31, 2020
243576e
Revert HLRC changes
przemekwitek Sep 10, 2020
8a1b503
Revert HLRC changes
przemekwitek Sep 10, 2020
6fd073a
Revert HLRC changes
przemekwitek Sep 10, 2020
f0d6317
Use ignoreUnmapped flag on nested queries so that the search doesn't …
przemekwitek Sep 10, 2020
4673c74
Apply review comment
przemekwitek Sep 14, 2020
ea136d7
Apply review comments
przemekwitek Sep 15, 2020
126cb7d
Fix compile error
przemekwitek Sep 15, 2020
147398f
Apply docs-related review comments
przemekwitek Sep 15, 2020
21efd55
Fix bug in Classification::getExplicitlyMappedFields method
przemekwitek Sep 16, 2020
0aff0a8
Add doc_count field; Make exception messages more detailed;
przemekwitek Sep 16, 2020
18b7428
Adapt outlier_detection.AucRoc to have the same error messages as cla…
przemekwitek Sep 16, 2020
782e7d2
Fix error message
przemekwitek Sep 16, 2020
40ac4e5
Rename predicted_class_name_field to predicted_class_field
przemekwitek Sep 17, 2020
0f5b82e
Remove redundant evaluation.classification.results_nested_field
przemekwitek Sep 17, 2020
f831d9f
Fix ml-with-security yaml tests
przemekwitek Sep 17, 2020
c0e6ca5
Apply default values for predicted_class_field and predicted_probabil…
przemekwitek Sep 17, 2020
aa08218
Apply review comment
przemekwitek Sep 17, 2020
6d70718
Fix org.elasticsearch.client.RestHighLevelClientTests.testProvidedNam…
przemekwitek Sep 17, 2020
634eb2e
Apply review comments.
przemekwitek Sep 21, 2020
f99bad1
Add top_classes_field to the Classification evaluation.
przemekwitek Sep 28, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ the probability that each document is an outlier.
`auc_roc`:::
(Optional, object) The AUC ROC (area under the curve of the receiver
operating characteristic) score and optionally the curve. Default value is
{"includes_curve": false}.
{"include_curve": false}.

`confusion_matrix`:::
(Optional, object) Set the different thresholds of the {olscore} at where
Expand Down Expand Up @@ -154,16 +154,39 @@ belongs.
The data type of this field must be categorical.

`predicted_field`::
(Required, string) The field in the `index` that contains the predicted value,
(Optional, string) The field in the `index` which contains the predicted value,
in other words the results of the {classanalysis}.

`top_classes_field`::
(Optional, string) The field of the `index` which is an array of documents
of the form `{ "class_name": XXX, "class_probability": YYY }`.
This field must be defined as `nested` in the mappings.

`metrics`::
(Optional, object) Specifies the metrics that are used for the evaluation.
Available metrics:

`accuracy`:::
(Optional, object) Accuracy of predictions (per-class and overall).

`auc_roc`:::
(Optional, object) The AUC ROC (area under the curve of the receiver
operating characteristic) score and optionally the curve.
It is calculated for a specific class (provided as "class_name")
treated as positive.

`class_name`::::
(Required, string) Name of the only class that will be treated as
positive during AUC ROC calculation. Other classes will be treated as
negative ("one-vs-all" strategy). Documents which do not have `class_name`
in the list of their top classes will not be taken into account for evaluation.
The number of documents taken into account is returned in the evaluation result
(`auc_roc.doc_count` field).

`include_curve`::::
(Optional, boolean) Whether or not the curve should be returned in
addition to the score. Default value is false.

`multiclass_confusion_matrix`:::
(Optional, object) Multiclass confusion matrix.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,16 @@ public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mapping
return additionalProperties;
}
additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping);
additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping);

Map<String, Object> topClassesProperties = new HashMap<>();
topClassesProperties.put("class_name", dependentVariableMapping);
topClassesProperties.put("class_probability", Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));

Map<String, Object> topClassesMapping = new HashMap<>();
topClassesMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE);
topClassesMapping.put("properties", topClassesProperties);

additionalProperties.put(resultsFieldName + ".top_classes", topClassesMapping);
return additionalProperties;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;

import org.apache.lucene.search.join.ScoreMode;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.collect.Tuple;
Expand All @@ -21,11 +22,16 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;

/**
* Defines an evaluation
Expand All @@ -38,14 +44,9 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
String getName();

/**
* Returns the field containing the actual value
*/
String getActualField();

/**
* Returns the field containing the predicted value
* Returns the collection of fields required by evaluation
*/
String getPredictedField();
EvaluationFields getFields();

/**
* Returns the list of metrics to evaluate
Expand All @@ -59,27 +60,74 @@ default <T extends EvaluationMetric> List<T> initMetrics(@Nullable List<T> parse
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", getName());
}
Collections.sort(metrics, Comparator.comparing(EvaluationMetric::getName));
checkRequiredFieldsAreSet(metrics);
return metrics;
}

private <T extends EvaluationMetric> void checkRequiredFieldsAreSet(List<T> metrics) {
assert (metrics == null || metrics.isEmpty()) == false;
for (Tuple<String, String> requiredField : getFields().listPotentiallyRequiredFields()) {
String fieldDescriptor = requiredField.v1();
String field = requiredField.v2();
if (field == null) {
String metricNamesString =
metrics.stream()
.filter(m -> m.getRequiredFields().contains(fieldDescriptor))
.map(EvaluationMetric::getName)
.collect(joining(", "));
if (metricNamesString.isEmpty() == false) {
throw ExceptionsHelper.badRequestException(
"[{}] must define [{}] as required by the following metrics [{}]",
getName(), fieldDescriptor, metricNamesString);
}
}
}
}

/**
* Builds the search required to collect data to compute the evaluation result
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
*/
default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBuilder userProvidedQueryBuilder) {
Objects.requireNonNull(userProvidedQueryBuilder);
BoolQueryBuilder boolQuery =
QueryBuilders.boolQuery()
// Verify existence of required fields
.filter(QueryBuilders.existsQuery(getActualField()))
.filter(QueryBuilders.existsQuery(getPredictedField()))
// Apply user-provided query
.filter(userProvidedQueryBuilder);
Set<String> requiredFields = new HashSet<>(getRequiredFields());
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
if (getFields().getActualField() != null && requiredFields.contains(getFields().getActualField())) {
// Verify existence of the actual field if required
boolQuery.filter(QueryBuilders.existsQuery(getFields().getActualField()));
}
if (getFields().getPredictedField() != null && requiredFields.contains(getFields().getPredictedField())) {
// Verify existence of the predicted field if required
boolQuery.filter(QueryBuilders.existsQuery(getFields().getPredictedField()));
}
if (getFields().getPredictedClassField() != null && requiredFields.contains(getFields().getPredictedClassField())) {
assert getFields().getTopClassesField() != null;
// Verify existence of the predicted class name field if required
QueryBuilder predictedClassFieldExistsQuery = QueryBuilders.existsQuery(getFields().getPredictedClassField());
boolQuery.filter(
QueryBuilders.nestedQuery(getFields().getTopClassesField(), predictedClassFieldExistsQuery, ScoreMode.None)
.ignoreUnmapped(true));
}
if (getFields().getPredictedProbabilityField() != null && requiredFields.contains(getFields().getPredictedProbabilityField())) {
// Verify existence of the predicted probability field if required
QueryBuilder predictedProbabilityFieldExistsQuery = QueryBuilders.existsQuery(getFields().getPredictedProbabilityField());
// predicted probability field may be either nested (just like in case of classification evaluation) or non-nested (just like
// in case of outlier detection evaluation). Here we support both modes.
if (getFields().isPredictedProbabilityFieldNested()) {
assert getFields().getTopClassesField() != null;
boolQuery.filter(
QueryBuilders.nestedQuery(getFields().getTopClassesField(), predictedProbabilityFieldExistsQuery, ScoreMode.None)
.ignoreUnmapped(true));
} else {
boolQuery.filter(predictedProbabilityFieldExistsQuery);
}
}
// Apply user-provided query
boolQuery.filter(userProvidedQueryBuilder);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
for (EvaluationMetric metric : getMetrics()) {
// Fetch aggregations requested by individual metrics
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs =
metric.aggs(parameters, getActualField(), getPredictedField());
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = metric.aggs(parameters, getFields());
aggs.v1().forEach(searchSourceBuilder::aggregation);
aggs.v2().forEach(searchSourceBuilder::aggregation);
}
Expand All @@ -93,14 +141,31 @@ default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBu
default void process(SearchResponse searchResponse) {
Objects.requireNonNull(searchResponse);
if (searchResponse.getHits().getTotalHits().value == 0) {
throw ExceptionsHelper.badRequestException(
"No documents found containing both [{}, {}] fields", getActualField(), getPredictedField());
String requiredFieldsString = String.join(", ", getRequiredFields());
throw ExceptionsHelper.badRequestException("No documents found containing all the required fields [{}]", requiredFieldsString);
}
for (EvaluationMetric metric : getMetrics()) {
metric.process(searchResponse.getAggregations());
}
}

/**
* @return list of fields which are required by at least one of the metrics
*/
private List<String> getRequiredFields() {
Set<String> requiredFieldDescriptors =
getMetrics().stream()
.map(EvaluationMetric::getRequiredFields)
.flatMap(Set::stream)
.collect(toSet());
List<String> requiredFields =
getFields().listPotentiallyRequiredFields().stream()
.filter(f -> requiredFieldDescriptors.contains(f.v1()))
.map(Tuple::v2)
.collect(toList());
return requiredFields;
}

/**
* @return true iff all the metrics have their results computed
*/
Expand All @@ -117,6 +182,6 @@ default List<EvaluationMetricResult> getResults() {
.map(EvaluationMetric::getResult)
.filter(Optional::isPresent)
.map(Optional::get)
.collect(Collectors.toList());
.collect(toList());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;

import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;

/**
* Encapsulates fields needed by evaluation.
*/
public final class EvaluationFields {

public static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
public static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
public static final ParseField TOP_CLASSES_FIELD = new ParseField("top_classes_field");
public static final ParseField PREDICTED_CLASS_FIELD = new ParseField("predicted_class_field");
public static final ParseField PREDICTED_PROBABILITY_FIELD = new ParseField("predicted_probability_field");

/**
* The field containing the actual value
*/
private final String actualField;

/**
* The field containing the predicted value
*/
private final String predictedField;

/**
* The field containing the array of top classes
*/
private final String topClassesField;

/**
* The field containing the predicted class name value
*/
private final String predictedClassField;

/**
* The field containing the predicted probability value in [0.0, 1.0]
*/
private final String predictedProbabilityField;

/**
* Whether the {@code predictedProbabilityField} should be treated as nested (e.g.: when used in exists queries).
*/
private final boolean predictedProbabilityFieldNested;

public EvaluationFields(@Nullable String actualField,
@Nullable String predictedField,
@Nullable String topClassesField,
@Nullable String predictedClassField,
@Nullable String predictedProbabilityField,
boolean predictedProbabilityFieldNested) {

this.actualField = actualField;
this.predictedField = predictedField;
this.topClassesField = topClassesField;
this.predictedClassField = predictedClassField;
this.predictedProbabilityField = predictedProbabilityField;
this.predictedProbabilityFieldNested = predictedProbabilityFieldNested;
}

/**
* Returns the field containing the actual value
*/
public String getActualField() {
return actualField;
}

/**
* Returns the field containing the predicted value
*/
public String getPredictedField() {
return predictedField;
}

/**
* Returns the field containing the array of top classes
*/
public String getTopClassesField() {
return topClassesField;
}

/**
* Returns the field containing the predicted class name value
*/
public String getPredictedClassField() {
return predictedClassField;
}

/**
* Returns the field containing the predicted probability value in [0.0, 1.0]
*/
public String getPredictedProbabilityField() {
return predictedProbabilityField;
}

/**
* Returns whether the {@code predictedProbabilityField} should be treated as nested (e.g.: when used in exists queries).
*/
public boolean isPredictedProbabilityFieldNested() {
return predictedProbabilityFieldNested;
}

public List<Tuple<String, String>> listPotentiallyRequiredFields() {
return Arrays.asList(
Tuple.tuple(ACTUAL_FIELD.getPreferredName(), actualField),
Tuple.tuple(PREDICTED_FIELD.getPreferredName(), predictedField),
Tuple.tuple(TOP_CLASSES_FIELD.getPreferredName(), topClassesField),
Tuple.tuple(PREDICTED_CLASS_FIELD.getPreferredName(), predictedClassField),
Tuple.tuple(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField));
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
EvaluationFields that = (EvaluationFields) o;
return Objects.equals(that.actualField, this.actualField)
&& Objects.equals(that.predictedField, this.predictedField)
&& Objects.equals(that.topClassesField, this.topClassesField)
&& Objects.equals(that.predictedClassField, this.predictedClassField)
&& Objects.equals(that.predictedProbabilityField, this.predictedProbabilityField)
&& Objects.equals(that.predictedProbabilityFieldNested, this.predictedProbabilityFieldNested);
}

@Override
public int hashCode() {
return Objects.hash(
actualField, predictedField, topClassesField, predictedClassField, predictedProbabilityField, predictedProbabilityFieldNested);
}
}
Loading