Skip to content

Commit d210bfa

Browse files
authored
[7.x] Add MlClientDocumentationIT tests for classification. (#47569) (#47896)
1 parent e60221d commit d210bfa

File tree

4 files changed

+285
-39
lines changed

4 files changed

+285
-39
lines changed

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1776,7 +1776,7 @@ public void testEvaluateDataFrame_Regression() throws IOException {
17761776
new EvaluateDataFrameRequest(
17771777
regressionIndex,
17781778
null,
1779-
new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
1779+
new Regression(actualRegression, predictedRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
17801780

17811781
EvaluateDataFrameResponse evaluateDataFrameResponse =
17821782
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
@@ -1933,25 +1933,25 @@ private static IndexRequest docForClassification(String indexName, String actual
19331933
}
19341934

19351935
private static final String actualRegression = "regression_actual";
1936-
private static final String probabilityRegression = "regression_prob";
1936+
private static final String predictedRegression = "regression_predicted";
19371937

19381938
private static XContentBuilder mappingForRegression() throws IOException {
19391939
return XContentFactory.jsonBuilder().startObject()
19401940
.startObject("properties")
19411941
.startObject(actualRegression)
19421942
.field("type", "double")
19431943
.endObject()
1944-
.startObject(probabilityRegression)
1944+
.startObject(predictedRegression)
19451945
.field("type", "double")
19461946
.endObject()
19471947
.endObject()
19481948
.endObject();
19491949
}
19501950

1951-
private static IndexRequest docForRegression(String indexName, double act, double p) {
1951+
private static IndexRequest docForRegression(String indexName, double actualValue, double predictedValue) {
19521952
return new IndexRequest()
19531953
.index(indexName)
1954-
.source(XContentType.JSON, actualRegression, act, probabilityRegression, p);
1954+
.source(XContentType.JSON, actualRegression, actualValue, predictedRegression, predictedValue);
19551955
}
19561956

19571957
private void createIndex(String indexName, XContentBuilder mapping) throws IOException {

client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

Lines changed: 178 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,11 @@
139139
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
140140
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
141141
import org.elasticsearch.client.ml.dataframe.QueryConfig;
142-
import org.elasticsearch.client.ml.dataframe.Regression;
142+
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
143143
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
144+
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
145+
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
146+
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
144147
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
145148
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
146149
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
@@ -2821,7 +2824,7 @@ public void testGetDataFrameAnalytics() throws Exception {
28212824
List<DataFrameAnalyticsConfig> configs = response.getAnalytics();
28222825
// end::get-data-frame-analytics-response
28232826

2824-
assertThat(configs.size(), equalTo(1));
2827+
assertThat(configs, hasSize(1));
28252828
}
28262829
{
28272830
GetDataFrameAnalyticsRequest request = new GetDataFrameAnalyticsRequest("my-analytics-config");
@@ -2871,7 +2874,7 @@ public void testGetDataFrameAnalyticsStats() throws Exception {
28712874
List<DataFrameAnalyticsStats> stats = response.getAnalyticsStats();
28722875
// end::get-data-frame-analytics-stats-response
28732876

2874-
assertThat(stats.size(), equalTo(1));
2877+
assertThat(stats, hasSize(1));
28752878
}
28762879
{
28772880
GetDataFrameAnalyticsStatsRequest request = new GetDataFrameAnalyticsStatsRequest("my-analytics-config");
@@ -2939,8 +2942,20 @@ public void testPutDataFrameAnalytics() throws Exception {
29392942
.build();
29402943
// end::put-data-frame-analytics-outlier-detection-customized
29412944

2945+
// tag::put-data-frame-analytics-classification
2946+
DataFrameAnalysis classification = org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable") // <1>
2947+
.setLambda(1.0) // <2>
2948+
.setGamma(5.5) // <3>
2949+
.setEta(5.5) // <4>
2950+
.setMaximumNumberTrees(50) // <5>
2951+
.setFeatureBagFraction(0.4) // <6>
2952+
.setPredictionFieldName("my_prediction_field_name") // <7>
2953+
.setTrainingPercent(50.0) // <8>
2954+
.build();
2955+
// end::put-data-frame-analytics-classification
2956+
29422957
// tag::put-data-frame-analytics-regression
2943-
DataFrameAnalysis regression = Regression.builder("my_dependent_variable") // <1>
2958+
DataFrameAnalysis regression = org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable") // <1>
29442959
.setLambda(1.0) // <2>
29452960
.setGamma(5.5) // <3>
29462961
.setEta(5.5) // <4>
@@ -3209,18 +3224,24 @@ public void testEvaluateDataFrame() throws Exception {
32093224
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
32103225
client.bulk(bulkRequest, RequestOptions.DEFAULT);
32113226
{
3227+
// tag::evaluate-data-frame-evaluation-softclassification
3228+
Evaluation evaluation =
3229+
new BinarySoftClassification( // <1>
3230+
"label", // <2>
3231+
"p", // <3>
3232+
// Evaluation metrics // <4>
3233+
PrecisionMetric.at(0.4, 0.5, 0.6), // <5>
3234+
RecallMetric.at(0.5, 0.7), // <6>
3235+
ConfusionMatrixMetric.at(0.5), // <7>
3236+
AucRocMetric.withCurve()); // <8>
3237+
// end::evaluate-data-frame-evaluation-softclassification
3238+
32123239
// tag::evaluate-data-frame-request
3213-
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest( // <1>
3214-
indexName, // <2>
3215-
new QueryConfig(QueryBuilders.termQuery("dataset", "blue")), // <3>
3216-
new BinarySoftClassification( // <4>
3217-
"label", // <5>
3218-
"p", // <6>
3219-
// Evaluation metrics // <7>
3220-
PrecisionMetric.at(0.4, 0.5, 0.6), // <8>
3221-
RecallMetric.at(0.5, 0.7), // <9>
3222-
ConfusionMatrixMetric.at(0.5), // <10>
3223-
AucRocMetric.withCurve())); // <11>
3240+
EvaluateDataFrameRequest request =
3241+
new EvaluateDataFrameRequest( // <1>
3242+
indexName, // <2>
3243+
new QueryConfig(QueryBuilders.termQuery("dataset", "blue")), // <3>
3244+
evaluation); // <4>
32243245
// end::evaluate-data-frame-request
32253246

32263247
// tag::evaluate-data-frame-execute
@@ -3229,16 +3250,18 @@ public void testEvaluateDataFrame() throws Exception {
32293250

32303251
// tag::evaluate-data-frame-response
32313252
List<EvaluationMetric.Result> metrics = response.getMetrics(); // <1>
3253+
// end::evaluate-data-frame-response
32323254

3233-
PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <2>
3234-
double precision = precisionResult.getScoreByThreshold("0.4"); // <3>
3255+
// tag::evaluate-data-frame-results-softclassification
3256+
PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <1>
3257+
double precision = precisionResult.getScoreByThreshold("0.4"); // <2>
32353258

3236-
ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <4>
3237-
ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5"); // <5>
3238-
// end::evaluate-data-frame-response
3259+
ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <3>
3260+
ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5"); // <4>
3261+
// end::evaluate-data-frame-results-softclassification
32393262

32403263
assertThat(
3241-
metrics.stream().map(m -> m.getMetricName()).collect(Collectors.toList()),
3264+
metrics.stream().map(EvaluationMetric.Result::getMetricName).collect(Collectors.toList()),
32423265
containsInAnyOrder(PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, AucRocMetric.NAME));
32433266
assertThat(precision, closeTo(0.6, 1e-9));
32443267
assertThat(confusionMatrix.getTruePositives(), equalTo(2L)); // docs #8 and #9
@@ -3284,6 +3307,140 @@ public void onFailure(Exception e) {
32843307
}
32853308
}
32863309

3310+
public void testEvaluateDataFrame_Classification() throws Exception {
3311+
String indexName = "evaluate-classification-test-index";
3312+
CreateIndexRequest createIndexRequest =
3313+
new CreateIndexRequest(indexName)
3314+
.mapping(XContentFactory.jsonBuilder().startObject()
3315+
.startObject("properties")
3316+
.startObject("actual_class")
3317+
.field("type", "keyword")
3318+
.endObject()
3319+
.startObject("predicted_class")
3320+
.field("type", "keyword")
3321+
.endObject()
3322+
.endObject()
3323+
.endObject());
3324+
BulkRequest bulkRequest =
3325+
new BulkRequest(indexName)
3326+
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
3327+
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #0
3328+
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #1
3329+
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #2
3330+
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "dog")) // #3
3331+
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "fox")) // #4
3332+
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "cat")) // #5
3333+
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #6
3334+
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #7
3335+
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #8
3336+
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "ant", "predicted_class", "cat")); // #9
3337+
RestHighLevelClient client = highLevelClient();
3338+
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
3339+
client.bulk(bulkRequest, RequestOptions.DEFAULT);
3340+
{
3341+
// tag::evaluate-data-frame-evaluation-classification
3342+
Evaluation evaluation =
3343+
new org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification( // <1>
3344+
"actual_class", // <2>
3345+
"predicted_class", // <3>
3346+
// Evaluation metrics // <4>
3347+
new MulticlassConfusionMatrixMetric(3)); // <5>
3348+
// end::evaluate-data-frame-evaluation-classification
3349+
3350+
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
3351+
EvaluateDataFrameResponse response = client.machineLearning().evaluateDataFrame(request, RequestOptions.DEFAULT);
3352+
3353+
// tag::evaluate-data-frame-results-classification
3354+
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
3355+
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1>
3356+
3357+
Map<String, Map<String, Long>> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2>
3358+
long otherClassesCount = multiclassConfusionMatrix.getOtherClassesCount(); // <3>
3359+
// end::evaluate-data-frame-results-classification
3360+
3361+
assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
3362+
assertThat(
3363+
confusionMatrix,
3364+
equalTo(
3365+
new HashMap<String, Map<String, Long>>() {{
3366+
put("cat", new HashMap<String, Long>() {{
3367+
put("cat", 3L);
3368+
put("dog", 1L);
3369+
put("ant", 0L);
3370+
put("_other_", 1L);
3371+
}});
3372+
put("dog", new HashMap<String, Long>() {{
3373+
put("cat", 1L);
3374+
put("dog", 3L);
3375+
put("ant", 0L);
3376+
}});
3377+
put("ant", new HashMap<String, Long>() {{
3378+
put("cat", 1L);
3379+
put("dog", 0L);
3380+
put("ant", 0L);
3381+
}});
3382+
}}));
3383+
assertThat(otherClassesCount, equalTo(0L));
3384+
}
3385+
}
3386+
3387+
public void testEvaluateDataFrame_Regression() throws Exception {
3388+
String indexName = "evaluate-classification-test-index";
3389+
CreateIndexRequest createIndexRequest =
3390+
new CreateIndexRequest(indexName)
3391+
.mapping(XContentFactory.jsonBuilder().startObject()
3392+
.startObject("properties")
3393+
.startObject("actual_value")
3394+
.field("type", "double")
3395+
.endObject()
3396+
.startObject("predicted_value")
3397+
.field("type", "double")
3398+
.endObject()
3399+
.endObject()
3400+
.endObject());
3401+
BulkRequest bulkRequest =
3402+
new BulkRequest(indexName)
3403+
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
3404+
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.0, "predicted_value", 1.0)) // #0
3405+
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.0, "predicted_value", 0.9)) // #1
3406+
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.0, "predicted_value", 2.0)) // #2
3407+
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.5, "predicted_value", 1.4)) // #3
3408+
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.2, "predicted_value", 1.3)) // #4
3409+
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 1.7, "predicted_value", 2.0)) // #5
3410+
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.1, "predicted_value", 2.1)) // #6
3411+
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.5, "predicted_value", 2.7)) // #7
3412+
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 0.8, "predicted_value", 1.0)) // #8
3413+
.add(new IndexRequest().source(XContentType.JSON, "actual_value", 2.5, "predicted_value", 2.4)); // #9
3414+
RestHighLevelClient client = highLevelClient();
3415+
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
3416+
client.bulk(bulkRequest, RequestOptions.DEFAULT);
3417+
{
3418+
// tag::evaluate-data-frame-evaluation-regression
3419+
Evaluation evaluation =
3420+
new org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression( // <1>
3421+
"actual_value", // <2>
3422+
"predicted_value", // <3>
3423+
// Evaluation metrics // <4>
3424+
new MeanSquaredErrorMetric(), // <5>
3425+
new RSquaredMetric()); // <6>
3426+
// end::evaluate-data-frame-evaluation-regression
3427+
3428+
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
3429+
EvaluateDataFrameResponse response = client.machineLearning().evaluateDataFrame(request, RequestOptions.DEFAULT);
3430+
3431+
// tag::evaluate-data-frame-results-regression
3432+
MeanSquaredErrorMetric.Result meanSquaredErrorResult = response.getMetricByName(MeanSquaredErrorMetric.NAME); // <1>
3433+
double meanSquaredError = meanSquaredErrorResult.getError(); // <2>
3434+
3435+
RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <3>
3436+
double rSquared = rSquaredResult.getValue(); // <4>
3437+
// end::evaluate-data-frame-results-regression
3438+
3439+
assertThat(meanSquaredError, closeTo(0.021, 1e-3));
3440+
assertThat(rSquared, closeTo(0.941, 1e-3));
3441+
}
3442+
}
3443+
32873444
public void testEstimateMemoryUsage() throws Exception {
32883445
createIndex("estimate-test-source-index");
32893446
BulkRequest bulkRequest =

0 commit comments

Comments
 (0)