diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 2820f2d5530f9..ac323e6a2ea01 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1776,7 +1776,7 @@ public void testEvaluateDataFrame_Regression() throws IOException { new EvaluateDataFrameRequest( regressionIndex, null, - new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric())); + new Regression(actualRegression, predictedRegression, new MeanSquaredErrorMetric(), new RSquaredMetric())); EvaluateDataFrameResponse evaluateDataFrameResponse = execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); @@ -1933,7 +1933,7 @@ private static IndexRequest docForClassification(String indexName, String actual } private static final String actualRegression = "regression_actual"; - private static final String probabilityRegression = "regression_prob"; + private static final String predictedRegression = "regression_predicted"; private static XContentBuilder mappingForRegression() throws IOException { return XContentFactory.jsonBuilder().startObject() @@ -1941,17 +1941,17 @@ private static XContentBuilder mappingForRegression() throws IOException { .startObject(actualRegression) .field("type", "double") .endObject() - .startObject(probabilityRegression) + .startObject(predictedRegression) .field("type", "double") .endObject() .endObject() .endObject(); } - private static IndexRequest docForRegression(String indexName, double act, double p) { + private static IndexRequest docForRegression(String indexName, double actualValue, double predictedValue) { return new IndexRequest() .index(indexName) - .source(XContentType.JSON, actualRegression, act, probabilityRegression, p); + .source(XContentType.JSON, actualRegression, actualValue, predictedRegression, predictedValue); } private void createIndex(String indexName, XContentBuilder mapping) throws IOException { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 2367af160f008..717e8c04c9d10 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -139,8 +139,11 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; import org.elasticsearch.client.ml.dataframe.OutlierDetection; import org.elasticsearch.client.ml.dataframe.QueryConfig; -import org.elasticsearch.client.ml.dataframe.Regression; +import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; @@ -2821,7 +2824,7 @@ public void testGetDataFrameAnalytics() throws Exception { List configs = response.getAnalytics(); // end::get-data-frame-analytics-response - assertThat(configs.size(), equalTo(1)); + assertThat(configs, hasSize(1)); } { GetDataFrameAnalyticsRequest request = new GetDataFrameAnalyticsRequest("my-analytics-config"); @@ -2871,7 +2874,7 @@ public void testGetDataFrameAnalyticsStats() throws Exception { List stats = response.getAnalyticsStats(); // end::get-data-frame-analytics-stats-response - assertThat(stats.size(), equalTo(1)); + assertThat(stats, hasSize(1)); } { GetDataFrameAnalyticsStatsRequest request = new GetDataFrameAnalyticsStatsRequest("my-analytics-config"); @@ -2939,8 +2942,20 @@ public void testPutDataFrameAnalytics() throws Exception { .build(); // end::put-data-frame-analytics-outlier-detection-customized + // tag::put-data-frame-analytics-classification + DataFrameAnalysis classification = org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable") // <1> + .setLambda(1.0) // <2> + .setGamma(5.5) // <3> + .setEta(5.5) // <4> + .setMaximumNumberTrees(50) // <5> + .setFeatureBagFraction(0.4) // <6> + .setPredictionFieldName("my_prediction_field_name") // <7> + .setTrainingPercent(50.0) // <8> + .build(); + // end::put-data-frame-analytics-classification + // tag::put-data-frame-analytics-regression - DataFrameAnalysis regression = Regression.builder("my_dependent_variable") // <1> + DataFrameAnalysis regression = org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable") // <1> .setLambda(1.0) // <2> .setGamma(5.5) // <3> .setEta(5.5) // <4> @@ -3209,18 +3224,24 @@ public void testEvaluateDataFrame() throws Exception { client.indices().create(createIndexRequest, RequestOptions.DEFAULT); client.bulk(bulkRequest, RequestOptions.DEFAULT); { + // tag::evaluate-data-frame-evaluation-softclassification + Evaluation evaluation = + new BinarySoftClassification( // <1> + "label", // <2> + "p", // <3> + // Evaluation metrics // <4> + PrecisionMetric.at(0.4, 0.5, 0.6), // <5> + RecallMetric.at(0.5, 0.7), // <6> + ConfusionMatrixMetric.at(0.5), // <7> + AucRocMetric.withCurve()); // <8> + // end::evaluate-data-frame-evaluation-softclassification + // tag::evaluate-data-frame-request - EvaluateDataFrameRequest request = new EvaluateDataFrameRequest( // <1> - indexName, // <2> - new QueryConfig(QueryBuilders.termQuery("dataset", "blue")), // <3> - new BinarySoftClassification( // <4> - "label", // <5> - "p", // <6> - // Evaluation metrics // <7> - PrecisionMetric.at(0.4, 0.5, 0.6), // <8> - RecallMetric.at(0.5, 0.7), // <9> - ConfusionMatrixMetric.at(0.5), // <10> - AucRocMetric.withCurve())); // <11> + EvaluateDataFrameRequest request = + new EvaluateDataFrameRequest( // <1> + indexName, // <2> + new QueryConfig(QueryBuilders.termQuery("dataset", "blue")), // <3> + evaluation); // <4> // end::evaluate-data-frame-request // tag::evaluate-data-frame-execute @@ -3229,16 +3250,18 @@ public void testEvaluateDataFrame() throws Exception { // tag::evaluate-data-frame-response List metrics = response.getMetrics(); // <1> + // end::evaluate-data-frame-response - PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <2> - double precision = precisionResult.getScoreByThreshold("0.4"); // <3> + // tag::evaluate-data-frame-results-softclassification + PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <1> + double precision = precisionResult.getScoreByThreshold("0.4"); // <2> - ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <4> - ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5"); // <5> - // end::evaluate-data-frame-response + ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <3> + ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5"); // <4> + // end::evaluate-data-frame-results-softclassification assertThat( - metrics.stream().map(m -> m.getMetricName()).collect(Collectors.toList()), + metrics.stream().map(EvaluationMetric.Result::getMetricName).collect(Collectors.toList()), containsInAnyOrder(PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, AucRocMetric.NAME)); assertThat(precision, closeTo(0.6, 1e-9)); assertThat(confusionMatrix.getTruePositives(), equalTo(2L)); // docs #8 and #9 @@ -3284,6 +3307,140 @@ public void onFailure(Exception e) { } } + public void testEvaluateDataFrame_Classification() throws Exception { + String indexName = "evaluate-classification-test-index"; + CreateIndexRequest createIndexRequest = + new CreateIndexRequest(indexName) + .mapping(XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject("actual_class") + .field("type", "keyword") + .endObject() + .startObject("predicted_class") + .field("type", "keyword") + .endObject() + .endObject() + .endObject()); + BulkRequest bulkRequest = + new BulkRequest(indexName) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #0 + .add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #1 + .add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #2 + .add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "dog")) // #3 + .add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "fox")) // #4 + .add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "cat")) // #5 + .add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #6 + .add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #7 + .add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #8 + .add(new IndexRequest().source(XContentType.JSON, "actual_class", "ant", "predicted_class", "cat")); // #9 + RestHighLevelClient client = highLevelClient(); + client.indices().create(createIndexRequest, RequestOptions.DEFAULT); + client.bulk(bulkRequest, RequestOptions.DEFAULT); + { + // tag::evaluate-data-frame-evaluation-classification + Evaluation evaluation = + new org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification( // <1> + "actual_class", // <2> + "predicted_class", // <3> + // Evaluation metrics // <4> + new MulticlassConfusionMatrixMetric(3)); // <5> + // end::evaluate-data-frame-evaluation-classification + + EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation); + EvaluateDataFrameResponse response = client.machineLearning().evaluateDataFrame(request, RequestOptions.DEFAULT); + + // tag::evaluate-data-frame-results-classification + MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix = + response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1> + + Map> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2> + long otherClassesCount = multiclassConfusionMatrix.getOtherClassesCount(); // <3> + // end::evaluate-data-frame-results-classification + + assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME)); + assertThat( + confusionMatrix, + equalTo( + new HashMap>() {{ + put("cat", new HashMap() {{ + put("cat", 3L); + put("dog", 1L); + put("ant", 0L); + put("_other_", 1L); + }}); + put("dog", new HashMap() {{ + put("cat", 1L); + put("dog", 3L); + put("ant", 0L); + }}); + put("ant", new HashMap() {{ + put("cat", 1L); + put("dog", 0L); + put("ant", 0L); + }}); + }})); + assertThat(otherClassesCount, equalTo(0L)); + } + } + + public void testEvaluateDataFrame_Regression() throws Exception { + String indexName = "evaluate-classification-test-index"; + CreateIndexRequest createIndexRequest = + new CreateIndexRequest(indexName) + .mapping(XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject("actual_value") + .field("type", "double") + .endObject() + .startObject("predicted_value") + .field("type", "double") + .endObject() + .endObject() + .endObject()); + BulkRequest bulkRequest = + new BulkRequest(indexName) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.0, "predicted_value", 1.0)) // #0 + .add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.0, "predicted_value", 0.9)) // #1 + .add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.0, "predicted_value", 2.0)) // #2 + .add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.5, "predicted_value", 1.4)) // #3 + .add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.2, "predicted_value", 1.3)) // #4 + .add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.7, "predicted_value", 2.0)) // #5 + .add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.1, "predicted_value", 2.1)) // #6 + .add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.5, "predicted_value", 2.7)) // #7 + .add(new IndexRequest().source(XContentType.JSON, "actual_value", 0.8, "predicted_value", 1.0)) // #8 + .add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.5, "predicted_value", 2.4)); // #9 + RestHighLevelClient client = highLevelClient(); + client.indices().create(createIndexRequest, RequestOptions.DEFAULT); + client.bulk(bulkRequest, RequestOptions.DEFAULT); + { + // tag::evaluate-data-frame-evaluation-regression + Evaluation evaluation = + new org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression( // <1> + "actual_value", // <2> + "predicted_value", // <3> + // Evaluation metrics // <4> + new MeanSquaredErrorMetric(), // <5> + new RSquaredMetric()); // <6> + // end::evaluate-data-frame-evaluation-regression + + EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation); + EvaluateDataFrameResponse response = client.machineLearning().evaluateDataFrame(request, RequestOptions.DEFAULT); + + // tag::evaluate-data-frame-results-regression + MeanSquaredErrorMetric.Result meanSquaredErrorResult = response.getMetricByName(MeanSquaredErrorMetric.NAME); // <1> + double meanSquaredError = meanSquaredErrorResult.getError(); // <2> + + RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <3> + double rSquared = rSquaredResult.getValue(); // <4> + // end::evaluate-data-frame-results-regression + + assertThat(meanSquaredError, closeTo(0.021, 1e-3)); + assertThat(rSquared, closeTo(0.941, 1e-3)); + } + } + public void testEstimateMemoryUsage() throws Exception { createIndex("estimate-test-source-index"); BulkRequest bulkRequest = diff --git a/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc b/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc index 7c231d7103b4a..61f18dbd09230 100644 --- a/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc +++ b/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc @@ -20,14 +20,52 @@ include-tagged::{doc-tests-file}[{api}-request] <1> Constructing a new evaluation request <2> Reference to an existing index <3> The query with which to select data from indices -<4> Kind of evaluation to perform -<5> Name of the field in the index. Its value denotes the actual (i.e. ground truth) label for an example. Must be either true or false -<6> Name of the field in the index. Its value denotes the probability (as per some ML algorithm) of the example being classified as positive -<7> The remaining parameters are the metrics to be calculated based on the two fields described above. -<8> https://en.wikipedia.org/wiki/Precision_and_recall[Precision] calculated at thresholds: 0.4, 0.5 and 0.6 -<9> https://en.wikipedia.org/wiki/Precision_and_recall[Recall] calculated at thresholds: 0.5 and 0.7 -<10> https://en.wikipedia.org/wiki/Confusion_matrix[Confusion matrix] calculated at threshold 0.5 -<11> https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated and the curve points returned +<4> Evaluation to be performed + +==== Evaluation + +Evaluation to be performed. +Currently, supported evaluations include: +BinarySoftClassification+, +Classification+, +Regression+. + +===== Binary soft classification + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-evaluation-softclassification] +-------------------------------------------------- +<1> Constructing a new evaluation +<2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) label for an example. Must be either true or false. +<3> Name of the field in the index. Its value denotes the probability (as per some ML algorithm) of the example being classified as positive. +<4> The remaining parameters are the metrics to be calculated based on the two fields described above +<5> https://en.wikipedia.org/wiki/Precision_and_recall#Precision[Precision] calculated at thresholds: 0.4, 0.5 and 0.6 +<6> https://en.wikipedia.org/wiki/Precision_and_recall#Recall[Recall] calculated at thresholds: 0.5 and 0.7 +<7> https://en.wikipedia.org/wiki/Confusion_matrix[Confusion matrix] calculated at threshold 0.5 +<8> https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated and the curve points returned + +===== Classification + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-evaluation-classification] +-------------------------------------------------- +<1> Constructing a new evaluation +<2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) class the example belongs to. +<3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) class of the example. +<4> The remaining parameters are the metrics to be calculated based on the two fields described above +<5> Multiclass confusion matrix of size 3 + +===== Regression + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-evaluation-regression] +-------------------------------------------------- +<1> Constructing a new evaluation +<2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) value for an example. +<3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) value for the example. +<4> The remaining parameters are the metrics to be calculated based on the two fields described above +<5> https://en.wikipedia.org/wiki/Mean_squared_error[Mean squared error] +<6> https://en.wikipedia.org/wiki/Coefficient_of_determination[R squared] include::../execution.asciidoc[] @@ -41,7 +79,40 @@ The returned +{response}+ contains the requested evaluation metrics. include-tagged::{doc-tests-file}[{api}-response] -------------------------------------------------- <1> Fetching all the calculated metrics results -<2> Fetching precision metric by name -<3> Fetching precision at a given (0.4) threshold -<4> Fetching confusion matrix metric by name -<5> Fetching confusion matrix at a given (0.5) threshold \ No newline at end of file + +==== Results + +===== Binary soft classification + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-results-softclassification] +-------------------------------------------------- + +<1> Fetching precision metric by name +<2> Fetching precision at a given (0.4) threshold +<3> Fetching confusion matrix metric by name +<4> Fetching confusion matrix at a given (0.5) threshold + +===== Classification + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-results-classification] +-------------------------------------------------- + +<1> Fetching multiclass confusion matrix metric by name +<2> Fetching the contents of the confusion matrix +<3> Fetching the number of classes that were not included in the matrix + +===== Regression + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-results-regression] +-------------------------------------------------- + +<1> Fetching mean squared error metric by name +<2> Fetching the actual mean squared error value +<3> Fetching R squared metric by name +<4> Fetching the actual R squared value \ No newline at end of file diff --git a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc index 214b2a99e0644..4ca4c31ecf574 100644 --- a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc +++ b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc @@ -76,7 +76,7 @@ include-tagged::{doc-tests-file}[{api}-dest-config] ==== Analysis The analysis to be performed. -Currently, the supported analyses include : +OutlierDetection+, +Regression+. +Currently, the supported analyses include: +OutlierDetection+, +Classification+, +Regression+. ===== Outlier detection @@ -101,6 +101,24 @@ include-tagged::{doc-tests-file}[{api}-outlier-detection-customized] <6> The proportion of the data set that is assumed to be outlying prior to outlier detection <7> Whether to apply standardization to feature values +===== Classification + ++Classification+ analysis requires to set which is the +dependent_variable+ and +has a number of other optional parameters: + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-classification] +-------------------------------------------------- +<1> Constructing a new Classification builder object with the required dependent variable +<2> The lambda regularization parameter. A non-negative double. +<3> The gamma regularization parameter. A non-negative double. +<4> The applied shrinkage. A double in [0.001, 1]. +<5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000]. +<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1]. +<7> The name of the prediction field in the results object. +<8> The percentage of training-eligible rows to be used in training. Defaults to 100%. + ===== Regression +Regression+ analysis requires to set which is the +dependent_variable+ and