Skip to content

Commit 175f2a7

Browse files
committed
Apply review comments
1 parent 1503a78 commit 175f2a7

File tree

4 files changed

+48
-48
lines changed

4 files changed

+48
-48
lines changed

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@
200200
import java.util.Map;
201201
import java.util.concurrent.TimeUnit;
202202
import java.util.stream.Collectors;
203+
import java.util.stream.IntStream;
203204

204205
import static org.hamcrest.Matchers.allOf;
205206
import static org.hamcrest.Matchers.anyOf;
@@ -1901,18 +1902,17 @@ public void testEvaluateDataFrame_Classification() throws IOException {
19011902
createIndex(indexName, mappingForClassification());
19021903
BulkRequest regressionBulk = new BulkRequest()
19031904
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
1904-
.add(docForClassification(indexName, "cat", "cat", 0.9, "dog"))
1905-
.add(docForClassification(indexName, "cat", "cat", 0.85, "dog"))
1906-
.add(docForClassification(indexName, "cat", "cat", 0.95, "horse"))
1907-
.add(docForClassification(indexName, "cat", "dog", 0.4, "cat"))
1908-
.add(docForClassification(indexName, "cat", "fish", 0.35, "cat"))
1909-
.add(docForClassification(indexName, "dog", "cat", 0.5, "dog"))
1910-
.add(docForClassification(indexName, "dog", "dog", 0.4, "cat"))
1911-
.add(docForClassification(indexName, "dog", "dog", 0.35, "cat"))
1912-
.add(docForClassification(indexName, "dog", "dog", 0.6, "cat"))
1913-
.add(docForClassification(indexName, "ant", "cat", 0.1, "ant"));
1905+
.add(docForClassification(indexName, "cat", "cat", "dog", "ant"))
1906+
.add(docForClassification(indexName, "cat", "cat", "dog", "ant"))
1907+
.add(docForClassification(indexName, "cat", "cat", "horse", "dog"))
1908+
.add(docForClassification(indexName, "cat", "dog", "cat", "mule"))
1909+
.add(docForClassification(indexName, "cat", "fish", "cat", "dog"))
1910+
.add(docForClassification(indexName, "dog", "cat", "dog", "mule"))
1911+
.add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
1912+
.add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
1913+
.add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
1914+
.add(docForClassification(indexName, "ant", "cat", "ant", "wasp"));
19141915
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
1915-
19161916
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
19171917

19181918
{ // AucRoc
@@ -1927,7 +1927,7 @@ public void testEvaluateDataFrame_Classification() throws IOException {
19271927

19281928
AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME);
19291929
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
1930-
assertThat(aucRocResult.getScore(), closeTo(0.9299, 1e-9));
1930+
assertThat(aucRocResult.getScore(), closeTo(0.6425, 1e-9));
19311931
assertNotNull(aucRocResult.getCurve());
19321932
}
19331933
{ // Accuracy
@@ -2144,17 +2144,17 @@ private static XContentBuilder mappingForClassification() throws IOException {
21442144

21452145
private static IndexRequest docForClassification(String indexName,
21462146
String actualClass,
2147-
String predictedClass,
2148-
double p,
2149-
String otherClass) {
2147+
String... topPredictedClasses) {
2148+
assert topPredictedClasses.length > 0;
21502149
return new IndexRequest()
21512150
.index(indexName)
21522151
.source(XContentType.JSON,
21532152
actualClassField, actualClass,
2154-
predictedClassField, predictedClass,
2155-
topClassesField, List.of(
2156-
Map.of("class_name", predictedClass, "class_probability", p),
2157-
Map.of("class_name", otherClass, "class_probability", 1 - p)));
2153+
predictedClassField, topPredictedClasses[0],
2154+
topClassesField, IntStream.range(0, topPredictedClasses.length)
2155+
// Consecutive assigned probabilities are: 0.5, 0.25, 0.125, etc.
2156+
.mapToObj(i -> Map.of("class_name", topPredictedClasses[i], "class_probability", 1.0 / (2 << i)))
2157+
.collect(Collectors.toList()));
21582158
}
21592159

21602160
private static final String actualRegression = "regression_actual";

client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@
201201
import org.elasticsearch.client.ml.job.results.Influencer;
202202
import org.elasticsearch.client.ml.job.results.OverallBucket;
203203
import org.elasticsearch.client.ml.job.stats.JobStats;
204-
import org.elasticsearch.common.TriFunction;
205204
import org.elasticsearch.common.bytes.BytesReference;
206205
import org.elasticsearch.common.unit.ByteSizeUnit;
207206
import org.elasticsearch.common.unit.ByteSizeValue;
@@ -229,8 +228,11 @@
229228
import java.util.Map;
230229
import java.util.concurrent.CountDownLatch;
231230
import java.util.concurrent.TimeUnit;
231+
import java.util.function.BiFunction;
232232
import java.util.stream.Collectors;
233+
import java.util.stream.IntStream;
233234

235+
import static java.util.stream.Collectors.toList;
234236
import static org.hamcrest.Matchers.allOf;
235237
import static org.hamcrest.Matchers.closeTo;
236238
import static org.hamcrest.Matchers.contains;
@@ -3466,28 +3468,30 @@ public void testEvaluateDataFrame_Classification() throws Exception {
34663468
.endObject()
34673469
.endObject()
34683470
.endObject());
3469-
TriFunction<String, String, String, IndexRequest> indexRequest = (actualClass, predictedClass, otherClass) -> {
3471+
BiFunction<String, String[], IndexRequest> indexRequest = (actualClass, topPredictedClasses) -> {
3472+
assert topPredictedClasses.length > 0;
34703473
return new IndexRequest()
34713474
.source(XContentType.JSON,
34723475
"actual_class", actualClass,
3473-
"predicted_class", predictedClass,
3474-
"ml.top_classes", List.of(
3475-
Map.of("class_name", predictedClass, "class_probability", 0.9),
3476-
Map.of("class_name", otherClass, "class_probability", 0.1)));
3476+
"predicted_class", topPredictedClasses[0],
3477+
"ml.top_classes", IntStream.range(0, topPredictedClasses.length)
3478+
// Consecutive assigned probabilities are: 0.5, 0.25, 0.125, etc.
3479+
.mapToObj(i -> Map.of("class_name", topPredictedClasses[i], "class_probability", 1.0 / (2 << i)))
3480+
.collect(toList()));
34773481
};
34783482
BulkRequest bulkRequest =
34793483
new BulkRequest(indexName)
34803484
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
3481-
.add(indexRequest.apply("cat", "cat", "dog")) // #0
3482-
.add(indexRequest.apply("cat", "cat", "dog")) // #1
3483-
.add(indexRequest.apply("cat", "cat", "dog")) // #2
3484-
.add(indexRequest.apply("cat", "dog", "cat")) // #3
3485-
.add(indexRequest.apply("cat", "fox", "cat")) // #4
3486-
.add(indexRequest.apply("dog", "cat", "dog")) // #5
3487-
.add(indexRequest.apply("dog", "dog", "cat")) // #6
3488-
.add(indexRequest.apply("dog", "dog", "cat")) // #7
3489-
.add(indexRequest.apply("dog", "dog", "cat")) // #8
3490-
.add(indexRequest.apply("ant", "cat", "ant")); // #9
3485+
.add(indexRequest.apply("cat", new String[]{"cat", "dog", "ant"})) // #0
3486+
.add(indexRequest.apply("cat", new String[]{"cat", "dog", "ant"})) // #1
3487+
.add(indexRequest.apply("cat", new String[]{"cat", "horse", "dog"})) // #2
3488+
.add(indexRequest.apply("cat", new String[]{"dog", "cat", "mule"})) // #3
3489+
.add(indexRequest.apply("cat", new String[]{"fox", "cat", "dog"})) // #4
3490+
.add(indexRequest.apply("dog", new String[]{"cat", "dog", "mule"})) // #5
3491+
.add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #6
3492+
.add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #7
3493+
.add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #8
3494+
.add(indexRequest.apply("ant", new String[]{"cat", "ant", "wasp"})); // #9
34913495
RestHighLevelClient client = highLevelClient();
34923496
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
34933497
client.bulk(bulkRequest, RequestOptions.DEFAULT);
@@ -3561,7 +3565,7 @@ public void testEvaluateDataFrame_Classification() throws Exception {
35613565
assertThat(otherClassesCount, equalTo(0L));
35623566

35633567
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
3564-
assertThat(aucRocScore, equalTo(0.7162000000000013));
3568+
assertThat(aucRocScore, closeTo(0.6425, 1e-9));
35653569
}
35663570
}
35673571

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRoc.java

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -187,28 +187,24 @@ public void process(Aggregations aggs) {
187187
Nested restNested = restAgg.getAggregations().get(NESTED_AGG_NAME);
188188
Filter restNestedFilter = restNested.getAggregations().get(NESTED_FILTER_AGG_NAME);
189189

190-
long filteredDocCount = classNestedFilter.getDocCount() + restNestedFilter.getDocCount();
191-
long totalDocCount = classAgg.getDocCount() + restAgg.getDocCount();
192-
193190
if (classAgg.getDocCount() == 0) {
194191
throw ExceptionsHelper.badRequestException(
195192
"[{}] requires at least one [{}] to have the value [{}]",
196193
getName(), fields.get().getActualField(), className);
197194
}
198-
if (classNestedFilter.getDocCount() < classAgg.getDocCount()) {
199-
throw ExceptionsHelper.badRequestException(
200-
"[{}] requires that [{}] appears as one of the [{}] for every document (appeared for {} out of {})",
201-
getName(), className, fields.get().getPredictedClassField(), filteredDocCount, totalDocCount);
202-
}
203195
if (restAgg.getDocCount() == 0) {
204196
throw ExceptionsHelper.badRequestException(
205197
"[{}] requires at least one [{}] to have a different value than [{}]",
206198
getName(), fields.get().getActualField(), className);
207199
}
208-
if (restNestedFilter.getDocCount() < restAgg.getDocCount()) {
200+
long filteredDocCount = classNestedFilter.getDocCount() + restNestedFilter.getDocCount();
201+
long totalDocCount = classAgg.getDocCount() + restAgg.getDocCount();
202+
if (filteredDocCount < totalDocCount) {
209203
throw ExceptionsHelper.badRequestException(
210-
"[{}] requires that [{}] appears as one of the [{}] for every document (appeared for {} out of {})",
211-
getName(), className, fields.get().getPredictedClassField(), filteredDocCount, totalDocCount);
204+
"[{}] requires that [{}] appears as one of the [{}] for every document (appeared in {} out of {}). "
205+
+ "This is probably caused by the {} value being less than the total number of actual classes in the dataset.",
206+
getName(), className, fields.get().getPredictedClassField(), filteredDocCount, totalDocCount,
207+
org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification.NUM_TOP_CLASSES.getPreferredName());
212208
}
213209

214210
Percentiles classPercentiles = classNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME);

x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ setup:
685685
---
686686
"Test classification auc_roc given predicted_class_field is never equal to mouse":
687687
- do:
688-
catch: /\[auc_roc\] requires that \[mouse\] appears as one of the \[ml.top_classes.class_name\] for every document \(appeared for 0 out of 8\)/
688+
catch: /\[auc_roc\] requires that \[mouse\] appears as one of the \[ml.top_classes.class_name\] for every document \(appeared in 0 out of 8\)./
689689
ml.evaluate_data_frame:
690690
body: >
691691
{

0 commit comments

Comments
 (0)