Skip to content

[7.x] Add MlClientDocumentationIT tests for classification. (#47569) #47896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1776,7 +1776,7 @@ public void testEvaluateDataFrame_Regression() throws IOException {
new EvaluateDataFrameRequest(
regressionIndex,
null,
new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
new Regression(actualRegression, predictedRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));

EvaluateDataFrameResponse evaluateDataFrameResponse =
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
Expand Down Expand Up @@ -1933,25 +1933,25 @@ private static IndexRequest docForClassification(String indexName, String actual
}

private static final String actualRegression = "regression_actual";
private static final String probabilityRegression = "regression_prob";
private static final String predictedRegression = "regression_predicted";

private static XContentBuilder mappingForRegression() throws IOException {
return XContentFactory.jsonBuilder().startObject()
.startObject("properties")
.startObject(actualRegression)
.field("type", "double")
.endObject()
.startObject(probabilityRegression)
.startObject(predictedRegression)
.field("type", "double")
.endObject()
.endObject()
.endObject();
}

private static IndexRequest docForRegression(String indexName, double act, double p) {
private static IndexRequest docForRegression(String indexName, double actualValue, double predictedValue) {
return new IndexRequest()
.index(indexName)
.source(XContentType.JSON, actualRegression, act, probabilityRegression, p);
.source(XContentType.JSON, actualRegression, actualValue, predictedRegression, predictedValue);
}

private void createIndex(String indexName, XContentBuilder mapping) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,11 @@
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
Expand Down Expand Up @@ -2821,7 +2824,7 @@ public void testGetDataFrameAnalytics() throws Exception {
List<DataFrameAnalyticsConfig> configs = response.getAnalytics();
// end::get-data-frame-analytics-response

assertThat(configs.size(), equalTo(1));
assertThat(configs, hasSize(1));
}
{
GetDataFrameAnalyticsRequest request = new GetDataFrameAnalyticsRequest("my-analytics-config");
Expand Down Expand Up @@ -2871,7 +2874,7 @@ public void testGetDataFrameAnalyticsStats() throws Exception {
List<DataFrameAnalyticsStats> stats = response.getAnalyticsStats();
// end::get-data-frame-analytics-stats-response

assertThat(stats.size(), equalTo(1));
assertThat(stats, hasSize(1));
}
{
GetDataFrameAnalyticsStatsRequest request = new GetDataFrameAnalyticsStatsRequest("my-analytics-config");
Expand Down Expand Up @@ -2939,8 +2942,20 @@ public void testPutDataFrameAnalytics() throws Exception {
.build();
// end::put-data-frame-analytics-outlier-detection-customized

// tag::put-data-frame-analytics-classification
DataFrameAnalysis classification = org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable") // <1>
.setLambda(1.0) // <2>
.setGamma(5.5) // <3>
.setEta(5.5) // <4>
.setMaximumNumberTrees(50) // <5>
.setFeatureBagFraction(0.4) // <6>
.setPredictionFieldName("my_prediction_field_name") // <7>
.setTrainingPercent(50.0) // <8>
.build();
// end::put-data-frame-analytics-classification

// tag::put-data-frame-analytics-regression
DataFrameAnalysis regression = Regression.builder("my_dependent_variable") // <1>
DataFrameAnalysis regression = org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable") // <1>
.setLambda(1.0) // <2>
.setGamma(5.5) // <3>
.setEta(5.5) // <4>
Expand Down Expand Up @@ -3209,18 +3224,24 @@ public void testEvaluateDataFrame() throws Exception {
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
client.bulk(bulkRequest, RequestOptions.DEFAULT);
{
// tag::evaluate-data-frame-evaluation-softclassification
Evaluation evaluation =
new BinarySoftClassification( // <1>
"label", // <2>
"p", // <3>
// Evaluation metrics // <4>
PrecisionMetric.at(0.4, 0.5, 0.6), // <5>
RecallMetric.at(0.5, 0.7), // <6>
ConfusionMatrixMetric.at(0.5), // <7>
AucRocMetric.withCurve()); // <8>
// end::evaluate-data-frame-evaluation-softclassification

// tag::evaluate-data-frame-request
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest( // <1>
indexName, // <2>
new QueryConfig(QueryBuilders.termQuery("dataset", "blue")), // <3>
new BinarySoftClassification( // <4>
"label", // <5>
"p", // <6>
// Evaluation metrics // <7>
PrecisionMetric.at(0.4, 0.5, 0.6), // <8>
RecallMetric.at(0.5, 0.7), // <9>
ConfusionMatrixMetric.at(0.5), // <10>
AucRocMetric.withCurve())); // <11>
EvaluateDataFrameRequest request =
new EvaluateDataFrameRequest( // <1>
indexName, // <2>
new QueryConfig(QueryBuilders.termQuery("dataset", "blue")), // <3>
evaluation); // <4>
// end::evaluate-data-frame-request

// tag::evaluate-data-frame-execute
Expand All @@ -3229,16 +3250,18 @@ public void testEvaluateDataFrame() throws Exception {

// tag::evaluate-data-frame-response
List<EvaluationMetric.Result> metrics = response.getMetrics(); // <1>
// end::evaluate-data-frame-response

PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <2>
double precision = precisionResult.getScoreByThreshold("0.4"); // <3>
// tag::evaluate-data-frame-results-softclassification
PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <1>
double precision = precisionResult.getScoreByThreshold("0.4"); // <2>

ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <4>
ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5"); // <5>
// end::evaluate-data-frame-response
ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <3>
ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5"); // <4>
// end::evaluate-data-frame-results-softclassification

assertThat(
metrics.stream().map(m -> m.getMetricName()).collect(Collectors.toList()),
metrics.stream().map(EvaluationMetric.Result::getMetricName).collect(Collectors.toList()),
containsInAnyOrder(PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, AucRocMetric.NAME));
assertThat(precision, closeTo(0.6, 1e-9));
assertThat(confusionMatrix.getTruePositives(), equalTo(2L)); // docs #8 and #9
Expand Down Expand Up @@ -3284,6 +3307,140 @@ public void onFailure(Exception e) {
}
}

public void testEvaluateDataFrame_Classification() throws Exception {
String indexName = "evaluate-classification-test-index";
CreateIndexRequest createIndexRequest =
new CreateIndexRequest(indexName)
.mapping(XContentFactory.jsonBuilder().startObject()
.startObject("properties")
.startObject("actual_class")
.field("type", "keyword")
.endObject()
.startObject("predicted_class")
.field("type", "keyword")
.endObject()
.endObject()
.endObject());
BulkRequest bulkRequest =
new BulkRequest(indexName)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #0
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #1
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #2
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "dog")) // #3
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "fox")) // #4
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "cat")) // #5
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #6
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #7
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #8
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "ant", "predicted_class", "cat")); // #9
RestHighLevelClient client = highLevelClient();
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
client.bulk(bulkRequest, RequestOptions.DEFAULT);
{
// tag::evaluate-data-frame-evaluation-classification
Evaluation evaluation =
new org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification( // <1>
"actual_class", // <2>
"predicted_class", // <3>
// Evaluation metrics // <4>
new MulticlassConfusionMatrixMetric(3)); // <5>
// end::evaluate-data-frame-evaluation-classification

EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
EvaluateDataFrameResponse response = client.machineLearning().evaluateDataFrame(request, RequestOptions.DEFAULT);

// tag::evaluate-data-frame-results-classification
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1>

Map<String, Map<String, Long>> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2>
long otherClassesCount = multiclassConfusionMatrix.getOtherClassesCount(); // <3>
// end::evaluate-data-frame-results-classification

assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
assertThat(
confusionMatrix,
equalTo(
new HashMap<String, Map<String, Long>>() {{
put("cat", new HashMap<String, Long>() {{
put("cat", 3L);
put("dog", 1L);
put("ant", 0L);
put("_other_", 1L);
}});
put("dog", new HashMap<String, Long>() {{
put("cat", 1L);
put("dog", 3L);
put("ant", 0L);
}});
put("ant", new HashMap<String, Long>() {{
put("cat", 1L);
put("dog", 0L);
put("ant", 0L);
}});
}}));
assertThat(otherClassesCount, equalTo(0L));
}
}

