@@ -63,6 +63,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
63
63
private static final String NUMERICAL_FIELD = "numerical-field" ;
64
64
private static final String DISCRETE_NUMERICAL_FIELD = "discrete-numerical-field" ;
65
65
private static final String KEYWORD_FIELD = "keyword-field" ;
66
+ private static final String NESTED_FIELD = "outer-field.inner-field" ;
67
+ private static final String ALIAS_TO_KEYWORD_FIELD = "alias-to-keyword-field" ;
68
+ private static final String ALIAS_TO_NESTED_FIELD = "alias-to-nested-field" ;
66
69
private static final List <Boolean > BOOLEAN_FIELD_VALUES = Collections .unmodifiableList (Arrays .asList (false , true ));
67
70
private static final List <Double > NUMERICAL_FIELD_VALUES = Collections .unmodifiableList (Arrays .asList (1.0 , 2.0 ));
68
71
private static final List <Integer > DISCRETE_NUMERICAL_FIELD_VALUES = Collections .unmodifiableList (Arrays .asList (10 , 20 ));
@@ -301,7 +304,6 @@ public void testStopAndRestart() throws Exception {
301
304
assertInferenceModelPersisted (jobId );
302
305
assertMlResultsFieldMappings (predictedClassField , "keyword" );
303
306
assertEvaluation (KEYWORD_FIELD , KEYWORD_FIELD_VALUES , "ml." + predictedClassField );
304
-
305
307
}
306
308
307
309
public void testDependentVariableCardinalityTooHighError () throws Exception {
@@ -343,6 +345,63 @@ public void testDependentVariableCardinalityTooHighButWithQueryMakesItWithinRang
343
345
assertProgress (jobId , 100 , 100 , 100 , 100 );
344
346
}
345
347
348
+ public void testDependentVariableIsNested () throws Exception {
349
+ initialize ("dependent_variable_is_nested" );
350
+ String predictedClassField = NESTED_FIELD + "_prediction" ;
351
+ indexData (sourceIndex , 100 , 0 , NESTED_FIELD );
352
+
353
+ DataFrameAnalyticsConfig config = buildAnalytics (jobId , sourceIndex , destIndex , null , new Classification (NESTED_FIELD ));
354
+ registerAnalytics (config );
355
+ putAnalytics (config );
356
+ startAnalytics (jobId );
357
+ waitUntilAnalyticsIsStopped (jobId );
358
+
359
+ assertProgress (jobId , 100 , 100 , 100 , 100 );
360
+ assertThat (searchStoredProgress (jobId ).getHits ().getTotalHits ().value , equalTo (1L ));
361
+ assertModelStatePersisted (stateDocId ());
362
+ assertInferenceModelPersisted (jobId );
363
+ assertMlResultsFieldMappings (predictedClassField , "keyword" );
364
+ assertEvaluation (NESTED_FIELD , KEYWORD_FIELD_VALUES , "ml." + predictedClassField );
365
+ }
366
+
367
+ public void testDependentVariableIsAliasToKeyword () throws Exception {
368
+ initialize ("dependent_variable_is_alias" );
369
+ String predictedClassField = ALIAS_TO_KEYWORD_FIELD + "_prediction" ;
370
+ indexData (sourceIndex , 100 , 0 , KEYWORD_FIELD );
371
+
372
+ DataFrameAnalyticsConfig config = buildAnalytics (jobId , sourceIndex , destIndex , null , new Classification (ALIAS_TO_KEYWORD_FIELD ));
373
+ registerAnalytics (config );
374
+ putAnalytics (config );
375
+ startAnalytics (jobId );
376
+ waitUntilAnalyticsIsStopped (jobId );
377
+
378
+ assertProgress (jobId , 100 , 100 , 100 , 100 );
379
+ assertThat (searchStoredProgress (jobId ).getHits ().getTotalHits ().value , equalTo (1L ));
380
+ assertModelStatePersisted (stateDocId ());
381
+ assertInferenceModelPersisted (jobId );
382
+ assertMlResultsFieldMappings (predictedClassField , "keyword" );
383
+ assertEvaluation (ALIAS_TO_KEYWORD_FIELD , KEYWORD_FIELD_VALUES , "ml." + predictedClassField );
384
+ }
385
+
386
+ public void testDependentVariableIsAliasToNested () throws Exception {
387
+ initialize ("dependent_variable_is_alias_to_nested" );
388
+ String predictedClassField = ALIAS_TO_NESTED_FIELD + "_prediction" ;
389
+ indexData (sourceIndex , 100 , 0 , NESTED_FIELD );
390
+
391
+ DataFrameAnalyticsConfig config = buildAnalytics (jobId , sourceIndex , destIndex , null , new Classification (ALIAS_TO_NESTED_FIELD ));
392
+ registerAnalytics (config );
393
+ putAnalytics (config );
394
+ startAnalytics (jobId );
395
+ waitUntilAnalyticsIsStopped (jobId );
396
+
397
+ assertProgress (jobId , 100 , 100 , 100 , 100 );
398
+ assertThat (searchStoredProgress (jobId ).getHits ().getTotalHits ().value , equalTo (1L ));
399
+ assertModelStatePersisted (stateDocId ());
400
+ assertInferenceModelPersisted (jobId );
401
+ assertMlResultsFieldMappings (predictedClassField , "keyword" );
402
+ assertEvaluation (ALIAS_TO_NESTED_FIELD , KEYWORD_FIELD_VALUES , "ml." + predictedClassField );
403
+ }
404
+
346
405
public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet () throws Exception {
347
406
String sourceIndex = "classification_two_jobs_with_same_randomize_seed_source" ;
348
407
String dependentVariable = KEYWORD_FIELD ;
@@ -434,7 +493,10 @@ private static void createIndex(String index) {
434
493
BOOLEAN_FIELD , "type=boolean" ,
435
494
NUMERICAL_FIELD , "type=double" ,
436
495
DISCRETE_NUMERICAL_FIELD , "type=integer" ,
437
- KEYWORD_FIELD , "type=keyword" )
496
+ KEYWORD_FIELD , "type=keyword" ,
497
+ NESTED_FIELD , "type=keyword" ,
498
+ ALIAS_TO_KEYWORD_FIELD , "type=alias,path=" + KEYWORD_FIELD ,
499
+ ALIAS_TO_NESTED_FIELD , "type=alias,path=" + NESTED_FIELD )
438
500
.get ();
439
501
}
440
502
@@ -446,7 +508,8 @@ private static void indexData(String sourceIndex, int numTrainingRows, int numNo
446
508
BOOLEAN_FIELD , BOOLEAN_FIELD_VALUES .get (i % BOOLEAN_FIELD_VALUES .size ()),
447
509
NUMERICAL_FIELD , NUMERICAL_FIELD_VALUES .get (i % NUMERICAL_FIELD_VALUES .size ()),
448
510
DISCRETE_NUMERICAL_FIELD , DISCRETE_NUMERICAL_FIELD_VALUES .get (i % DISCRETE_NUMERICAL_FIELD_VALUES .size ()),
449
- KEYWORD_FIELD , KEYWORD_FIELD_VALUES .get (i % KEYWORD_FIELD_VALUES .size ()));
511
+ KEYWORD_FIELD , KEYWORD_FIELD_VALUES .get (i % KEYWORD_FIELD_VALUES .size ()),
512
+ NESTED_FIELD , KEYWORD_FIELD_VALUES .get (i % KEYWORD_FIELD_VALUES .size ()));
450
513
IndexRequest indexRequest = new IndexRequest (sourceIndex ).source (source .toArray ());
451
514
bulkRequestBuilder .add (indexRequest );
452
515
}
@@ -465,6 +528,9 @@ private static void indexData(String sourceIndex, int numTrainingRows, int numNo
465
528
if (KEYWORD_FIELD .equals (dependentVariable ) == false ) {
466
529
source .addAll (List .of (KEYWORD_FIELD , KEYWORD_FIELD_VALUES .get (i % KEYWORD_FIELD_VALUES .size ())));
467
530
}
531
+ if (NESTED_FIELD .equals (dependentVariable ) == false ) {
532
+ source .addAll (List .of (NESTED_FIELD , KEYWORD_FIELD_VALUES .get (i % KEYWORD_FIELD_VALUES .size ())));
533
+ }
468
534
IndexRequest indexRequest = new IndexRequest (sourceIndex ).source (source .toArray ());
469
535
bulkRequestBuilder .add (indexRequest );
470
536
}
@@ -487,10 +553,12 @@ private static Map<String, Object> getDestDoc(DataFrameAnalyticsConfig config, S
487
553
}
488
554
489
555
/**
490
- * Wrapper around extractValue with implicit casting to the appropriate type.
556
+ * Wrapper around extractValue that:
557
+ * - allows dots (".") in the path elements provided as arguments
558
+ * - supports implicit casting to the appropriate type
491
559
*/
492
560
private static <T > T getFieldValue (Map <String , Object > doc , String ... path ) {
493
- return (T )extractValue (doc , path );
561
+ return (T )extractValue (String . join ( "." , path ), doc );
494
562
}
495
563
496
564
private static <T > void assertTopClasses (Map <String , Object > resultsObject ,
@@ -582,8 +650,14 @@ private void assertMlResultsFieldMappings(String predictedClassField, String exp
582
650
.mappings ()
583
651
.get (destIndex )
584
652
.sourceAsMap ();
585
- assertThat (getFieldValue (mappings , "properties" , "ml" , "properties" , predictedClassField , "type" ), equalTo (expectedType ));
586
653
assertThat (
654
+ mappings .toString (),
655
+ getFieldValue (
656
+ mappings ,
657
+ "properties" , "ml" , "properties" , String .join (".properties." , predictedClassField .split ("\\ ." )), "type" ),
658
+ equalTo (expectedType ));
659
+ assertThat (
660
+ mappings .toString (),
587
661
getFieldValue (mappings , "properties" , "ml" , "properties" , "top_classes" , "properties" , "class_name" , "type" ),
588
662
equalTo (expectedType ));
589
663
}
0 commit comments