@@ -63,6 +63,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
63
63
private static final String NUMERICAL_FIELD = "numerical-field" ;
64
64
private static final String DISCRETE_NUMERICAL_FIELD = "discrete-numerical-field" ;
65
65
private static final String KEYWORD_FIELD = "keyword-field" ;
66
+ private static final String NESTED_FIELD = "outer-field.inner-field" ;
67
+ private static final String ALIAS_TO_KEYWORD_FIELD = "alias-to-keyword-field" ;
68
+ private static final String ALIAS_TO_NESTED_FIELD = "alias-to-nested-field" ;
66
69
private static final List <Boolean > BOOLEAN_FIELD_VALUES = Collections .unmodifiableList (Arrays .asList (false , true ));
67
70
private static final List <Double > NUMERICAL_FIELD_VALUES = Collections .unmodifiableList (Arrays .asList (1.0 , 2.0 ));
68
71
private static final List <Integer > DISCRETE_NUMERICAL_FIELD_VALUES = Collections .unmodifiableList (Arrays .asList (10 , 20 ));
@@ -301,7 +304,6 @@ public void testStopAndRestart() throws Exception {
301
304
assertInferenceModelPersisted (jobId );
302
305
assertMlResultsFieldMappings (predictedClassField , "keyword" );
303
306
assertEvaluation (KEYWORD_FIELD , KEYWORD_FIELD_VALUES , "ml." + predictedClassField );
304
-
305
307
}
306
308
307
309
public void testDependentVariableCardinalityTooHighError () throws Exception {
@@ -342,6 +344,63 @@ public void testDependentVariableCardinalityTooHighButWithQueryMakesItWithinRang
342
344
assertProgress (jobId , 100 , 100 , 100 , 100 );
343
345
}
344
346
347
+ public void testDependentVariableIsNested () throws Exception {
348
+ initialize ("dependent_variable_is_nested" );
349
+ String predictedClassField = NESTED_FIELD + "_prediction" ;
350
+ indexData (sourceIndex , 100 , 0 , NESTED_FIELD );
351
+
352
+ DataFrameAnalyticsConfig config = buildAnalytics (jobId , sourceIndex , destIndex , null , new Classification (NESTED_FIELD ));
353
+ registerAnalytics (config );
354
+ putAnalytics (config );
355
+ startAnalytics (jobId );
356
+ waitUntilAnalyticsIsStopped (jobId );
357
+
358
+ assertProgress (jobId , 100 , 100 , 100 , 100 );
359
+ assertThat (searchStoredProgress (jobId ).getHits ().getTotalHits ().value , equalTo (1L ));
360
+ assertModelStatePersisted (stateDocId ());
361
+ assertInferenceModelPersisted (jobId );
362
+ assertMlResultsFieldMappings (predictedClassField , "keyword" );
363
+ assertEvaluation (NESTED_FIELD , KEYWORD_FIELD_VALUES , "ml." + predictedClassField );
364
+ }
365
+
366
+ public void testDependentVariableIsAliasToKeyword () throws Exception {
367
+ initialize ("dependent_variable_is_alias" );
368
+ String predictedClassField = ALIAS_TO_KEYWORD_FIELD + "_prediction" ;
369
+ indexData (sourceIndex , 100 , 0 , KEYWORD_FIELD );
370
+
371
+ DataFrameAnalyticsConfig config = buildAnalytics (jobId , sourceIndex , destIndex , null , new Classification (ALIAS_TO_KEYWORD_FIELD ));
372
+ registerAnalytics (config );
373
+ putAnalytics (config );
374
+ startAnalytics (jobId );
375
+ waitUntilAnalyticsIsStopped (jobId );
376
+
377
+ assertProgress (jobId , 100 , 100 , 100 , 100 );
378
+ assertThat (searchStoredProgress (jobId ).getHits ().getTotalHits ().value , equalTo (1L ));
379
+ assertModelStatePersisted (stateDocId ());
380
+ assertInferenceModelPersisted (jobId );
381
+ assertMlResultsFieldMappings (predictedClassField , "keyword" );
382
+ assertEvaluation (ALIAS_TO_KEYWORD_FIELD , KEYWORD_FIELD_VALUES , "ml." + predictedClassField );
383
+ }
384
+
385
+ public void testDependentVariableIsAliasToNested () throws Exception {
386
+ initialize ("dependent_variable_is_alias_to_nested" );
387
+ String predictedClassField = ALIAS_TO_NESTED_FIELD + "_prediction" ;
388
+ indexData (sourceIndex , 100 , 0 , NESTED_FIELD );
389
+
390
+ DataFrameAnalyticsConfig config = buildAnalytics (jobId , sourceIndex , destIndex , null , new Classification (ALIAS_TO_NESTED_FIELD ));
391
+ registerAnalytics (config );
392
+ putAnalytics (config );
393
+ startAnalytics (jobId );
394
+ waitUntilAnalyticsIsStopped (jobId );
395
+
396
+ assertProgress (jobId , 100 , 100 , 100 , 100 );
397
+ assertThat (searchStoredProgress (jobId ).getHits ().getTotalHits ().value , equalTo (1L ));
398
+ assertModelStatePersisted (stateDocId ());
399
+ assertInferenceModelPersisted (jobId );
400
+ assertMlResultsFieldMappings (predictedClassField , "keyword" );
401
+ assertEvaluation (ALIAS_TO_NESTED_FIELD , KEYWORD_FIELD_VALUES , "ml." + predictedClassField );
402
+ }
403
+
345
404
public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet () throws Exception {
346
405
String sourceIndex = "classification_two_jobs_with_same_randomize_seed_source" ;
347
406
String dependentVariable = KEYWORD_FIELD ;
@@ -433,7 +492,10 @@ private static void createIndex(String index) {
433
492
BOOLEAN_FIELD , "type=boolean" ,
434
493
NUMERICAL_FIELD , "type=double" ,
435
494
DISCRETE_NUMERICAL_FIELD , "type=integer" ,
436
- KEYWORD_FIELD , "type=keyword" )
495
+ KEYWORD_FIELD , "type=keyword" ,
496
+ NESTED_FIELD , "type=keyword" ,
497
+ ALIAS_TO_KEYWORD_FIELD , "type=alias,path=" + KEYWORD_FIELD ,
498
+ ALIAS_TO_NESTED_FIELD , "type=alias,path=" + NESTED_FIELD )
437
499
.get ();
438
500
}
439
501
@@ -445,7 +507,8 @@ private static void indexData(String sourceIndex, int numTrainingRows, int numNo
445
507
BOOLEAN_FIELD , BOOLEAN_FIELD_VALUES .get (i % BOOLEAN_FIELD_VALUES .size ()),
446
508
NUMERICAL_FIELD , NUMERICAL_FIELD_VALUES .get (i % NUMERICAL_FIELD_VALUES .size ()),
447
509
DISCRETE_NUMERICAL_FIELD , DISCRETE_NUMERICAL_FIELD_VALUES .get (i % DISCRETE_NUMERICAL_FIELD_VALUES .size ()),
448
- KEYWORD_FIELD , KEYWORD_FIELD_VALUES .get (i % KEYWORD_FIELD_VALUES .size ()));
510
+ KEYWORD_FIELD , KEYWORD_FIELD_VALUES .get (i % KEYWORD_FIELD_VALUES .size ()),
511
+ NESTED_FIELD , KEYWORD_FIELD_VALUES .get (i % KEYWORD_FIELD_VALUES .size ()));
449
512
IndexRequest indexRequest = new IndexRequest (sourceIndex ).source (source .toArray ());
450
513
bulkRequestBuilder .add (indexRequest );
451
514
}
@@ -465,6 +528,9 @@ private static void indexData(String sourceIndex, int numTrainingRows, int numNo
465
528
if (KEYWORD_FIELD .equals (dependentVariable ) == false ) {
466
529
source .addAll (Arrays .asList (KEYWORD_FIELD , KEYWORD_FIELD_VALUES .get (i % KEYWORD_FIELD_VALUES .size ())));
467
530
}
531
+ if (NESTED_FIELD .equals (dependentVariable ) == false ) {
532
+ source .addAll (Arrays .asList (NESTED_FIELD , KEYWORD_FIELD_VALUES .get (i % KEYWORD_FIELD_VALUES .size ())));
533
+ }
468
534
IndexRequest indexRequest = new IndexRequest (sourceIndex ).source (source .toArray ());
469
535
bulkRequestBuilder .add (indexRequest );
470
536
}
@@ -487,10 +553,12 @@ private static Map<String, Object> getDestDoc(DataFrameAnalyticsConfig config, S
487
553
}
488
554
489
555
/**
490
- * Wrapper around extractValue with implicit casting to the appropriate type.
556
+ * Wrapper around extractValue that:
557
+ * - allows dots (".") in the path elements provided as arguments
558
+ * - supports implicit casting to the appropriate type
491
559
*/
492
560
private static <T > T getFieldValue (Map <String , Object > doc , String ... path ) {
493
- return (T )extractValue (doc , path );
561
+ return (T )extractValue (String . join ( "." , path ), doc );
494
562
}
495
563
496
564
private static <T > void assertTopClasses (Map <String , Object > resultsObject ,
@@ -583,8 +651,14 @@ private void assertMlResultsFieldMappings(String predictedClassField, String exp
583
651
.get (destIndex )
584
652
.get ("_doc" )
585
653
.sourceAsMap ();
586
- assertThat (getFieldValue (mappings , "properties" , "ml" , "properties" , predictedClassField , "type" ), equalTo (expectedType ));
587
654
assertThat (
655
+ mappings .toString (),
656
+ getFieldValue (
657
+ mappings ,
658
+ "properties" , "ml" , "properties" , String .join (".properties." , predictedClassField .split ("\\ ." )), "type" ),
659
+ equalTo (expectedType ));
660
+ assertThat (
661
+ mappings .toString (),
588
662
getFieldValue (mappings , "properties" , "ml" , "properties" , "top_classes" , "properties" , "class_name" , "type" ),
589
663
equalTo (expectedType ));
590
664
}
0 commit comments