public void testEvaluateDataFrame_Regression() throws Exception {
String indexName = "evaluate-classification-test-index";
CreateIndexRequest createIndexRequest =
new CreateIndexRequest(indexName)
.mapping(XContentFactory.jsonBuilder().startObject()
.startObject("properties")
.startObject("actual_value")
.field("type", "double")
.endObject()
.startObject("predicted_value")
.field("type", "double")
.endObject()
.endObject()
.endObject());
BulkRequest bulkRequest =
new BulkRequest(indexName)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.0, "predicted_value", 1.0)) // #0
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.0, "predicted_value", 0.9)) // #1
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.0, "predicted_value", 2.0)) // #2
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.5, "predicted_value", 1.4)) // #3
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.2, "predicted_value", 1.3)) // #4
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.7, "predicted_value", 2.0)) // #5
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.1, "predicted_value", 2.1)) // #6
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.5, "predicted_value", 2.7)) // #7
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 0.8, "predicted_value", 1.0)) // #8
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.5, "predicted_value", 2.4)); // #9
RestHighLevelClient client = highLevelClient();
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
client.bulk(bulkRequest, RequestOptions.DEFAULT);
{
// tag::evaluate-data-frame-evaluation-regression
Evaluation evaluation =
new org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression( // <1>
"actual_value", // <2>
"predicted_value", // <3>
// Evaluation metrics // <4>
new MeanSquaredErrorMetric(), // <5>
new RSquaredMetric()); // <6>
// end::evaluate-data-frame-evaluation-regression

EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
EvaluateDataFrameResponse response = client.machineLearning().evaluateDataFrame(request, RequestOptions.DEFAULT);

// tag::evaluate-data-frame-results-regression
MeanSquaredErrorMetric.Result meanSquaredErrorResult = response.getMetricByName(MeanSquaredErrorMetric.NAME); // <1>
double meanSquaredError = meanSquaredErrorResult.getError(); // <2>

RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <3>
double rSquared = rSquaredResult.getValue(); // <4>
// end::evaluate-data-frame-results-regression

assertThat(meanSquaredError, closeTo(0.021, 1e-3));
assertThat(rSquared, closeTo(0.941, 1e-3));
}
}

public void testEstimateMemoryUsage() throws Exception {
createIndex("estimate-test-source-index");
BulkRequest bulkRequest =
Expand Down
Loading