Skip to content

Commit 9ecbbaf

Browse files
przemekwitekSivagurunathanV
authored andcommitted
Handle nested and aliased fields correctly when copying mapping. (elastic#50918)
1 parent 8e10736 commit 9ecbbaf

File tree

3 files changed

+175
-34
lines changed

3 files changed

+175
-34
lines changed

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
6363
private static final String NUMERICAL_FIELD = "numerical-field";
6464
private static final String DISCRETE_NUMERICAL_FIELD = "discrete-numerical-field";
6565
private static final String KEYWORD_FIELD = "keyword-field";
66+
private static final String NESTED_FIELD = "outer-field.inner-field";
67+
private static final String ALIAS_TO_KEYWORD_FIELD = "alias-to-keyword-field";
68+
private static final String ALIAS_TO_NESTED_FIELD = "alias-to-nested-field";
6669
private static final List<Boolean> BOOLEAN_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(false, true));
6770
private static final List<Double> NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0));
6871
private static final List<Integer> DISCRETE_NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(10, 20));
@@ -301,7 +304,6 @@ public void testStopAndRestart() throws Exception {
301304
assertInferenceModelPersisted(jobId);
302305
assertMlResultsFieldMappings(predictedClassField, "keyword");
303306
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
304-
305307
}
306308

307309
public void testDependentVariableCardinalityTooHighError() throws Exception {
@@ -343,6 +345,63 @@ public void testDependentVariableCardinalityTooHighButWithQueryMakesItWithinRang
343345
assertProgress(jobId, 100, 100, 100, 100);
344346
}
345347

348+
public void testDependentVariableIsNested() throws Exception {
349+
initialize("dependent_variable_is_nested");
350+
String predictedClassField = NESTED_FIELD + "_prediction";
351+
indexData(sourceIndex, 100, 0, NESTED_FIELD);
352+
353+
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(NESTED_FIELD));
354+
registerAnalytics(config);
355+
putAnalytics(config);
356+
startAnalytics(jobId);
357+
waitUntilAnalyticsIsStopped(jobId);
358+
359+
assertProgress(jobId, 100, 100, 100, 100);
360+
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
361+
assertModelStatePersisted(stateDocId());
362+
assertInferenceModelPersisted(jobId);
363+
assertMlResultsFieldMappings(predictedClassField, "keyword");
364+
assertEvaluation(NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
365+
}
366+
367+
public void testDependentVariableIsAliasToKeyword() throws Exception {
368+
initialize("dependent_variable_is_alias");
369+
String predictedClassField = ALIAS_TO_KEYWORD_FIELD + "_prediction";
370+
indexData(sourceIndex, 100, 0, KEYWORD_FIELD);
371+
372+
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(ALIAS_TO_KEYWORD_FIELD));
373+
registerAnalytics(config);
374+
putAnalytics(config);
375+
startAnalytics(jobId);
376+
waitUntilAnalyticsIsStopped(jobId);
377+
378+
assertProgress(jobId, 100, 100, 100, 100);
379+
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
380+
assertModelStatePersisted(stateDocId());
381+
assertInferenceModelPersisted(jobId);
382+
assertMlResultsFieldMappings(predictedClassField, "keyword");
383+
assertEvaluation(ALIAS_TO_KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
384+
}
385+
386+
public void testDependentVariableIsAliasToNested() throws Exception {
387+
initialize("dependent_variable_is_alias_to_nested");
388+
String predictedClassField = ALIAS_TO_NESTED_FIELD + "_prediction";
389+
indexData(sourceIndex, 100, 0, NESTED_FIELD);
390+
391+
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(ALIAS_TO_NESTED_FIELD));
392+
registerAnalytics(config);
393+
putAnalytics(config);
394+
startAnalytics(jobId);
395+
waitUntilAnalyticsIsStopped(jobId);
396+
397+
assertProgress(jobId, 100, 100, 100, 100);
398+
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
399+
assertModelStatePersisted(stateDocId());
400+
assertInferenceModelPersisted(jobId);
401+
assertMlResultsFieldMappings(predictedClassField, "keyword");
402+
assertEvaluation(ALIAS_TO_NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
403+
}
404+
346405
public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception {
347406
String sourceIndex = "classification_two_jobs_with_same_randomize_seed_source";
348407
String dependentVariable = KEYWORD_FIELD;
@@ -434,7 +493,10 @@ private static void createIndex(String index) {
434493
BOOLEAN_FIELD, "type=boolean",
435494
NUMERICAL_FIELD, "type=double",
436495
DISCRETE_NUMERICAL_FIELD, "type=integer",
437-
KEYWORD_FIELD, "type=keyword")
496+
KEYWORD_FIELD, "type=keyword",
497+
NESTED_FIELD, "type=keyword",
498+
ALIAS_TO_KEYWORD_FIELD, "type=alias,path=" + KEYWORD_FIELD,
499+
ALIAS_TO_NESTED_FIELD, "type=alias,path=" + NESTED_FIELD)
438500
.get();
439501
}
440502

