Skip to content

[ML] Validate that AucRoc has the data necessary to be calculated #63302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,27 +114,23 @@ public static Result fromXContent(XContentParser parser) {
}

private static final ParseField SCORE = new ParseField("score");
private static final ParseField DOC_COUNT = new ParseField("doc_count");
private static final ParseField CURVE = new ParseField("curve");

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>(
"auc_roc_result", true, args -> new Result((double) args[0], (long) args[1], (List<AucRocPoint>) args[2]));
"auc_roc_result", true, args -> new Result((double) args[0], (List<AucRocPoint>) args[1]));

static {
PARSER.declareDouble(constructorArg(), SCORE);
PARSER.declareLong(constructorArg(), DOC_COUNT);
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE);
}

private final double score;
private final long docCount;
private final List<AucRocPoint> curve;

public Result(double score, long docCount, @Nullable List<AucRocPoint> curve) {
public Result(double score, @Nullable List<AucRocPoint> curve) {
this.score = score;
this.docCount = docCount;
this.curve = curve;
}

Expand All @@ -147,10 +143,6 @@ public double getScore() {
return score;
}

public long getDocCount() {
return docCount;
}

public List<AucRocPoint> getCurve() {
return curve == null ? null : Collections.unmodifiableList(curve);
}
Expand All @@ -159,7 +151,6 @@ public List<AucRocPoint> getCurve() {
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(SCORE.getPreferredName(), score);
builder.field(DOC_COUNT.getPreferredName(), docCount);
if (curve != null && curve.isEmpty() == false) {
builder.field(CURVE.getPreferredName(), curve);
}
Expand All @@ -173,13 +164,12 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return score == that.score
&& docCount == that.docCount
&& Objects.equals(curve, that.curve);
}

@Override
public int hashCode() {
return Objects.hash(score, docCount, curve);
return Objects.hash(score, curve);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.anyOf;
Expand Down Expand Up @@ -1901,18 +1902,17 @@ public void testEvaluateDataFrame_Classification() throws IOException {
createIndex(indexName, mappingForClassification());
BulkRequest regressionBulk = new BulkRequest()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(docForClassification(indexName, "cat", "cat", 0.9))
.add(docForClassification(indexName, "cat", "cat", 0.85))
.add(docForClassification(indexName, "cat", "cat", 0.95))
.add(docForClassification(indexName, "cat", "dog", 0.4))
.add(docForClassification(indexName, "cat", "fish", 0.35))
.add(docForClassification(indexName, "dog", "cat", 0.5))
.add(docForClassification(indexName, "dog", "dog", 0.4))
.add(docForClassification(indexName, "dog", "dog", 0.35))
.add(docForClassification(indexName, "dog", "dog", 0.6))
.add(docForClassification(indexName, "ant", "cat", 0.1));
.add(docForClassification(indexName, "cat", "cat", "dog", "ant"))
.add(docForClassification(indexName, "cat", "cat", "dog", "ant"))
.add(docForClassification(indexName, "cat", "cat", "horse", "dog"))
.add(docForClassification(indexName, "cat", "dog", "cat", "mule"))
.add(docForClassification(indexName, "cat", "fish", "cat", "dog"))
.add(docForClassification(indexName, "dog", "cat", "dog", "mule"))
.add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
.add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
.add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
.add(docForClassification(indexName, "ant", "cat", "ant", "wasp"));
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);

MachineLearningClient machineLearningClient = highLevelClient().machineLearning();

{ // AucRoc
Expand All @@ -1927,8 +1927,7 @@ public void testEvaluateDataFrame_Classification() throws IOException {

AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME);
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
assertThat(aucRocResult.getScore(), closeTo(0.99995, 1e-9));
assertThat(aucRocResult.getDocCount(), equalTo(5L));
assertThat(aucRocResult.getScore(), closeTo(0.6425, 1e-9));
assertNotNull(aucRocResult.getCurve());
}
{ // Accuracy
Expand Down Expand Up @@ -2143,15 +2142,19 @@ private static XContentBuilder mappingForClassification() throws IOException {
.endObject();
}

private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass, double p) {
private static IndexRequest docForClassification(String indexName,
String actualClass,
String... topPredictedClasses) {
assert topPredictedClasses.length > 0;
return new IndexRequest()
.index(indexName)
.source(XContentType.JSON,
actualClassField, actualClass,
predictedClassField, predictedClass,
topClassesField, List.of(
Map.of("class_name", predictedClass, "class_probability", p),
Map.of("class_name", "other", "class_probability", 1 - p)));
predictedClassField, topPredictedClasses[0],
topClassesField, IntStream.range(0, topPredictedClasses.length)
// Consecutive assigned probabilities are: 0.5, 0.25, 0.125, etc.
.mapToObj(i -> Map.of("class_name", topPredictedClasses[i], "class_probability", 1.0 / (2 << i)))
.collect(Collectors.toList()));
}

private static final String actualRegression = "regression_actual";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@
import org.elasticsearch.client.ml.job.results.Influencer;
import org.elasticsearch.client.ml.job.results.OverallBucket;
import org.elasticsearch.client.ml.job.stats.JobStats;
import org.elasticsearch.common.TriFunction;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
Expand Down Expand Up @@ -229,8 +228,11 @@
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static java.util.stream.Collectors.toList;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.contains;
Expand Down Expand Up @@ -3466,28 +3468,30 @@ public void testEvaluateDataFrame_Classification() throws Exception {
.endObject()
.endObject()
.endObject());
TriFunction<String, String, Double, IndexRequest> indexRequest = (actualClass, predictedClass, p) -> {
BiFunction<String, String[], IndexRequest> indexRequest = (actualClass, topPredictedClasses) -> {
assert topPredictedClasses.length > 0;
return new IndexRequest()
.source(XContentType.JSON,
"actual_class", actualClass,
"predicted_class", predictedClass,
"ml.top_classes", List.of(
Map.of("class_name", predictedClass, "class_probability", p),
Map.of("class_name", "other", "class_probability", 1 - p)));
"predicted_class", topPredictedClasses[0],
"ml.top_classes", IntStream.range(0, topPredictedClasses.length)
// Consecutive assigned probabilities are: 0.5, 0.25, 0.125, etc.
.mapToObj(i -> Map.of("class_name", topPredictedClasses[i], "class_probability", 1.0 / (2 << i)))
.collect(toList()));
};
BulkRequest bulkRequest =
new BulkRequest(indexName)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(indexRequest.apply("cat", "cat", 0.9)) // #0
.add(indexRequest.apply("cat", "cat", 0.9)) // #1
.add(indexRequest.apply("cat", "cat", 0.9)) // #2
.add(indexRequest.apply("cat", "dog", 0.9)) // #3
.add(indexRequest.apply("cat", "fox", 0.9)) // #4
.add(indexRequest.apply("dog", "cat", 0.9)) // #5
.add(indexRequest.apply("dog", "dog", 0.9)) // #6
.add(indexRequest.apply("dog", "dog", 0.9)) // #7
.add(indexRequest.apply("dog", "dog", 0.9)) // #8
.add(indexRequest.apply("ant", "cat", 0.9)); // #9
.add(indexRequest.apply("cat", new String[]{"cat", "dog", "ant"})) // #0
.add(indexRequest.apply("cat", new String[]{"cat", "dog", "ant"})) // #1
.add(indexRequest.apply("cat", new String[]{"cat", "horse", "dog"})) // #2
.add(indexRequest.apply("cat", new String[]{"dog", "cat", "mule"})) // #3
.add(indexRequest.apply("cat", new String[]{"fox", "cat", "dog"})) // #4
.add(indexRequest.apply("dog", new String[]{"cat", "dog", "mule"})) // #5
.add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #6
.add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #7
.add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #8
.add(indexRequest.apply("ant", new String[]{"cat", "ant", "wasp"})); // #9
RestHighLevelClient client = highLevelClient();
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
client.bulk(bulkRequest, RequestOptions.DEFAULT);
Expand Down Expand Up @@ -3527,7 +3531,6 @@ public void testEvaluateDataFrame_Classification() throws Exception {

AucRocMetric.Result aucRocResult = response.getMetricByName(AucRocMetric.NAME); // <10>
double aucRocScore = aucRocResult.getScore(); // <11>
Long aucRocDocCount = aucRocResult.getDocCount(); // <12>
// end::evaluate-data-frame-results-classification

assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
Expand Down Expand Up @@ -3562,8 +3565,7 @@ public void testEvaluateDataFrame_Classification() throws Exception {
assertThat(otherClassesCount, equalTo(0L));

assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
assertThat(aucRocScore, equalTo(0.2625));
assertThat(aucRocDocCount, equalTo(5L));
assertThat(aucRocScore, closeTo(0.6425, 1e-9));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ public class AucRocMetricResultTests extends AbstractXContentTestCase<AucRocMetr
public static AucRocMetric.Result randomResult() {
return new AucRocMetric.Result(
randomDouble(),
randomLong(),
Stream
.generate(AucRocMetricAucRocPointTests::randomPoint)
.limit(randomIntBetween(1, 10))
Expand Down
1 change: 0 additions & 1 deletion docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ include-tagged::{doc-tests-file}[{api}-results-classification]
<9> Fetching the number of classes that were not included in the matrix
<10> Fetching AucRoc metric by name
<11> Fetching the actual AucRoc score
<12> Fetching the number of documents that were used in order to calculate AucRoc score

===== Regression

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,8 @@ belongs.
`class_name`::::
(Required, string) Name of the only class that will be treated as
positive during AUC ROC calculation. Other classes will be treated as
negative ("one-vs-all" strategy). Documents which do not have `class_name`
in the list of their top classes will not be taken into account for
evaluation. The number of documents taken into account is returned in the
evaluation result (`auc_roc.doc_count` field).
negative ("one-vs-all" strategy). All the evaluated documents must have `class_name`
in the list of their top classes.

`include_curve`::::
(Optional, boolean) Whether or not the curve should be returned in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,33 +230,25 @@ private static double interpolate(double x, double x1, double y1, double x2, dou
public static class Result implements EvaluationMetricResult {

private static final String SCORE = "score";
private static final String DOC_COUNT = "doc_count";
private static final String CURVE = "curve";

private final double score;
private final Long docCount;
private final List<AucRocPoint> curve;

public Result(double score, Long docCount, List<AucRocPoint> curve) {
public Result(double score, List<AucRocPoint> curve) {
this.score = score;
this.docCount = docCount;
this.curve = Objects.requireNonNull(curve);
}

public Result(StreamInput in) throws IOException {
this.score = in.readDouble();
this.docCount = in.readOptionalLong();
this.curve = in.readList(AucRocPoint::new);
}

public double getScore() {
return score;
}

public Long getDocCount() {
return docCount;
}

public List<AucRocPoint> getCurve() {
return Collections.unmodifiableList(curve);
}
Expand All @@ -274,17 +266,13 @@ public String getMetricName() {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(score);
out.writeOptionalLong(docCount);
out.writeList(curve);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(SCORE, score);
if (docCount != null) {
builder.field(DOC_COUNT, docCount);
}
if (curve.isEmpty() == false) {
builder.field(CURVE, curve);
}
Expand All @@ -298,13 +286,12 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return score == that.score
&& Objects.equals(docCount, that.docCount)
&& Objects.equals(curve, that.curve);
}

@Override
public int hashCode() {
return Objects.hash(score, docCount, curve);
return Objects.hash(score, curve);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -182,42 +182,39 @@ public void process(Aggregations aggs) {
Filter classAgg = aggs.get(TRUE_AGG_NAME);
Nested classNested = classAgg.getAggregations().get(NESTED_AGG_NAME);
Filter classNestedFilter = classNested.getAggregations().get(NESTED_FILTER_AGG_NAME);

Filter restAgg = aggs.get(NON_TRUE_AGG_NAME);
Nested restNested = restAgg.getAggregations().get(NESTED_AGG_NAME);
Filter restNestedFilter = restNested.getAggregations().get(NESTED_FILTER_AGG_NAME);

if (classAgg.getDocCount() == 0) {
throw ExceptionsHelper.badRequestException(
"[{}] requires at least one [{}] to have the value [{}]",
getName(), fields.get().getActualField(), className);
}
if (classNestedFilter.getDocCount() == 0) {
throw ExceptionsHelper.badRequestException(
"[{}] requires at least one [{}] to have the value [{}]",
getName(), fields.get().getPredictedClassField(), className);
}
Percentiles classPercentiles = classNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME);
double[] tpPercentiles = percentilesArray(classPercentiles);

Filter restAgg = aggs.get(NON_TRUE_AGG_NAME);
Nested restNested = restAgg.getAggregations().get(NESTED_AGG_NAME);
Filter restNestedFilter = restNested.getAggregations().get(NESTED_FILTER_AGG_NAME);
if (restAgg.getDocCount() == 0) {
throw ExceptionsHelper.badRequestException(
"[{}] requires at least one [{}] to have a different value than [{}]",
getName(), fields.get().getActualField(), className);
}
if (restNestedFilter.getDocCount() == 0) {
long filteredDocCount = classNestedFilter.getDocCount() + restNestedFilter.getDocCount();
long totalDocCount = classAgg.getDocCount() + restAgg.getDocCount();
if (filteredDocCount < totalDocCount) {
throw ExceptionsHelper.badRequestException(
"[{}] requires at least one [{}] to have the value [{}]",
getName(), fields.get().getPredictedClassField(), className);
"[{}] requires that [{}] appears as one of the [{}] for every document (appeared in {} out of {}). "
+ "This is probably caused by the {} value being less than the total number of actual classes in the dataset.",
getName(), className, fields.get().getPredictedClassField(), filteredDocCount, totalDocCount,
org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification.NUM_TOP_CLASSES.getPreferredName());
}

Percentiles classPercentiles = classNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME);
double[] tpPercentiles = percentilesArray(classPercentiles);
Percentiles restPercentiles = restNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME);
double[] fpPercentiles = percentilesArray(restPercentiles);

List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = calculateAucScore(aucRocCurve);
result.set(
new Result(
aucRocScore,
classNestedFilter.getDocCount() + restNestedFilter.getDocCount(),
includeCurve ? aucRocCurve : Collections.emptyList()));
result.set(new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,7 @@ public void process(Aggregations aggs) {

List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = calculateAucScore(aucRocCurve);
result.set(
new Result(
aucRocScore,
classAgg.getDocCount() + restAgg.getDocCount(),
includeCurve ? aucRocCurve : Collections.emptyList()));
result.set(new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList()));
}

@Override
Expand Down
Loading