139
139
import org .elasticsearch .client .ml .dataframe .DataFrameAnalyticsStats ;
140
140
import org .elasticsearch .client .ml .dataframe .OutlierDetection ;
141
141
import org .elasticsearch .client .ml .dataframe .QueryConfig ;
142
- import org .elasticsearch .client .ml .dataframe .Regression ;
142
+ import org .elasticsearch .client .ml .dataframe .evaluation . Evaluation ;
143
143
import org .elasticsearch .client .ml .dataframe .evaluation .EvaluationMetric ;
144
+ import org .elasticsearch .client .ml .dataframe .evaluation .classification .MulticlassConfusionMatrixMetric ;
145
+ import org .elasticsearch .client .ml .dataframe .evaluation .regression .MeanSquaredErrorMetric ;
146
+ import org .elasticsearch .client .ml .dataframe .evaluation .regression .RSquaredMetric ;
144
147
import org .elasticsearch .client .ml .dataframe .evaluation .softclassification .AucRocMetric ;
145
148
import org .elasticsearch .client .ml .dataframe .evaluation .softclassification .BinarySoftClassification ;
146
149
import org .elasticsearch .client .ml .dataframe .evaluation .softclassification .ConfusionMatrixMetric ;
@@ -2821,7 +2824,7 @@ public void testGetDataFrameAnalytics() throws Exception {
2821
2824
List <DataFrameAnalyticsConfig > configs = response .getAnalytics ();
2822
2825
// end::get-data-frame-analytics-response
2823
2826
2824
- assertThat (configs . size (), equalTo (1 ));
2827
+ assertThat (configs , hasSize (1 ));
2825
2828
}
2826
2829
{
2827
2830
GetDataFrameAnalyticsRequest request = new GetDataFrameAnalyticsRequest ("my-analytics-config" );
@@ -2871,7 +2874,7 @@ public void testGetDataFrameAnalyticsStats() throws Exception {
2871
2874
List <DataFrameAnalyticsStats > stats = response .getAnalyticsStats ();
2872
2875
// end::get-data-frame-analytics-stats-response
2873
2876
2874
- assertThat (stats . size (), equalTo (1 ));
2877
+ assertThat (stats , hasSize (1 ));
2875
2878
}
2876
2879
{
2877
2880
GetDataFrameAnalyticsStatsRequest request = new GetDataFrameAnalyticsStatsRequest ("my-analytics-config" );
@@ -2939,8 +2942,20 @@ public void testPutDataFrameAnalytics() throws Exception {
2939
2942
.build ();
2940
2943
// end::put-data-frame-analytics-outlier-detection-customized
2941
2944
2945
+ // tag::put-data-frame-analytics-classification
2946
+ DataFrameAnalysis classification = org .elasticsearch .client .ml .dataframe .Classification .builder ("my_dependent_variable" ) // <1>
2947
+ .setLambda (1.0 ) // <2>
2948
+ .setGamma (5.5 ) // <3>
2949
+ .setEta (5.5 ) // <4>
2950
+ .setMaximumNumberTrees (50 ) // <5>
2951
+ .setFeatureBagFraction (0.4 ) // <6>
2952
+ .setPredictionFieldName ("my_prediction_field_name" ) // <7>
2953
+ .setTrainingPercent (50.0 ) // <8>
2954
+ .build ();
2955
+ // end::put-data-frame-analytics-classification
2956
+
2942
2957
// tag::put-data-frame-analytics-regression
2943
- DataFrameAnalysis regression = Regression .builder ("my_dependent_variable" ) // <1>
2958
+ DataFrameAnalysis regression = org . elasticsearch . client . ml . dataframe . Regression .builder ("my_dependent_variable" ) // <1>
2944
2959
.setLambda (1.0 ) // <2>
2945
2960
.setGamma (5.5 ) // <3>
2946
2961
.setEta (5.5 ) // <4>
@@ -3209,18 +3224,24 @@ public void testEvaluateDataFrame() throws Exception {
3209
3224
client .indices ().create (createIndexRequest , RequestOptions .DEFAULT );
3210
3225
client .bulk (bulkRequest , RequestOptions .DEFAULT );
3211
3226
{
3227
+ // tag::evaluate-data-frame-evaluation-softclassification
3228
+ Evaluation evaluation =
3229
+ new BinarySoftClassification ( // <1>
3230
+ "label" , // <2>
3231
+ "p" , // <3>
3232
+ // Evaluation metrics // <4>
3233
+ PrecisionMetric .at (0.4 , 0.5 , 0.6 ), // <5>
3234
+ RecallMetric .at (0.5 , 0.7 ), // <6>
3235
+ ConfusionMatrixMetric .at (0.5 ), // <7>
3236
+ AucRocMetric .withCurve ()); // <8>
3237
+ // end::evaluate-data-frame-evaluation-softclassification
3238
+
3212
3239
// tag::evaluate-data-frame-request
3213
- EvaluateDataFrameRequest request = new EvaluateDataFrameRequest ( // <1>
3214
- indexName , // <2>
3215
- new QueryConfig (QueryBuilders .termQuery ("dataset" , "blue" )), // <3>
3216
- new BinarySoftClassification ( // <4>
3217
- "label" , // <5>
3218
- "p" , // <6>
3219
- // Evaluation metrics // <7>
3220
- PrecisionMetric .at (0.4 , 0.5 , 0.6 ), // <8>
3221
- RecallMetric .at (0.5 , 0.7 ), // <9>
3222
- ConfusionMatrixMetric .at (0.5 ), // <10>
3223
- AucRocMetric .withCurve ())); // <11>
3240
+ EvaluateDataFrameRequest request =
3241
+ new EvaluateDataFrameRequest ( // <1>
3242
+ indexName , // <2>
3243
+ new QueryConfig (QueryBuilders .termQuery ("dataset" , "blue" )), // <3>
3244
+ evaluation ); // <4>
3224
3245
// end::evaluate-data-frame-request
3225
3246
3226
3247
// tag::evaluate-data-frame-execute
@@ -3229,16 +3250,18 @@ public void testEvaluateDataFrame() throws Exception {
3229
3250
3230
3251
// tag::evaluate-data-frame-response
3231
3252
List <EvaluationMetric .Result > metrics = response .getMetrics (); // <1>
3253
+ // end::evaluate-data-frame-response
3232
3254
3233
- PrecisionMetric .Result precisionResult = response .getMetricByName (PrecisionMetric .NAME ); // <2>
3234
- double precision = precisionResult .getScoreByThreshold ("0.4" ); // <3>
3255
+ // tag::evaluate-data-frame-results-softclassification
3256
+ PrecisionMetric .Result precisionResult = response .getMetricByName (PrecisionMetric .NAME ); // <1>
3257
+ double precision = precisionResult .getScoreByThreshold ("0.4" ); // <2>
3235
3258
3236
- ConfusionMatrixMetric .Result confusionMatrixResult = response .getMetricByName (ConfusionMatrixMetric .NAME ); // <4 >
3237
- ConfusionMatrix confusionMatrix = confusionMatrixResult .getScoreByThreshold ("0.5" ); // <5 >
3238
- // end::evaluate-data-frame-response
3259
+ ConfusionMatrixMetric .Result confusionMatrixResult = response .getMetricByName (ConfusionMatrixMetric .NAME ); // <3 >
3260
+ ConfusionMatrix confusionMatrix = confusionMatrixResult .getScoreByThreshold ("0.5" ); // <4 >
3261
+ // end::evaluate-data-frame-results-softclassification
3239
3262
3240
3263
assertThat (
3241
- metrics .stream ().map (m -> m . getMetricName () ).collect (Collectors .toList ()),
3264
+ metrics .stream ().map (EvaluationMetric . Result :: getMetricName ).collect (Collectors .toList ()),
3242
3265
containsInAnyOrder (PrecisionMetric .NAME , RecallMetric .NAME , ConfusionMatrixMetric .NAME , AucRocMetric .NAME ));
3243
3266
assertThat (precision , closeTo (0.6 , 1e-9 ));
3244
3267
assertThat (confusionMatrix .getTruePositives (), equalTo (2L )); // docs #8 and #9
@@ -3284,6 +3307,140 @@ public void onFailure(Exception e) {
3284
3307
}
3285
3308
}
3286
3309
3310
+ public void testEvaluateDataFrame_Classification () throws Exception {
3311
+ String indexName = "evaluate-classification-test-index" ;
3312
+ CreateIndexRequest createIndexRequest =
3313
+ new CreateIndexRequest (indexName )
3314
+ .mapping (XContentFactory .jsonBuilder ().startObject ()
3315
+ .startObject ("properties" )
3316
+ .startObject ("actual_class" )
3317
+ .field ("type" , "keyword" )
3318
+ .endObject ()
3319
+ .startObject ("predicted_class" )
3320
+ .field ("type" , "keyword" )
3321
+ .endObject ()
3322
+ .endObject ()
3323
+ .endObject ());
3324
+ BulkRequest bulkRequest =
3325
+ new BulkRequest (indexName )
3326
+ .setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE )
3327
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_class" , "cat" , "predicted_class" , "cat" )) // #0
3328
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_class" , "cat" , "predicted_class" , "cat" )) // #1
3329
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_class" , "cat" , "predicted_class" , "cat" )) // #2
3330
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_class" , "cat" , "predicted_class" , "dog" )) // #3
3331
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_class" , "cat" , "predicted_class" , "fox" )) // #4
3332
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_class" , "dog" , "predicted_class" , "cat" )) // #5
3333
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_class" , "dog" , "predicted_class" , "dog" )) // #6
3334
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_class" , "dog" , "predicted_class" , "dog" )) // #7
3335
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_class" , "dog" , "predicted_class" , "dog" )) // #8
3336
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_class" , "ant" , "predicted_class" , "cat" )); // #9
3337
+ RestHighLevelClient client = highLevelClient ();
3338
+ client .indices ().create (createIndexRequest , RequestOptions .DEFAULT );
3339
+ client .bulk (bulkRequest , RequestOptions .DEFAULT );
3340
+ {
3341
+ // tag::evaluate-data-frame-evaluation-classification
3342
+ Evaluation evaluation =
3343
+ new org .elasticsearch .client .ml .dataframe .evaluation .classification .Classification ( // <1>
3344
+ "actual_class" , // <2>
3345
+ "predicted_class" , // <3>
3346
+ // Evaluation metrics // <4>
3347
+ new MulticlassConfusionMatrixMetric (3 )); // <5>
3348
+ // end::evaluate-data-frame-evaluation-classification
3349
+
3350
+ EvaluateDataFrameRequest request = new EvaluateDataFrameRequest (indexName , null , evaluation );
3351
+ EvaluateDataFrameResponse response = client .machineLearning ().evaluateDataFrame (request , RequestOptions .DEFAULT );
3352
+
3353
+ // tag::evaluate-data-frame-results-classification
3354
+ MulticlassConfusionMatrixMetric .Result multiclassConfusionMatrix =
3355
+ response .getMetricByName (MulticlassConfusionMatrixMetric .NAME ); // <1>
3356
+
3357
+ Map <String , Map <String , Long >> confusionMatrix = multiclassConfusionMatrix .getConfusionMatrix (); // <2>
3358
+ long otherClassesCount = multiclassConfusionMatrix .getOtherClassesCount (); // <3>
3359
+ // end::evaluate-data-frame-results-classification
3360
+
3361
+ assertThat (multiclassConfusionMatrix .getMetricName (), equalTo (MulticlassConfusionMatrixMetric .NAME ));
3362
+ assertThat (
3363
+ confusionMatrix ,
3364
+ equalTo (
3365
+ new HashMap <String , Map <String , Long >>() {{
3366
+ put ("cat" , new HashMap <String , Long >() {{
3367
+ put ("cat" , 3L );
3368
+ put ("dog" , 1L );
3369
+ put ("ant" , 0L );
3370
+ put ("_other_" , 1L );
3371
+ }});
3372
+ put ("dog" , new HashMap <String , Long >() {{
3373
+ put ("cat" , 1L );
3374
+ put ("dog" , 3L );
3375
+ put ("ant" , 0L );
3376
+ }});
3377
+ put ("ant" , new HashMap <String , Long >() {{
3378
+ put ("cat" , 1L );
3379
+ put ("dog" , 0L );
3380
+ put ("ant" , 0L );
3381
+ }});
3382
+ }}));
3383
+ assertThat (otherClassesCount , equalTo (0L ));
3384
+ }
3385
+ }
3386
+
3387
+ public void testEvaluateDataFrame_Regression () throws Exception {
3388
+ String indexName = "evaluate-classification-test-index" ;
3389
+ CreateIndexRequest createIndexRequest =
3390
+ new CreateIndexRequest (indexName )
3391
+ .mapping (XContentFactory .jsonBuilder ().startObject ()
3392
+ .startObject ("properties" )
3393
+ .startObject ("actual_value" )
3394
+ .field ("type" , "double" )
3395
+ .endObject ()
3396
+ .startObject ("predicted_value" )
3397
+ .field ("type" , "double" )
3398
+ .endObject ()
3399
+ .endObject ()
3400
+ .endObject ());
3401
+ BulkRequest bulkRequest =
3402
+ new BulkRequest (indexName )
3403
+ .setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE )
3404
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_value" , 1.0 , "predicted_value" , 1.0 )) // #0
3405
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_value" , 1.0 , "predicted_value" , 0.9 )) // #1
3406
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_value" , 2.0 , "predicted_value" , 2.0 )) // #2
3407
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_value" , 1.5 , "predicted_value" , 1.4 )) // #3
3408
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_value" , 1.2 , "predicted_value" , 1.3 )) // #4
3409
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_value" , 1.7 , "predicted_value" , 2.0 )) // #5
3410
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_value" , 2.1 , "predicted_value" , 2.1 )) // #6
3411
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_value" , 2.5 , "predicted_value" , 2.7 )) // #7
3412
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_value" , 0.8 , "predicted_value" , 1.0 )) // #8
3413
+ .add (new IndexRequest ().source (XContentType .JSON , "actual_value" , 2.5 , "predicted_value" , 2.4 )); // #9
3414
+ RestHighLevelClient client = highLevelClient ();
3415
+ client .indices ().create (createIndexRequest , RequestOptions .DEFAULT );
3416
+ client .bulk (bulkRequest , RequestOptions .DEFAULT );
3417
+ {
3418
+ // tag::evaluate-data-frame-evaluation-regression
3419
+ Evaluation evaluation =
3420
+ new org .elasticsearch .client .ml .dataframe .evaluation .regression .Regression ( // <1>
3421
+ "actual_value" , // <2>
3422
+ "predicted_value" , // <3>
3423
+ // Evaluation metrics // <4>
3424
+ new MeanSquaredErrorMetric (), // <5>
3425
+ new RSquaredMetric ()); // <6>
3426
+ // end::evaluate-data-frame-evaluation-regression
3427
+
3428
+ EvaluateDataFrameRequest request = new EvaluateDataFrameRequest (indexName , null , evaluation );
3429
+ EvaluateDataFrameResponse response = client .machineLearning ().evaluateDataFrame (request , RequestOptions .DEFAULT );
3430
+
3431
+ // tag::evaluate-data-frame-results-regression
3432
+ MeanSquaredErrorMetric .Result meanSquaredErrorResult = response .getMetricByName (MeanSquaredErrorMetric .NAME ); // <1>
3433
+ double meanSquaredError = meanSquaredErrorResult .getError (); // <2>
3434
+
3435
+ RSquaredMetric .Result rSquaredResult = response .getMetricByName (RSquaredMetric .NAME ); // <3>
3436
+ double rSquared = rSquaredResult .getValue (); // <4>
3437
+ // end::evaluate-data-frame-results-regression
3438
+
3439
+ assertThat (meanSquaredError , closeTo (0.021 , 1e-3 ));
3440
+ assertThat (rSquared , closeTo (0.941 , 1e-3 ));
3441
+ }
3442
+ }
3443
+
3287
3444
public void testEstimateMemoryUsage () throws Exception {
3288
3445
createIndex ("estimate-test-source-index" );
3289
3446
BulkRequest bulkRequest =
0 commit comments