@@ -27,8 +27,12 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
27
27
28
28
private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index" ;
29
29
30
- private static final String ACTUAL_CLASS_FIELD = "actual_class_field" ;
31
- private static final String PREDICTED_CLASS_FIELD = "predicted_class_field" ;
30
+ private static final String ANIMAL_NAME_FIELD = "animal_name" ;
31
+ private static final String ANIMAL_NAME_PREDICTION_FIELD = "animal_name_prediction" ;
32
+ private static final String NO_LEGS_FIELD = "no_legs" ;
33
+ private static final String NO_LEGS_PREDICTION_FIELD = "no_legs_prediction" ;
34
+ private static final String IS_PREDATOR_FIELD = "predator" ;
35
+ private static final String IS_PREDATOR_PREDICTION_FIELD = "predator_prediction" ;
32
36
33
37
@ Before
34
38
public void setup () {
@@ -40,9 +44,9 @@ public void cleanup() {
40
44
cleanUp ();
41
45
}
42
46
43
- public void testEvaluate_MulticlassClassification_DefaultMetrics () {
47
+ public void testEvaluate_DefaultMetrics () {
44
48
EvaluateDataFrameAction .Response evaluateDataFrameResponse =
45
- evaluateDataFrame (ANIMALS_DATA_INDEX , new Classification (ACTUAL_CLASS_FIELD , PREDICTED_CLASS_FIELD , null ));
49
+ evaluateDataFrame (ANIMALS_DATA_INDEX , new Classification (ANIMAL_NAME_FIELD , ANIMAL_NAME_PREDICTION_FIELD , null ));
46
50
47
51
assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Classification .NAME .getPreferredName ()));
48
52
assertThat (evaluateDataFrameResponse .getMetrics (), hasSize (1 ));
@@ -51,9 +55,10 @@ public void testEvaluate_MulticlassClassification_DefaultMetrics() {
51
55
equalTo (MulticlassConfusionMatrix .NAME .getPreferredName ()));
52
56
}
53
57
54
- public void testEvaluate_MulticlassClassification_Accuracy () {
58
+ public void testEvaluate_Accuracy_KeywordField () {
55
59
EvaluateDataFrameAction .Response evaluateDataFrameResponse =
56
- evaluateDataFrame (ANIMALS_DATA_INDEX , new Classification (ACTUAL_CLASS_FIELD , PREDICTED_CLASS_FIELD , List .of (new Accuracy ())));
60
+ evaluateDataFrame (
61
+ ANIMALS_DATA_INDEX , new Classification (ANIMAL_NAME_FIELD , ANIMAL_NAME_PREDICTION_FIELD , List .of (new Accuracy ())));
57
62
58
63
assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Classification .NAME .getPreferredName ()));
59
64
assertThat (evaluateDataFrameResponse .getMetrics (), hasSize (1 ));
@@ -72,11 +77,50 @@ public void testEvaluate_MulticlassClassification_Accuracy() {
72
77
assertThat (accuracyResult .getOverallAccuracy (), equalTo (5.0 / 75 ));
73
78
}
74
79
75
- public void testEvaluate_MulticlassClassification_AccuracyAndConfusionMatrixMetricWithDefaultSize () {
80
+ public void testEvaluate_Accuracy_IntegerField () {
81
+ EvaluateDataFrameAction .Response evaluateDataFrameResponse =
82
+ evaluateDataFrame (
83
+ ANIMALS_DATA_INDEX , new Classification (NO_LEGS_FIELD , NO_LEGS_PREDICTION_FIELD , List .of (new Accuracy ())));
84
+
85
+ assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Classification .NAME .getPreferredName ()));
86
+ assertThat (evaluateDataFrameResponse .getMetrics (), hasSize (1 ));
87
+
88
+ Accuracy .Result accuracyResult = (Accuracy .Result ) evaluateDataFrameResponse .getMetrics ().get (0 );
89
+ assertThat (accuracyResult .getMetricName (), equalTo (Accuracy .NAME .getPreferredName ()));
90
+ assertThat (
91
+ accuracyResult .getActualClasses (),
92
+ equalTo (List .of (
93
+ new Accuracy .ActualClass ("1" , 15 , 1.0 / 15 ),
94
+ new Accuracy .ActualClass ("2" , 15 , 2.0 / 15 ),
95
+ new Accuracy .ActualClass ("3" , 15 , 3.0 / 15 ),
96
+ new Accuracy .ActualClass ("4" , 15 , 4.0 / 15 ),
97
+ new Accuracy .ActualClass ("5" , 15 , 5.0 / 15 ))));
98
+ assertThat (accuracyResult .getOverallAccuracy (), equalTo (15.0 / 75 ));
99
+ }
100
+
101
+ public void testEvaluate_Accuracy_BooleanField () {
102
+ EvaluateDataFrameAction .Response evaluateDataFrameResponse =
103
+ evaluateDataFrame (
104
+ ANIMALS_DATA_INDEX , new Classification (IS_PREDATOR_FIELD , IS_PREDATOR_PREDICTION_FIELD , List .of (new Accuracy ())));
105
+
106
+ assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Classification .NAME .getPreferredName ()));
107
+ assertThat (evaluateDataFrameResponse .getMetrics (), hasSize (1 ));
108
+
109
+ Accuracy .Result accuracyResult = (Accuracy .Result ) evaluateDataFrameResponse .getMetrics ().get (0 );
110
+ assertThat (accuracyResult .getMetricName (), equalTo (Accuracy .NAME .getPreferredName ()));
111
+ assertThat (
112
+ accuracyResult .getActualClasses (),
113
+ equalTo (List .of (
114
+ new Accuracy .ActualClass ("true" , 45 , 27.0 / 45 ),
115
+ new Accuracy .ActualClass ("false" , 30 , 18.0 / 30 ))));
116
+ assertThat (accuracyResult .getOverallAccuracy (), equalTo (45.0 / 75 ));
117
+ }
118
+
119
+ public void testEvaluate_ConfusionMatrixMetricWithDefaultSize () {
76
120
EvaluateDataFrameAction .Response evaluateDataFrameResponse =
77
121
evaluateDataFrame (
78
122
ANIMALS_DATA_INDEX ,
79
- new Classification (ACTUAL_CLASS_FIELD , PREDICTED_CLASS_FIELD , List .of (new MulticlassConfusionMatrix ())));
123
+ new Classification (ANIMAL_NAME_FIELD , ANIMAL_NAME_PREDICTION_FIELD , List .of (new MulticlassConfusionMatrix ())));
80
124
81
125
assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Classification .NAME .getPreferredName ()));
82
126
assertThat (evaluateDataFrameResponse .getMetrics (), hasSize (1 ));
@@ -135,11 +179,11 @@ public void testEvaluate_MulticlassClassification_AccuracyAndConfusionMatrixMetr
135
179
assertThat (confusionMatrixResult .getOtherActualClassCount (), equalTo (0L ));
136
180
}
137
181
138
- public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize () {
182
+ public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize () {
139
183
EvaluateDataFrameAction .Response evaluateDataFrameResponse =
140
184
evaluateDataFrame (
141
185
ANIMALS_DATA_INDEX ,
142
- new Classification (ACTUAL_CLASS_FIELD , PREDICTED_CLASS_FIELD , List .of (new MulticlassConfusionMatrix (3 ))));
186
+ new Classification (ANIMAL_NAME_FIELD , ANIMAL_NAME_PREDICTION_FIELD , List .of (new MulticlassConfusionMatrix (3 ))));
143
187
144
188
assertThat (evaluateDataFrameResponse .getEvaluationName (), equalTo (Classification .NAME .getPreferredName ()));
145
189
assertThat (evaluateDataFrameResponse .getMetrics (), hasSize (1 ));
@@ -166,20 +210,30 @@ public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserP
166
210
167
211
private static void indexAnimalsData (String indexName ) {
168
212
client ().admin ().indices ().prepareCreate (indexName )
169
- .addMapping ("_doc" , ACTUAL_CLASS_FIELD , "type=keyword" , PREDICTED_CLASS_FIELD , "type=keyword" )
213
+ .addMapping ("_doc" ,
214
+ ANIMAL_NAME_FIELD , "type=keyword" ,
215
+ ANIMAL_NAME_PREDICTION_FIELD , "type=keyword" ,
216
+ NO_LEGS_FIELD , "type=integer" ,
217
+ NO_LEGS_PREDICTION_FIELD , "type=integer" ,
218
+ IS_PREDATOR_FIELD , "type=boolean" ,
219
+ IS_PREDATOR_PREDICTION_FIELD , "type=boolean" )
170
220
.get ();
171
221
172
- List <String > classNames = List .of ("dog" , "cat" , "mouse" , "ant" , "fox" );
222
+ List <String > animalNames = List .of ("dog" , "cat" , "mouse" , "ant" , "fox" );
173
223
BulkRequestBuilder bulkRequestBuilder = client ().prepareBulk ()
174
224
.setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE );
175
- for (int i = 0 ; i < classNames .size (); i ++) {
176
- for (int j = 0 ; j < classNames .size (); j ++) {
225
+ for (int i = 0 ; i < animalNames .size (); i ++) {
226
+ for (int j = 0 ; j < animalNames .size (); j ++) {
177
227
for (int k = 0 ; k < j + 1 ; k ++) {
178
228
bulkRequestBuilder .add (
179
229
new IndexRequest (indexName )
180
230
.source (
181
- ACTUAL_CLASS_FIELD , classNames .get (i ),
182
- PREDICTED_CLASS_FIELD , classNames .get ((i + j ) % classNames .size ())));
231
+ ANIMAL_NAME_FIELD , animalNames .get (i ),
232
+ ANIMAL_NAME_PREDICTION_FIELD , animalNames .get ((i + j ) % animalNames .size ()),
233
+ NO_LEGS_FIELD , i + 1 ,
234
+ NO_LEGS_PREDICTION_FIELD , j + 1 ,
235
+ IS_PREDATOR_FIELD , i % 2 == 0 ,
236
+ IS_PREDATOR_PREDICTION_FIELD , (i + j ) % 2 == 0 ));
183
237
}
184
238
}
185
239
}
0 commit comments