@@ -446,7 +508,8 @@ private static void indexData(String sourceIndex, int numTrainingRows, int numNo
446508
BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size()),
447509
NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size()),
448510
DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size()),
449-
KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()));
511+
KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()),
512+
NESTED_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()));
450513
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
451514
bulkRequestBuilder.add(indexRequest);
452515
}
@@ -465,6 +528,9 @@ private static void indexData(String sourceIndex, int numTrainingRows, int numNo
465528
if (KEYWORD_FIELD.equals(dependentVariable) == false) {
466529
source.addAll(List.of(KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
467530
}
531+
if (NESTED_FIELD.equals(dependentVariable) == false) {
532+
source.addAll(List.of(NESTED_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
533+
}
468534
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
469535
bulkRequestBuilder.add(indexRequest);
470536
}
@@ -487,10 +553,12 @@ private static Map<String, Object> getDestDoc(DataFrameAnalyticsConfig config, S
487553
}
488554

489555
/**
490-
* Wrapper around extractValue with implicit casting to the appropriate type.
556+
* Wrapper around extractValue that:
557+
* - allows dots (".") in the path elements provided as arguments
558+
* - supports implicit casting to the appropriate type
491559
*/
492560
private static <T> T getFieldValue(Map<String, Object> doc, String... path) {
493-
return (T)extractValue(doc, path);
561+
return (T)extractValue(String.join(".", path), doc);
494562
}
495563

496564
private static <T> void assertTopClasses(Map<String, Object> resultsObject,
@@ -582,8 +650,14 @@ private void assertMlResultsFieldMappings(String predictedClassField, String exp
582650
.mappings()
583651
.get(destIndex)
584652
.sourceAsMap();
585-
assertThat(getFieldValue(mappings, "properties", "ml", "properties", predictedClassField, "type"), equalTo(expectedType));
586653
assertThat(
654+
mappings.toString(),
655+
getFieldValue(
656+
mappings,
657+
"properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"),
658+
equalTo(expectedType));
659+
assertThat(
660+
mappings.toString(),
587661
getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"),
588662
equalTo(expectedType));
589663
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndex.java

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import org.elasticsearch.common.Nullable;
2525
import org.elasticsearch.common.settings.Settings;
2626
import org.elasticsearch.index.IndexSortConfig;
27+
import org.elasticsearch.index.mapper.FieldAliasMapper;
28+
import org.elasticsearch.index.mapper.KeywordFieldMapper;
2729
import org.elasticsearch.search.sort.SortOrder;
2830
import org.elasticsearch.xpack.core.ClientHelper;
2931
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
@@ -38,6 +40,7 @@
3840
import java.util.concurrent.atomic.AtomicReference;
3941
import java.util.function.Supplier;
4042

43+
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
4144
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
4245

4346
/**
@@ -155,21 +158,36 @@ private static Integer findMaxSettingValue(GetSettingsResponse settingsResponse,
155158
return maxValue;
156159
}
157160

161+
@SuppressWarnings("unchecked")
158162
private static Map<String, Object> createAdditionalMappings(DataFrameAnalyticsConfig config, Map<String, Object> mappingsProperties) {
159163
Map<String, Object> properties = new HashMap<>();
160-
properties.put(ID_COPY, Map.of("type", "keyword"));
164+
properties.put(ID_COPY, Map.of("type", KeywordFieldMapper.CONTENT_TYPE));
161165
for (Map.Entry<String, String> entry
162166
: config.getAnalysis().getExplicitlyMappedFields(config.getDest().getResultsField()).entrySet()) {
163167
String destFieldPath = entry.getKey();
164168
String sourceFieldPath = entry.getValue();
165-
Object sourceFieldMapping = mappingsProperties.get(sourceFieldPath);
166-
if (sourceFieldMapping != null) {
169+
Object sourceFieldMapping = extractMapping(sourceFieldPath, mappingsProperties);
170+
if (sourceFieldMapping instanceof Map) {
171+
Map<String, Object> sourceFieldMappingAsMap = (Map) sourceFieldMapping;
172+
// If the source field is an alias, fetch the concrete field that the alias points to.
173+
if (FieldAliasMapper.CONTENT_TYPE.equals(sourceFieldMappingAsMap.get("type"))) {
174+
String path = (String) sourceFieldMappingAsMap.get(FieldAliasMapper.Names.PATH);
175+
sourceFieldMapping = extractMapping(path, mappingsProperties);
176+
}
177+
}
178+
// We may have updated the value of {@code sourceFieldMapping} in the "if" block above.
179+
// Hence, we need to check the "instanceof" condition again.
180+
if (sourceFieldMapping instanceof Map) {
167181
properties.put(destFieldPath, sourceFieldMapping);
168182
}
169183
}
170184
return properties;
171185
}
172186

187+
private static Object extractMapping(String path, Map<String, Object> mappingsProperties) {
188+
return extractValue(String.join("." + PROPERTIES + ".", path.split("\\.")), mappingsProperties);
189+
}
190+
173191
private static Map<String, Object> createMetaData(String analyticsId, Clock clock) {
174192
Map<String, Object> metadata = new HashMap<>();
175193
metadata.put(CREATION_DATE_MILLIS, clock.millis());
@@ -227,4 +245,3 @@ private static void checkResultsFieldIsNotPresentInProperties(DataFrameAnalytics
227245
}
228246
}
229247
}
230-

0 commit comments

Comments
 (0)