16
16
import org .elasticsearch .action .search .SearchResponse ;
17
17
import org .elasticsearch .action .support .WriteRequest ;
18
18
import org .elasticsearch .search .SearchHit ;
19
+ import org .elasticsearch .xpack .core .ml .action .EvaluateDataFrameAction ;
19
20
import org .elasticsearch .xpack .core .ml .dataframe .DataFrameAnalyticsConfig ;
20
21
import org .elasticsearch .xpack .core .ml .dataframe .analyses .BoostedTreeParamsTests ;
21
22
import org .elasticsearch .xpack .core .ml .dataframe .analyses .Classification ;
23
+ import org .elasticsearch .xpack .core .ml .dataframe .evaluation .classification .MulticlassConfusionMatrix ;
24
+ import org .elasticsearch .xpack .core .ml .dataframe .evaluation .classification .MulticlassConfusionMatrix .ActualClass ;
25
+ import org .elasticsearch .xpack .core .ml .dataframe .evaluation .classification .MulticlassConfusionMatrix .PredictedClass ;
22
26
import org .junit .After ;
23
27
24
28
import java .util .ArrayList ;
28
32
import java .util .Map ;
29
33
import java .util .function .Function ;
30
34
35
+ import static java .util .stream .Collectors .toList ;
31
36
import static org .hamcrest .Matchers .allOf ;
32
37
import static org .hamcrest .Matchers .equalTo ;
33
38
import static org .hamcrest .Matchers .greaterThan ;
@@ -48,7 +53,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
48
53
private static final List <Boolean > BOOLEAN_FIELD_VALUES = Collections .unmodifiableList (Arrays .asList (false , true ));
49
54
private static final List <Double > NUMERICAL_FIELD_VALUES = Collections .unmodifiableList (Arrays .asList (1.0 , 2.0 ));
50
55
private static final List <Integer > DISCRETE_NUMERICAL_FIELD_VALUES = Collections .unmodifiableList (Arrays .asList (10 , 20 ));
51
- private static final List <String > KEYWORD_FIELD_VALUES = Collections .unmodifiableList (Arrays .asList ("dog " , "cat " ));
56
+ private static final List <String > KEYWORD_FIELD_VALUES = Collections .unmodifiableList (Arrays .asList ("cat " , "dog " ));
52
57
53
58
private String jobId ;
54
59
private String sourceIndex ;
@@ -97,6 +102,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
97
102
"Creating destination index [" + destIndex + "]" ,
98
103
"Finished reindexing to destination index [" + destIndex + "]" ,
99
104
"Finished analysis" );
105
+ assertEvaluation (KEYWORD_FIELD , KEYWORD_FIELD_VALUES , "ml.keyword-field_prediction" );
100
106
}
101
107
102
108
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred () throws Exception {
@@ -136,11 +142,13 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
136
142
"Creating destination index [" + destIndex + "]" ,
137
143
"Finished reindexing to destination index [" + destIndex + "]" ,
138
144
"Finished analysis" );
145
+ assertEvaluation (KEYWORD_FIELD , KEYWORD_FIELD_VALUES , "ml.keyword-field_prediction" );
139
146
}
140
147
141
148
public <T > void testWithOnlyTrainingRowsAndTrainingPercentIsFifty (
142
149
String jobId , String dependentVariable , List <T > dependentVariableValues , Function <String , T > parser ) throws Exception {
143
150
initialize (jobId );
151
+ String predictedClassField = dependentVariable + "_prediction" ;
144
152
indexData (sourceIndex , 300 , 0 , dependentVariable );
145
153
146
154
int numTopClasses = 2 ;
@@ -166,7 +174,6 @@ public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
166
174
SearchResponse sourceData = client ().prepareSearch (sourceIndex ).setTrackTotalHits (true ).setSize (1000 ).get ();
167
175
for (SearchHit hit : sourceData .getHits ()) {
168
176
Map <String , Object > resultsObject = getMlResultsObjectFromDestDoc (getDestDoc (config , hit ));
169
- String predictedClassField = dependentVariable + "_prediction" ;
170
177
assertThat (resultsObject .containsKey (predictedClassField ), is (true ));
171
178
T predictedClassValue = parser .apply ((String ) resultsObject .get (predictedClassField ));
172
179
assertThat (predictedClassValue , is (in (dependentVariableValues )));
@@ -194,6 +201,10 @@ public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
194
201
"Creating destination index [" + destIndex + "]" ,
195
202
"Finished reindexing to destination index [" + destIndex + "]" ,
196
203
"Finished analysis" );
204
+ assertEvaluation (
205
+ dependentVariable ,
206
+ dependentVariableValues .stream ().map (String ::valueOf ).collect (toList ()),
207
+ "ml." + predictedClassField );
197
208
}
198
209
199
210
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsKeyword () throws Exception {
@@ -219,51 +230,6 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableI
219
230
"classification_training_percent_is_50_boolean" , BOOLEAN_FIELD , BOOLEAN_FIELD_VALUES , Boolean ::valueOf );
220
231
}
221
232
222
- public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested () throws Exception {
223
- initialize ("classification_top_classes_requested" );
224
- indexData (sourceIndex , 300 , 50 , KEYWORD_FIELD );
225
-
226
- int numTopClasses = 2 ;
227
- DataFrameAnalyticsConfig config =
228
- buildAnalytics (
229
- jobId ,
230
- sourceIndex ,
231
- destIndex ,
232
- null ,
233
- new Classification (KEYWORD_FIELD , BoostedTreeParamsTests .createRandom (), null , numTopClasses , null ));
234
- registerAnalytics (config );
235
- putAnalytics (config );
236
-
237
- assertIsStopped (jobId );
238
- assertProgress (jobId , 0 , 0 , 0 , 0 );
239
-
240
- startAnalytics (jobId );
241
- waitUntilAnalyticsIsStopped (jobId );
242
-
243
- client ().admin ().indices ().refresh (new RefreshRequest (destIndex ));
244
- SearchResponse sourceData = client ().prepareSearch (sourceIndex ).setTrackTotalHits (true ).setSize (1000 ).get ();
245
- for (SearchHit hit : sourceData .getHits ()) {
246
- Map <String , Object > destDoc = getDestDoc (config , hit );
247
- Map <String , Object > resultsObject = getMlResultsObjectFromDestDoc (destDoc );
248
-
249
- assertThat (resultsObject .containsKey ("keyword-field_prediction" ), is (true ));
250
- assertThat ((String ) resultsObject .get ("keyword-field_prediction" ), is (in (KEYWORD_FIELD_VALUES )));
251
- assertTopClasses (resultsObject , numTopClasses , KEYWORD_FIELD , KEYWORD_FIELD_VALUES , String ::valueOf );
252
- }
253
-
254
- assertProgress (jobId , 100 , 100 , 100 , 100 );
255
- assertThat (searchStoredProgress (jobId ).getHits ().getTotalHits ().value , equalTo (1L ));
256
- assertInferenceModelPersisted (jobId );
257
- assertThatAuditMessagesMatch (jobId ,
258
- "Created analytics with analysis type [classification]" ,
259
- "Estimated memory usage for this analytics to be" ,
260
- "Starting analytics on node" ,
261
- "Started analytics" ,
262
- "Creating destination index [" + destIndex + "]" ,
263
- "Finished reindexing to destination index [" + destIndex + "]" ,
264
- "Finished analysis" );
265
- }
266
-
267
233
public void testDependentVariableCardinalityTooHighError () {
268
234
initialize ("cardinality_too_high" );
269
235
indexData (sourceIndex , 6 , 5 , KEYWORD_FIELD );
@@ -377,4 +343,26 @@ private static <T> void assertTopClasses(
377
343
// Assert that the top classes are listed in the order of decreasing probabilities.
378
344
assertThat (Ordering .natural ().reverse ().isOrdered (classProbabilities ), is (true ));
379
345
}
346
+
347
+ private void assertEvaluation (String dependentVariable , List <String > dependentVariableValues , String predictedClassField ) {
348
+ EvaluateDataFrameAction .Response evaluateDataFrameResponse =
349
+ evaluateDataFrame (
350
+ destIndex ,
351
+ new org .elasticsearch .xpack .core .ml .dataframe .evaluation .classification .Classification (
352
+ dependentVariable , predictedClassField , null ));
353
+ assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Classification .NAME .getPreferredName ()));
354
+ assertThat (evaluateDataFrameResponse .getMetrics ().size (), equalTo (1 ));
355
+ MulticlassConfusionMatrix .Result confusionMatrixResult =
356
+ (MulticlassConfusionMatrix .Result ) evaluateDataFrameResponse .getMetrics ().get (0 );
357
+ assertThat (confusionMatrixResult .getMetricName (), equalTo (MulticlassConfusionMatrix .NAME .getPreferredName ()));
358
+ List <ActualClass > actualClasses = confusionMatrixResult .getConfusionMatrix ();
359
+ assertThat (actualClasses .stream ().map (ActualClass ::getActualClass ).collect (toList ()), equalTo (dependentVariableValues ));
360
+ for (ActualClass actualClass : actualClasses ) {
361
+ assertThat (actualClass .getOtherPredictedClassDocCount (), equalTo (0L ));
362
+ assertThat (
363
+ actualClass .getPredictedClasses ().stream ().map (PredictedClass ::getPredictedClass ).collect (toList ()),
364
+ equalTo (dependentVariableValues ));
365
+ }
366
+ assertThat (confusionMatrixResult .getOtherActualClassCount (), equalTo (0L ));
367
+ }
380
368
}
0 commit comments