Skip to content

Commit 31f6e78

Browse files
authored
Allow the user to specify 'query' in Evaluate Data Frame request (elastic#45775)
1 parent 3cf174d commit 31f6e78

File tree

19 files changed

+414
-108
lines changed

19 files changed

+414
-108
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameRequest.java

+33-11
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121

2222
import org.elasticsearch.client.Validatable;
2323
import org.elasticsearch.client.ValidationException;
24+
import org.elasticsearch.client.ml.dataframe.QueryConfig;
2425
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
26+
import org.elasticsearch.common.Nullable;
2527
import org.elasticsearch.common.ParseField;
2628
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
2729
import org.elasticsearch.common.xcontent.ToXContentObject;
@@ -37,20 +39,25 @@
3739
import java.util.Optional;
3840

3941
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
42+
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
4043
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
4144

4245
public class EvaluateDataFrameRequest implements ToXContentObject, Validatable {
4346

4447
private static final ParseField INDEX = new ParseField("index");
48+
private static final ParseField QUERY = new ParseField("query");
4549
private static final ParseField EVALUATION = new ParseField("evaluation");
4650

4751
@SuppressWarnings("unchecked")
4852
private static final ConstructingObjectParser<EvaluateDataFrameRequest, Void> PARSER =
4953
new ConstructingObjectParser<>(
50-
"evaluate_data_frame_request", true, args -> new EvaluateDataFrameRequest((List<String>) args[0], (Evaluation) args[1]));
54+
"evaluate_data_frame_request",
55+
true,
56+
args -> new EvaluateDataFrameRequest((List<String>) args[0], (QueryConfig) args[1], (Evaluation) args[2]));
5157

5258
static {
5359
PARSER.declareStringArray(constructorArg(), INDEX);
60+
PARSER.declareObject(optionalConstructorArg(), (p, c) -> QueryConfig.fromXContent(p), QUERY);
5461
PARSER.declareObject(constructorArg(), (p, c) -> parseEvaluation(p), EVALUATION);
5562
}
5663

@@ -67,14 +74,16 @@ public static EvaluateDataFrameRequest fromXContent(XContentParser parser) {
6774
}
6875

6976
private List<String> indices;
77+
private QueryConfig queryConfig;
7078
private Evaluation evaluation;
7179

72-
public EvaluateDataFrameRequest(String index, Evaluation evaluation) {
73-
this(Arrays.asList(index), evaluation);
80+
public EvaluateDataFrameRequest(String index, @Nullable QueryConfig queryConfig, Evaluation evaluation) {
81+
this(Arrays.asList(index), queryConfig, evaluation);
7482
}
7583

76-
public EvaluateDataFrameRequest(List<String> indices, Evaluation evaluation) {
84+
public EvaluateDataFrameRequest(List<String> indices, @Nullable QueryConfig queryConfig, Evaluation evaluation) {
7785
setIndices(indices);
86+
setQueryConfig(queryConfig);
7887
setEvaluation(evaluation);
7988
}
8089

@@ -87,6 +96,14 @@ public final void setIndices(List<String> indices) {
8796
this.indices = new ArrayList<>(indices);
8897
}
8998

99+
public QueryConfig getQueryConfig() {
100+
return queryConfig;
101+
}
102+
103+
public final void setQueryConfig(QueryConfig queryConfig) {
104+
this.queryConfig = queryConfig;
105+
}
106+
90107
public Evaluation getEvaluation() {
91108
return evaluation;
92109
}
@@ -111,18 +128,22 @@ public Optional<ValidationException> validate() {
111128

112129
@Override
113130
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
114-
return builder
115-
.startObject()
116-
.array(INDEX.getPreferredName(), indices.toArray())
117-
.startObject(EVALUATION.getPreferredName())
118-
.field(evaluation.getName(), evaluation)
119-
.endObject()
131+
builder.startObject();
132+
builder.array(INDEX.getPreferredName(), indices.toArray());
133+
if (queryConfig != null) {
134+
builder.field(QUERY.getPreferredName(), queryConfig.getQuery());
135+
}
136+
builder
137+
.startObject(EVALUATION.getPreferredName())
138+
.field(evaluation.getName(), evaluation)
120139
.endObject();
140+
builder.endObject();
141+
return builder;
121142
}
122143

123144
@Override
124145
public int hashCode() {
125-
return Objects.hash(indices, evaluation);
146+
return Objects.hash(indices, queryConfig, evaluation);
126147
}
127148

128149
@Override
@@ -131,6 +152,7 @@ public boolean equals(Object o) {
131152
if (o == null || getClass() != o.getClass()) return false;
132153
EvaluateDataFrameRequest that = (EvaluateDataFrameRequest) o;
133154
return Objects.equals(indices, that.indices)
155+
&& Objects.equals(queryConfig, that.queryConfig)
134156
&& Objects.equals(evaluation, that.evaluation);
135157
}
136158
}

client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java

+2-10
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.elasticsearch.client.ml.DeleteJobRequest;
3737
import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
3838
import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
39+
import org.elasticsearch.client.ml.EvaluateDataFrameRequestTests;
3940
import org.elasticsearch.client.ml.FindFileStructureRequest;
4041
import org.elasticsearch.client.ml.FindFileStructureRequestTests;
4142
import org.elasticsearch.client.ml.FlushJobRequest;
@@ -85,9 +86,6 @@
8586
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
8687
import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider;
8788
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
88-
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
89-
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
90-
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
9189
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
9290
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
9391
import org.elasticsearch.client.ml.job.config.Detector;
@@ -779,13 +777,7 @@ public void testDeleteDataFrameAnalytics() {
779777
}
780778

781779
public void testEvaluateDataFrame() throws IOException {
782-
EvaluateDataFrameRequest evaluateRequest =
783-
new EvaluateDataFrameRequest(
784-
Arrays.asList(generateRandomStringArray(1, 10, false, false)),
785-
new BinarySoftClassification(
786-
randomAlphaOfLengthBetween(1, 10),
787-
randomAlphaOfLengthBetween(1, 10),
788-
PrecisionMetric.at(0.5), RecallMetric.at(0.6, 0.7)));
780+
EvaluateDataFrameRequest evaluateRequest = EvaluateDataFrameRequestTests.createRandom();
789781
Request request = MLRequestConverters.evaluateDataFrame(evaluateRequest);
790782
assertEquals(HttpPost.METHOD_NAME, request.getMethod());
791783
assertEquals("/_ml/data_frame/_evaluate", request.getEndpoint());

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

+70-19
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@
149149
import org.elasticsearch.common.xcontent.XContentFactory;
150150
import org.elasticsearch.common.xcontent.XContentType;
151151
import org.elasticsearch.index.query.MatchAllQueryBuilder;
152+
import org.elasticsearch.index.query.QueryBuilders;
152153
import org.elasticsearch.rest.RestStatus;
153154
import org.elasticsearch.search.SearchHit;
154155
import org.junit.After;
@@ -1427,7 +1428,7 @@ public void testStartDataFrameAnalyticsConfig() throws Exception {
14271428
public void testStopDataFrameAnalyticsConfig() throws Exception {
14281429
String sourceIndex = "stop-test-source-index";
14291430
String destIndex = "stop-test-dest-index";
1430-
createIndex(sourceIndex, mappingForClassification());
1431+
createIndex(sourceIndex, defaultMappingForTest());
14311432
highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000)
14321433
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), RequestOptions.DEFAULT);
14331434

@@ -1525,27 +1526,28 @@ public void testDeleteDataFrameAnalyticsConfig_ConfigNotFound() {
15251526
assertThat(exception.status().getStatus(), equalTo(404));
15261527
}
15271528

1528-
public void testEvaluateDataFrame() throws IOException {
1529+
public void testEvaluateDataFrame_BinarySoftClassification() throws IOException {
15291530
String indexName = "evaluate-test-index";
15301531
createIndex(indexName, mappingForClassification());
15311532
BulkRequest bulk = new BulkRequest()
15321533
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
1533-
.add(docForClassification(indexName, false, 0.1)) // #0
1534-
.add(docForClassification(indexName, false, 0.2)) // #1
1535-
.add(docForClassification(indexName, false, 0.3)) // #2
1536-
.add(docForClassification(indexName, false, 0.4)) // #3
1537-
.add(docForClassification(indexName, false, 0.7)) // #4
1538-
.add(docForClassification(indexName, true, 0.2)) // #5
1539-
.add(docForClassification(indexName, true, 0.3)) // #6
1540-
.add(docForClassification(indexName, true, 0.4)) // #7
1541-
.add(docForClassification(indexName, true, 0.8)) // #8
1542-
.add(docForClassification(indexName, true, 0.9)); // #9
1534+
.add(docForClassification(indexName, "blue", false, 0.1)) // #0
1535+
.add(docForClassification(indexName, "blue", false, 0.2)) // #1
1536+
.add(docForClassification(indexName, "blue", false, 0.3)) // #2
1537+
.add(docForClassification(indexName, "blue", false, 0.4)) // #3
1538+
.add(docForClassification(indexName, "blue", false, 0.7)) // #4
1539+
.add(docForClassification(indexName, "blue", true, 0.2)) // #5
1540+
.add(docForClassification(indexName, "green", true, 0.3)) // #6
1541+
.add(docForClassification(indexName, "green", true, 0.4)) // #7
1542+
.add(docForClassification(indexName, "green", true, 0.8)) // #8
1543+
.add(docForClassification(indexName, "green", true, 0.9)); // #9
15431544
highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
15441545

15451546
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
15461547
EvaluateDataFrameRequest evaluateDataFrameRequest =
15471548
new EvaluateDataFrameRequest(
15481549
indexName,
1550+
null,
15491551
new BinarySoftClassification(
15501552
actualField,
15511553
probabilityField,
@@ -1596,7 +1598,48 @@ public void testEvaluateDataFrame() throws IOException {
15961598
assertThat(curvePointAtThreshold1.getTruePositiveRate(), equalTo(0.0));
15971599
assertThat(curvePointAtThreshold1.getFalsePositiveRate(), equalTo(0.0));
15981600
assertThat(curvePointAtThreshold1.getThreshold(), equalTo(1.0));
1601+
}
1602+
1603+
public void testEvaluateDataFrame_BinarySoftClassification_WithQuery() throws IOException {
1604+
String indexName = "evaluate-with-query-test-index";
1605+
createIndex(indexName, mappingForClassification());
1606+
BulkRequest bulk = new BulkRequest()
1607+
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
1608+
.add(docForClassification(indexName, "blue", true, 1.0)) // #0
1609+
.add(docForClassification(indexName, "blue", true, 1.0)) // #1
1610+
.add(docForClassification(indexName, "blue", true, 1.0)) // #2
1611+
.add(docForClassification(indexName, "blue", true, 1.0)) // #3
1612+
.add(docForClassification(indexName, "blue", true, 0.0)) // #4
1613+
.add(docForClassification(indexName, "blue", true, 0.0)) // #5
1614+
.add(docForClassification(indexName, "green", true, 0.0)) // #6
1615+
.add(docForClassification(indexName, "green", true, 0.0)) // #7
1616+
.add(docForClassification(indexName, "green", true, 0.0)) // #8
1617+
.add(docForClassification(indexName, "green", true, 1.0)); // #9
1618+
highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
15991619

1620+
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
1621+
EvaluateDataFrameRequest evaluateDataFrameRequest =
1622+
new EvaluateDataFrameRequest(
1623+
indexName,
1624+
// Request only "blue" subset to be evaluated
1625+
new QueryConfig(QueryBuilders.termQuery(datasetField, "blue")),
1626+
new BinarySoftClassification(actualField, probabilityField, ConfusionMatrixMetric.at(0.5)));
1627+
1628+
EvaluateDataFrameResponse evaluateDataFrameResponse =
1629+
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
1630+
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(BinarySoftClassification.NAME));
1631+
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
1632+
1633+
ConfusionMatrixMetric.Result confusionMatrixResult = evaluateDataFrameResponse.getMetricByName(ConfusionMatrixMetric.NAME);
1634+
assertThat(confusionMatrixResult.getMetricName(), equalTo(ConfusionMatrixMetric.NAME));
1635+
ConfusionMatrixMetric.ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5");
1636+
assertThat(confusionMatrix.getTruePositives(), equalTo(4L)); // docs #0, #1, #2 and #3
1637+
assertThat(confusionMatrix.getFalsePositives(), equalTo(0L));
1638+
assertThat(confusionMatrix.getTrueNegatives(), equalTo(0L));
1639+
assertThat(confusionMatrix.getFalseNegatives(), equalTo(2L)); // docs #4 and #5
1640+
}
1641+
1642+
public void testEvaluateDataFrame_Regression() throws IOException {
16001643
String regressionIndex = "evaluate-regression-test-index";
16011644
createIndex(regressionIndex, mappingForRegression());
16021645
BulkRequest regressionBulk = new BulkRequest()
@@ -1613,10 +1656,14 @@ public void testEvaluateDataFrame() throws IOException {
16131656
.add(docForRegression(regressionIndex, 0.5, 0.9)); // #9
16141657
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
16151658

1616-
evaluateDataFrameRequest = new EvaluateDataFrameRequest(regressionIndex,
1617-
new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
1659+
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
1660+
EvaluateDataFrameRequest evaluateDataFrameRequest =
1661+
new EvaluateDataFrameRequest(
1662+
regressionIndex,
1663+
null,
1664+
new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
16181665

1619-
evaluateDataFrameResponse =
1666+
EvaluateDataFrameResponse evaluateDataFrameResponse =
16201667
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
16211668
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME));
16221669
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(2));
@@ -1643,12 +1690,16 @@ private static XContentBuilder defaultMappingForTest() throws IOException {
16431690
.endObject();
16441691
}
16451692

1693+
private static final String datasetField = "dataset";
16461694
private static final String actualField = "label";
16471695
private static final String probabilityField = "p";
16481696

16491697
private static XContentBuilder mappingForClassification() throws IOException {
16501698
return XContentFactory.jsonBuilder().startObject()
16511699
.startObject("properties")
1700+
.startObject(datasetField)
1701+
.field("type", "keyword")
1702+
.endObject()
16521703
.startObject(actualField)
16531704
.field("type", "keyword")
16541705
.endObject()
@@ -1659,10 +1710,10 @@ private static XContentBuilder mappingForClassification() throws IOException {
16591710
.endObject();
16601711
}
16611712

1662-
private static IndexRequest docForClassification(String indexName, boolean isTrue, double p) {
1713+
private static IndexRequest docForClassification(String indexName, String dataset, boolean isTrue, double p) {
16631714
return new IndexRequest()
16641715
.index(indexName)
1665-
.source(XContentType.JSON, actualField, Boolean.toString(isTrue), probabilityField, p);
1716+
.source(XContentType.JSON, datasetField, dataset, actualField, Boolean.toString(isTrue), probabilityField, p);
16661717
}
16671718

16681719
private static final String actualRegression = "regression_actual";
@@ -1697,7 +1748,7 @@ public void testEstimateMemoryUsage() throws IOException {
16971748
BulkRequest bulk1 = new BulkRequest()
16981749
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
16991750
for (int i = 0; i < 10; ++i) {
1700-
bulk1.add(docForClassification(indexName, randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
1751+
bulk1.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
17011752
}
17021753
highLevelClient().bulk(bulk1, RequestOptions.DEFAULT);
17031754

@@ -1723,7 +1774,7 @@ public void testEstimateMemoryUsage() throws IOException {
17231774
BulkRequest bulk2 = new BulkRequest()
17241775
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
17251776
for (int i = 10; i < 100; ++i) {
1726-
bulk2.add(docForClassification(indexName, randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
1777+
bulk2.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
17271778
}
17281779
highLevelClient().bulk(bulk2, RequestOptions.DEFAULT);
17291780

0 commit comments

Comments
 (0)