149
149
import org .elasticsearch .common .xcontent .XContentFactory ;
150
150
import org .elasticsearch .common .xcontent .XContentType ;
151
151
import org .elasticsearch .index .query .MatchAllQueryBuilder ;
152
+ import org .elasticsearch .index .query .QueryBuilders ;
152
153
import org .elasticsearch .rest .RestStatus ;
153
154
import org .elasticsearch .search .SearchHit ;
154
155
import org .junit .After ;
@@ -1427,7 +1428,7 @@ public void testStartDataFrameAnalyticsConfig() throws Exception {
1427
1428
public void testStopDataFrameAnalyticsConfig () throws Exception {
1428
1429
String sourceIndex = "stop-test-source-index" ;
1429
1430
String destIndex = "stop-test-dest-index" ;
1430
- createIndex (sourceIndex , mappingForClassification ());
1431
+ createIndex (sourceIndex , defaultMappingForTest ());
1431
1432
highLevelClient ().index (new IndexRequest (sourceIndex ).source (XContentType .JSON , "total" , 10000 )
1432
1433
.setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE ), RequestOptions .DEFAULT );
1433
1434
@@ -1525,27 +1526,28 @@ public void testDeleteDataFrameAnalyticsConfig_ConfigNotFound() {
1525
1526
assertThat (exception .status ().getStatus (), equalTo (404 ));
1526
1527
}
1527
1528
1528
- public void testEvaluateDataFrame () throws IOException {
1529
+ public void testEvaluateDataFrame_BinarySoftClassification () throws IOException {
1529
1530
String indexName = "evaluate-test-index" ;
1530
1531
createIndex (indexName , mappingForClassification ());
1531
1532
BulkRequest bulk = new BulkRequest ()
1532
1533
.setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE )
1533
- .add (docForClassification (indexName , false , 0.1 )) // #0
1534
- .add (docForClassification (indexName , false , 0.2 )) // #1
1535
- .add (docForClassification (indexName , false , 0.3 )) // #2
1536
- .add (docForClassification (indexName , false , 0.4 )) // #3
1537
- .add (docForClassification (indexName , false , 0.7 )) // #4
1538
- .add (docForClassification (indexName , true , 0.2 )) // #5
1539
- .add (docForClassification (indexName , true , 0.3 )) // #6
1540
- .add (docForClassification (indexName , true , 0.4 )) // #7
1541
- .add (docForClassification (indexName , true , 0.8 )) // #8
1542
- .add (docForClassification (indexName , true , 0.9 )); // #9
1534
+ .add (docForClassification (indexName , "blue" , false , 0.1 )) // #0
1535
+ .add (docForClassification (indexName , "blue" , false , 0.2 )) // #1
1536
+ .add (docForClassification (indexName , "blue" , false , 0.3 )) // #2
1537
+ .add (docForClassification (indexName , "blue" , false , 0.4 )) // #3
1538
+ .add (docForClassification (indexName , "blue" , false , 0.7 )) // #4
1539
+ .add (docForClassification (indexName , "blue" , true , 0.2 )) // #5
1540
+ .add (docForClassification (indexName , "green" , true , 0.3 )) // #6
1541
+ .add (docForClassification (indexName , "green" , true , 0.4 )) // #7
1542
+ .add (docForClassification (indexName , "green" , true , 0.8 )) // #8
1543
+ .add (docForClassification (indexName , "green" , true , 0.9 )); // #9
1543
1544
highLevelClient ().bulk (bulk , RequestOptions .DEFAULT );
1544
1545
1545
1546
MachineLearningClient machineLearningClient = highLevelClient ().machineLearning ();
1546
1547
EvaluateDataFrameRequest evaluateDataFrameRequest =
1547
1548
new EvaluateDataFrameRequest (
1548
1549
indexName ,
1550
+ null ,
1549
1551
new BinarySoftClassification (
1550
1552
actualField ,
1551
1553
probabilityField ,
@@ -1596,7 +1598,48 @@ public void testEvaluateDataFrame() throws IOException {
1596
1598
assertThat (curvePointAtThreshold1 .getTruePositiveRate (), equalTo (0.0 ));
1597
1599
assertThat (curvePointAtThreshold1 .getFalsePositiveRate (), equalTo (0.0 ));
1598
1600
assertThat (curvePointAtThreshold1 .getThreshold (), equalTo (1.0 ));
1601
+ }
1602
+
1603
+ public void testEvaluateDataFrame_BinarySoftClassification_WithQuery () throws IOException {
1604
+ String indexName = "evaluate-with-query-test-index" ;
1605
+ createIndex (indexName , mappingForClassification ());
1606
+ BulkRequest bulk = new BulkRequest ()
1607
+ .setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE )
1608
+ .add (docForClassification (indexName , "blue" , true , 1.0 )) // #0
1609
+ .add (docForClassification (indexName , "blue" , true , 1.0 )) // #1
1610
+ .add (docForClassification (indexName , "blue" , true , 1.0 )) // #2
1611
+ .add (docForClassification (indexName , "blue" , true , 1.0 )) // #3
1612
+ .add (docForClassification (indexName , "blue" , true , 0.0 )) // #4
1613
+ .add (docForClassification (indexName , "blue" , true , 0.0 )) // #5
1614
+ .add (docForClassification (indexName , "green" , true , 0.0 )) // #6
1615
+ .add (docForClassification (indexName , "green" , true , 0.0 )) // #7
1616
+ .add (docForClassification (indexName , "green" , true , 0.0 )) // #8
1617
+ .add (docForClassification (indexName , "green" , true , 1.0 )); // #9
1618
+ highLevelClient ().bulk (bulk , RequestOptions .DEFAULT );
1599
1619
1620
+ MachineLearningClient machineLearningClient = highLevelClient ().machineLearning ();
1621
+ EvaluateDataFrameRequest evaluateDataFrameRequest =
1622
+ new EvaluateDataFrameRequest (
1623
+ indexName ,
1624
+ // Request only "blue" subset to be evaluated
1625
+ new QueryConfig (QueryBuilders .termQuery (datasetField , "blue" )),
1626
+ new BinarySoftClassification (actualField , probabilityField , ConfusionMatrixMetric .at (0.5 )));
1627
+
1628
+ EvaluateDataFrameResponse evaluateDataFrameResponse =
1629
+ execute (evaluateDataFrameRequest , machineLearningClient ::evaluateDataFrame , machineLearningClient ::evaluateDataFrameAsync );
1630
+ assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (BinarySoftClassification .NAME ));
1631
+ assertThat (evaluateDataFrameResponse .getMetrics ().size (), equalTo (1 ));
1632
+
1633
+ ConfusionMatrixMetric .Result confusionMatrixResult = evaluateDataFrameResponse .getMetricByName (ConfusionMatrixMetric .NAME );
1634
+ assertThat (confusionMatrixResult .getMetricName (), equalTo (ConfusionMatrixMetric .NAME ));
1635
+ ConfusionMatrixMetric .ConfusionMatrix confusionMatrix = confusionMatrixResult .getScoreByThreshold ("0.5" );
1636
+ assertThat (confusionMatrix .getTruePositives (), equalTo (4L )); // docs #0, #1, #2 and #3
1637
+ assertThat (confusionMatrix .getFalsePositives (), equalTo (0L ));
1638
+ assertThat (confusionMatrix .getTrueNegatives (), equalTo (0L ));
1639
+ assertThat (confusionMatrix .getFalseNegatives (), equalTo (2L )); // docs #4 and #5
1640
+ }
1641
+
1642
+ public void testEvaluateDataFrame_Regression () throws IOException {
1600
1643
String regressionIndex = "evaluate-regression-test-index" ;
1601
1644
createIndex (regressionIndex , mappingForRegression ());
1602
1645
BulkRequest regressionBulk = new BulkRequest ()
@@ -1613,10 +1656,14 @@ public void testEvaluateDataFrame() throws IOException {
1613
1656
.add (docForRegression (regressionIndex , 0.5 , 0.9 )); // #9
1614
1657
highLevelClient ().bulk (regressionBulk , RequestOptions .DEFAULT );
1615
1658
1616
- evaluateDataFrameRequest = new EvaluateDataFrameRequest (regressionIndex ,
1617
- new Regression (actualRegression , probabilityRegression , new MeanSquaredErrorMetric (), new RSquaredMetric ()));
1659
+ MachineLearningClient machineLearningClient = highLevelClient ().machineLearning ();
1660
+ EvaluateDataFrameRequest evaluateDataFrameRequest =
1661
+ new EvaluateDataFrameRequest (
1662
+ regressionIndex ,
1663
+ null ,
1664
+ new Regression (actualRegression , probabilityRegression , new MeanSquaredErrorMetric (), new RSquaredMetric ()));
1618
1665
1619
- evaluateDataFrameResponse =
1666
+ EvaluateDataFrameResponse evaluateDataFrameResponse =
1620
1667
execute (evaluateDataFrameRequest , machineLearningClient ::evaluateDataFrame , machineLearningClient ::evaluateDataFrameAsync );
1621
1668
assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Regression .NAME ));
1622
1669
assertThat (evaluateDataFrameResponse .getMetrics ().size (), equalTo (2 ));
@@ -1643,12 +1690,16 @@ private static XContentBuilder defaultMappingForTest() throws IOException {
1643
1690
.endObject ();
1644
1691
}
1645
1692
1693
+ private static final String datasetField = "dataset" ;
1646
1694
private static final String actualField = "label" ;
1647
1695
private static final String probabilityField = "p" ;
1648
1696
1649
1697
private static XContentBuilder mappingForClassification () throws IOException {
1650
1698
return XContentFactory .jsonBuilder ().startObject ()
1651
1699
.startObject ("properties" )
1700
+ .startObject (datasetField )
1701
+ .field ("type" , "keyword" )
1702
+ .endObject ()
1652
1703
.startObject (actualField )
1653
1704
.field ("type" , "keyword" )
1654
1705
.endObject ()
@@ -1659,10 +1710,10 @@ private static XContentBuilder mappingForClassification() throws IOException {
1659
1710
.endObject ();
1660
1711
}
1661
1712
1662
- private static IndexRequest docForClassification (String indexName , boolean isTrue , double p ) {
1713
+ private static IndexRequest docForClassification (String indexName , String dataset , boolean isTrue , double p ) {
1663
1714
return new IndexRequest ()
1664
1715
.index (indexName )
1665
- .source (XContentType .JSON , actualField , Boolean .toString (isTrue ), probabilityField , p );
1716
+ .source (XContentType .JSON , datasetField , dataset , actualField , Boolean .toString (isTrue ), probabilityField , p );
1666
1717
}
1667
1718
1668
1719
private static final String actualRegression = "regression_actual" ;
@@ -1697,7 +1748,7 @@ public void testEstimateMemoryUsage() throws IOException {
1697
1748
BulkRequest bulk1 = new BulkRequest ()
1698
1749
.setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE );
1699
1750
for (int i = 0 ; i < 10 ; ++i ) {
1700
- bulk1 .add (docForClassification (indexName , randomBoolean (), randomDoubleBetween (0.0 , 1.0 , true )));
1751
+ bulk1 .add (docForClassification (indexName , randomAlphaOfLength ( 10 ), randomBoolean (), randomDoubleBetween (0.0 , 1.0 , true )));
1701
1752
}
1702
1753
highLevelClient ().bulk (bulk1 , RequestOptions .DEFAULT );
1703
1754
@@ -1723,7 +1774,7 @@ public void testEstimateMemoryUsage() throws IOException {
1723
1774
BulkRequest bulk2 = new BulkRequest ()
1724
1775
.setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE );
1725
1776
for (int i = 10 ; i < 100 ; ++i ) {
1726
- bulk2 .add (docForClassification (indexName , randomBoolean (), randomDoubleBetween (0.0 , 1.0 , true )));
1777
+ bulk2 .add (docForClassification (indexName , randomAlphaOfLength ( 10 ), randomBoolean (), randomDoubleBetween (0.0 , 1.0 , true )));
1727
1778
}
1728
1779
highLevelClient ().bulk (bulk2 , RequestOptions .DEFAULT );
1729
1780
0 commit comments