10
10
import org .elasticsearch .action .bulk .BulkRequestBuilder ;
11
11
import org .elasticsearch .action .bulk .BulkResponse ;
12
12
import org .elasticsearch .action .get .GetResponse ;
13
+ import org .elasticsearch .action .index .IndexAction ;
13
14
import org .elasticsearch .action .index .IndexRequest ;
14
15
import org .elasticsearch .action .search .SearchResponse ;
15
16
import org .elasticsearch .action .support .WriteRequest ;
24
25
import java .util .Collections ;
25
26
import java .util .List ;
26
27
import java .util .Map ;
28
+ import java .util .function .Function ;
27
29
28
30
import static org .hamcrest .Matchers .allOf ;
29
31
import static org .hamcrest .Matchers .equalTo ;
34
36
import static org .hamcrest .Matchers .in ;
35
37
import static org .hamcrest .Matchers .is ;
36
38
import static org .hamcrest .Matchers .lessThanOrEqualTo ;
39
+ import static org .hamcrest .Matchers .startsWith ;
37
40
38
41
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
39
42
43
+ private static final String BOOLEAN_FIELD = "boolean-field" ;
40
44
private static final String NUMERICAL_FIELD = "numerical-field" ;
45
+ private static final String DISCRETE_NUMERICAL_FIELD = "discrete-numerical-field" ;
41
46
private static final String KEYWORD_FIELD = "keyword-field" ;
42
- private static final List <Double > NUMERICAL_FIELD_VALUES = Collections .unmodifiableList (Arrays .asList (1.0 , 2.0 , 3.0 , 4.0 ));
47
+ private static final List <Boolean > BOOLEAN_FIELD_VALUES = Collections .unmodifiableList (Arrays .asList (false , true ));
48
+ private static final List <Double > NUMERICAL_FIELD_VALUES = Collections .unmodifiableList (Arrays .asList (1.0 , 2.0 ));
49
+ private static final List <Integer > DISCRETE_NUMERICAL_FIELD_VALUES = Collections .unmodifiableList (Arrays .asList (10 , 20 ));
43
50
private static final List <String > KEYWORD_FIELD_VALUES = Collections .unmodifiableList (Arrays .asList ("dog" , "cat" ));
44
51
45
52
private String jobId ;
@@ -53,7 +60,7 @@ public void cleanup() {
53
60
54
61
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows () throws Exception {
55
62
initialize ("classification_single_numeric_feature_and_mixed_data_set" );
56
- indexData (sourceIndex , 300 , 50 , NUMERICAL_FIELD_VALUES , KEYWORD_FIELD_VALUES );
63
+ indexData (sourceIndex , 300 , 50 , KEYWORD_FIELD );
57
64
58
65
DataFrameAnalyticsConfig config = buildAnalytics (jobId , sourceIndex , destIndex , null , new Classification (KEYWORD_FIELD ));
59
66
registerAnalytics (config );
@@ -91,7 +98,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
91
98
92
99
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred () throws Exception {
93
100
initialize ("classification_only_training_data_and_training_percent_is_100" );
94
- indexData (sourceIndex , 300 , 0 , NUMERICAL_FIELD_VALUES , KEYWORD_FIELD_VALUES );
101
+ indexData (sourceIndex , 300 , 0 , KEYWORD_FIELD );
95
102
96
103
DataFrameAnalyticsConfig config = buildAnalytics (jobId , sourceIndex , destIndex , null , new Classification (KEYWORD_FIELD ));
97
104
registerAnalytics (config );
@@ -126,17 +133,19 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
126
133
"Finished analysis" );
127
134
}
128
135
129
- public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty () throws Exception {
130
- initialize ("classification_only_training_data_and_training_percent_is_50" );
131
- indexData (sourceIndex , 300 , 0 , NUMERICAL_FIELD_VALUES , KEYWORD_FIELD_VALUES );
136
+ public <T > void testWithOnlyTrainingRowsAndTrainingPercentIsFifty (
137
+ String jobId , String dependentVariable , List <T > dependentVariableValues , Function <String , T > parser ) throws Exception {
138
+ initialize (jobId );
139
+ indexData (sourceIndex , 300 , 0 , dependentVariable );
132
140
141
+ int numTopClasses = 2 ;
133
142
DataFrameAnalyticsConfig config =
134
143
buildAnalytics (
135
144
jobId ,
136
145
sourceIndex ,
137
146
destIndex ,
138
147
null ,
139
- new Classification (KEYWORD_FIELD , BoostedTreeParamsTests .createRandom (), null , null , 50.0 ));
148
+ new Classification (dependentVariable , BoostedTreeParamsTests .createRandom (), null , numTopClasses , 50.0 ));
140
149
registerAnalytics (config );
141
150
putAnalytics (config );
142
151
@@ -151,8 +160,11 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception
151
160
SearchResponse sourceData = client ().prepareSearch (sourceIndex ).setTrackTotalHits (true ).setSize (1000 ).get ();
152
161
for (SearchHit hit : sourceData .getHits ()) {
153
162
Map <String , Object > resultsObject = getMlResultsObjectFromDestDoc (getDestDoc (config , hit ));
154
- assertThat (resultsObject .containsKey ("keyword-field_prediction" ), is (true ));
155
- assertThat ((String ) resultsObject .get ("keyword-field_prediction" ), is (in (KEYWORD_FIELD_VALUES )));
163
+ String predictedClassField = dependentVariable + "_prediction" ;
164
+ assertThat (resultsObject .containsKey (predictedClassField ), is (true ));
165
+ T predictedClassValue = parser .apply ((String ) resultsObject .get (predictedClassField ));
166
+ assertThat (predictedClassValue , is (in (dependentVariableValues )));
167
+ assertTopClasses (resultsObject , numTopClasses , dependentVariable , dependentVariableValues , parser );
156
168
157
169
assertThat (resultsObject .containsKey ("is_training" ), is (true ));
158
170
// Let's just assert there's both training and non-training results
@@ -161,11 +173,9 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception
161
173
} else {
162
174
nonTrainingRowsCount ++;
163
175
}
164
- assertThat (resultsObject .containsKey ("top_classes" ), is (false ));
165
176
}
166
177
assertThat (trainingRowsCount , greaterThan (0 ));
167
178
assertThat (nonTrainingRowsCount , greaterThan (0 ));
168
-
169
179
assertProgress (jobId , 100 , 100 , 100 , 100 );
170
180
assertThat (searchStoredProgress (jobId ).getHits ().getTotalHits ().value , equalTo (1L ));
171
181
assertThatAuditMessagesMatch (jobId ,
@@ -178,9 +188,32 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception
178
188
"Finished analysis" );
179
189
}
180
190
191
+ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsKeyword () throws Exception {
192
+ testWithOnlyTrainingRowsAndTrainingPercentIsFifty (
193
+ "classification_training_percent_is_50_keyword" , KEYWORD_FIELD , KEYWORD_FIELD_VALUES , String ::valueOf );
194
+ }
195
+
196
+ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsInteger () throws Exception {
197
+ testWithOnlyTrainingRowsAndTrainingPercentIsFifty (
198
+ "classification_training_percent_is_50_integer" , DISCRETE_NUMERICAL_FIELD , DISCRETE_NUMERICAL_FIELD_VALUES , Integer ::valueOf );
199
+ }
200
+
201
+ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble () throws Exception {
202
+ ElasticsearchStatusException e = expectThrows (
203
+ ElasticsearchStatusException .class ,
204
+ () -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty (
205
+ "classification_training_percent_is_50_double" , NUMERICAL_FIELD , NUMERICAL_FIELD_VALUES , Double ::valueOf ));
206
+ assertThat (e .getMessage (), startsWith ("invalid types [double] for required field [numerical-field];" ));
207
+ }
208
+
209
+ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsBoolean () throws Exception {
210
+ testWithOnlyTrainingRowsAndTrainingPercentIsFifty (
211
+ "classification_training_percent_is_50_boolean" , BOOLEAN_FIELD , BOOLEAN_FIELD_VALUES , Boolean ::valueOf );
212
+ }
213
+
181
214
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested () throws Exception {
182
215
initialize ("classification_top_classes_requested" );
183
- indexData (sourceIndex , 300 , 0 , NUMERICAL_FIELD_VALUES , KEYWORD_FIELD_VALUES );
216
+ indexData (sourceIndex , 300 , 50 , KEYWORD_FIELD );
184
217
185
218
int numTopClasses = 2 ;
186
219
DataFrameAnalyticsConfig config =
@@ -206,7 +239,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClasse
206
239
207
240
assertThat (resultsObject .containsKey ("keyword-field_prediction" ), is (true ));
208
241
assertThat ((String ) resultsObject .get ("keyword-field_prediction" ), is (in (KEYWORD_FIELD_VALUES )));
209
- assertTopClasses (resultsObject , numTopClasses );
242
+ assertTopClasses (resultsObject , numTopClasses , KEYWORD_FIELD , KEYWORD_FIELD_VALUES , String :: valueOf );
210
243
}
211
244
212
245
assertProgress (jobId , 100 , 100 , 100 , 100 );
@@ -223,7 +256,9 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClasse
223
256
224
257
public void testDependentVariableCardinalityTooHighError () {
225
258
initialize ("cardinality_too_high" );
226
- indexData (sourceIndex , 6 , 5 , NUMERICAL_FIELD_VALUES , Arrays .asList ("dog" , "cat" , "fox" ));
259
+ indexData (sourceIndex , 6 , 5 , KEYWORD_FIELD );
260
+ // Index one more document with a class different than the two already used.
261
+ client ().execute (IndexAction .INSTANCE , new IndexRequest (sourceIndex ).source (KEYWORD_FIELD , "fox" ));
227
262
228
263
DataFrameAnalyticsConfig config = buildAnalytics (jobId , sourceIndex , destIndex , null , new Classification (KEYWORD_FIELD ));
229
264
registerAnalytics (config );
@@ -240,28 +275,42 @@ private void initialize(String jobId) {
240
275
this .destIndex = sourceIndex + "_results" ;
241
276
}
242
277
243
- private static void indexData (String sourceIndex ,
244
- int numTrainingRows , int numNonTrainingRows ,
245
- List <Double > numericalFieldValues , List <String > keywordFieldValues ) {
278
+ private static void indexData (String sourceIndex , int numTrainingRows , int numNonTrainingRows , String dependentVariable ) {
246
279
client ().admin ().indices ().prepareCreate (sourceIndex )
247
- .addMapping ("_doc" , NUMERICAL_FIELD , "type=double" , KEYWORD_FIELD , "type=keyword" )
280
+ .addMapping ("_doc" ,
281
+ BOOLEAN_FIELD , "type=boolean" ,
282
+ NUMERICAL_FIELD , "type=double" ,
283
+ DISCRETE_NUMERICAL_FIELD , "type=integer" ,
284
+ KEYWORD_FIELD , "type=keyword" )
248
285
.get ();
249
286
250
287
BulkRequestBuilder bulkRequestBuilder = client ().prepareBulk ()
251
288
.setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE );
252
289
for (int i = 0 ; i < numTrainingRows ; i ++) {
253
- Double numericalValue = numericalFieldValues .get (i % numericalFieldValues .size ());
254
- String keywordValue = keywordFieldValues .get (i % keywordFieldValues .size ());
255
-
256
- IndexRequest indexRequest = new IndexRequest (sourceIndex )
257
- .source (NUMERICAL_FIELD , numericalValue , KEYWORD_FIELD , keywordValue );
290
+ List <Object > source = List .of (
291
+ BOOLEAN_FIELD , BOOLEAN_FIELD_VALUES .get (i % BOOLEAN_FIELD_VALUES .size ()),
292
+ NUMERICAL_FIELD , NUMERICAL_FIELD_VALUES .get (i % NUMERICAL_FIELD_VALUES .size ()),
293
+ DISCRETE_NUMERICAL_FIELD , DISCRETE_NUMERICAL_FIELD_VALUES .get (i % DISCRETE_NUMERICAL_FIELD_VALUES .size ()),
294
+ KEYWORD_FIELD , KEYWORD_FIELD_VALUES .get (i % KEYWORD_FIELD_VALUES .size ()));
295
+ IndexRequest indexRequest = new IndexRequest (sourceIndex ).source (source .toArray ());
258
296
bulkRequestBuilder .add (indexRequest );
259
297
}
260
298
for (int i = numTrainingRows ; i < numTrainingRows + numNonTrainingRows ; i ++) {
261
- Double numericalValue = numericalFieldValues .get (i % numericalFieldValues .size ());
262
-
263
- IndexRequest indexRequest = new IndexRequest (sourceIndex )
264
- .source (NUMERICAL_FIELD , numericalValue );
299
+ List <Object > source = new ArrayList <>();
300
+ if (BOOLEAN_FIELD .equals (dependentVariable ) == false ) {
301
+ source .addAll (List .of (BOOLEAN_FIELD , BOOLEAN_FIELD_VALUES .get (i % BOOLEAN_FIELD_VALUES .size ())));
302
+ }
303
+ if (NUMERICAL_FIELD .equals (dependentVariable ) == false ) {
304
+ source .addAll (List .of (NUMERICAL_FIELD , NUMERICAL_FIELD_VALUES .get (i % NUMERICAL_FIELD_VALUES .size ())));
305
+ }
306
+ if (DISCRETE_NUMERICAL_FIELD .equals (dependentVariable ) == false ) {
307
+ source .addAll (
308
+ List .of (DISCRETE_NUMERICAL_FIELD , DISCRETE_NUMERICAL_FIELD_VALUES .get (i % DISCRETE_NUMERICAL_FIELD_VALUES .size ())));
309
+ }
310
+ if (KEYWORD_FIELD .equals (dependentVariable ) == false ) {
311
+ source .addAll (List .of (KEYWORD_FIELD , KEYWORD_FIELD_VALUES .get (i % KEYWORD_FIELD_VALUES .size ())));
312
+ }
313
+ IndexRequest indexRequest = new IndexRequest (sourceIndex ).source (source .toArray ());
265
314
bulkRequestBuilder .add (indexRequest );
266
315
}
267
316
BulkResponse bulkResponse = bulkRequestBuilder .get ();
@@ -289,7 +338,12 @@ private static Map<String, Object> getMlResultsObjectFromDestDoc(Map<String, Obj
289
338
return resultsObject ;
290
339
}
291
340
292
- private static void assertTopClasses (Map <String , Object > resultsObject , int numTopClasses ) {
341
+ private static <T > void assertTopClasses (
342
+ Map <String , Object > resultsObject ,
343
+ int numTopClasses ,
344
+ String dependentVariable ,
345
+ List <T > dependentVariableValues ,
346
+ Function <String , T > parser ) {
293
347
assertThat (resultsObject .containsKey ("top_classes" ), is (true ));
294
348
@ SuppressWarnings ("unchecked" )
295
349
List <Map <String , Object >> topClasses = (List <Map <String , Object >>) resultsObject .get ("top_classes" );
@@ -302,9 +356,9 @@ private static void assertTopClasses(Map<String, Object> resultsObject, int numT
302
356
classProbabilities .add ((Double ) topClass .get ("class_probability" ));
303
357
}
304
358
// Assert that all the predicted class names come from the set of keyword field values.
305
- classNames .forEach (className -> assertThat (className , is (in (KEYWORD_FIELD_VALUES ))));
359
+ classNames .forEach (className -> assertThat (parser . apply ( className ) , is (in (dependentVariableValues ))));
306
360
// Assert that the first class listed in top classes is the same as the predicted class.
307
- assertThat (classNames .get (0 ), equalTo (resultsObject .get ("keyword-field_prediction " )));
361
+ assertThat (classNames .get (0 ), equalTo (resultsObject .get (dependentVariable + "_prediction " )));
308
362
// Assert that all the class probabilities lie within [0, 1] interval.
309
363
classProbabilities .forEach (p -> assertThat (p , allOf (greaterThanOrEqualTo (0.0 ), lessThanOrEqualTo (1.0 ))));
310
364
// Assert that the top classes are listed in the order of decreasing probabilities.
0 commit comments