Skip to content

Commit 9c6ffdc

Browse files
authored
[7.x] Handle nested and aliased fields correctly when copying mapping. (#50918) (#50968)
1 parent f028ab0 commit 9c6ffdc

File tree

3 files changed

+191
-44
lines changed

3 files changed

+191
-44
lines changed

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

+80-6
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
6363
private static final String NUMERICAL_FIELD = "numerical-field";
6464
private static final String DISCRETE_NUMERICAL_FIELD = "discrete-numerical-field";
6565
private static final String KEYWORD_FIELD = "keyword-field";
66+
private static final String NESTED_FIELD = "outer-field.inner-field";
67+
private static final String ALIAS_TO_KEYWORD_FIELD = "alias-to-keyword-field";
68+
private static final String ALIAS_TO_NESTED_FIELD = "alias-to-nested-field";
6669
private static final List<Boolean> BOOLEAN_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(false, true));
6770
private static final List<Double> NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0));
6871
private static final List<Integer> DISCRETE_NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(10, 20));
@@ -301,7 +304,6 @@ public void testStopAndRestart() throws Exception {
301304
assertInferenceModelPersisted(jobId);
302305
assertMlResultsFieldMappings(predictedClassField, "keyword");
303306
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
304-
305307
}
306308

307309
public void testDependentVariableCardinalityTooHighError() throws Exception {
@@ -342,6 +344,63 @@ public void testDependentVariableCardinalityTooHighButWithQueryMakesItWithinRang
342344
assertProgress(jobId, 100, 100, 100, 100);
343345
}
344346

347+
public void testDependentVariableIsNested() throws Exception {
348+
initialize("dependent_variable_is_nested");
349+
String predictedClassField = NESTED_FIELD + "_prediction";
350+
indexData(sourceIndex, 100, 0, NESTED_FIELD);
351+
352+
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(NESTED_FIELD));
353+
registerAnalytics(config);
354+
putAnalytics(config);
355+
startAnalytics(jobId);
356+
waitUntilAnalyticsIsStopped(jobId);
357+
358+
assertProgress(jobId, 100, 100, 100, 100);
359+
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
360+
assertModelStatePersisted(stateDocId());
361+
assertInferenceModelPersisted(jobId);
362+
assertMlResultsFieldMappings(predictedClassField, "keyword");
363+
assertEvaluation(NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
364+
}
365+
366+
public void testDependentVariableIsAliasToKeyword() throws Exception {
367+
initialize("dependent_variable_is_alias");
368+
String predictedClassField = ALIAS_TO_KEYWORD_FIELD + "_prediction";
369+
indexData(sourceIndex, 100, 0, KEYWORD_FIELD);
370+
371+
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(ALIAS_TO_KEYWORD_FIELD));
372+
registerAnalytics(config);
373+
putAnalytics(config);
374+
startAnalytics(jobId);
375+
waitUntilAnalyticsIsStopped(jobId);
376+
377+
assertProgress(jobId, 100, 100, 100, 100);
378+
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
379+
assertModelStatePersisted(stateDocId());
380+
assertInferenceModelPersisted(jobId);
381+
assertMlResultsFieldMappings(predictedClassField, "keyword");
382+
assertEvaluation(ALIAS_TO_KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
383+
}
384+
385+
public void testDependentVariableIsAliasToNested() throws Exception {
386+
initialize("dependent_variable_is_alias_to_nested");
387+
String predictedClassField = ALIAS_TO_NESTED_FIELD + "_prediction";
388+
indexData(sourceIndex, 100, 0, NESTED_FIELD);
389+
390+
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(ALIAS_TO_NESTED_FIELD));
391+
registerAnalytics(config);
392+
putAnalytics(config);
393+
startAnalytics(jobId);
394+
waitUntilAnalyticsIsStopped(jobId);
395+
396+
assertProgress(jobId, 100, 100, 100, 100);
397+
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
398+
assertModelStatePersisted(stateDocId());
399+
assertInferenceModelPersisted(jobId);
400+
assertMlResultsFieldMappings(predictedClassField, "keyword");
401+
assertEvaluation(ALIAS_TO_NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
402+
}
403+
345404
public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception {
346405
String sourceIndex = "classification_two_jobs_with_same_randomize_seed_source";
347406
String dependentVariable = KEYWORD_FIELD;
@@ -433,7 +492,10 @@ private static void createIndex(String index) {
433492
BOOLEAN_FIELD, "type=boolean",
434493
NUMERICAL_FIELD, "type=double",
435494
DISCRETE_NUMERICAL_FIELD, "type=integer",
436-
KEYWORD_FIELD, "type=keyword")
495+
KEYWORD_FIELD, "type=keyword",
496+
NESTED_FIELD, "type=keyword",
497+
ALIAS_TO_KEYWORD_FIELD, "type=alias,path=" + KEYWORD_FIELD,
498+
ALIAS_TO_NESTED_FIELD, "type=alias,path=" + NESTED_FIELD)
437499
.get();
438500
}
439501

@@ -445,7 +507,8 @@ private static void indexData(String sourceIndex, int numTrainingRows, int numNo
445507
BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES.get(i % BOOLEAN_FIELD_VALUES.size()),
446508
NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES.get(i % NUMERICAL_FIELD_VALUES.size()),
447509
DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES.get(i % DISCRETE_NUMERICAL_FIELD_VALUES.size()),
448-
KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()));
510+
KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()),
511+
NESTED_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size()));
449512
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
450513
bulkRequestBuilder.add(indexRequest);
451514
}
@@ -465,6 +528,9 @@ private static void indexData(String sourceIndex, int numTrainingRows, int numNo
465528
if (KEYWORD_FIELD.equals(dependentVariable) == false) {
466529
source.addAll(Arrays.asList(KEYWORD_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
467530
}
531+
if (NESTED_FIELD.equals(dependentVariable) == false) {
532+
source.addAll(Arrays.asList(NESTED_FIELD, KEYWORD_FIELD_VALUES.get(i % KEYWORD_FIELD_VALUES.size())));
533+
}
468534
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
469535
bulkRequestBuilder.add(indexRequest);
470536
}
@@ -487,10 +553,12 @@ private static Map<String, Object> getDestDoc(DataFrameAnalyticsConfig config, S
487553
}
488554

489555
/**
490-
* Wrapper around extractValue with implicit casting to the appropriate type.
556+
* Wrapper around extractValue that:
557+
* - allows dots (".") in the path elements provided as arguments
558+
* - supports implicit casting to the appropriate type
491559
*/
492560
private static <T> T getFieldValue(Map<String, Object> doc, String... path) {
493-
return (T)extractValue(doc, path);
561+
return (T)extractValue(String.join(".", path), doc);
494562
}
495563

496564
private static <T> void assertTopClasses(Map<String, Object> resultsObject,
@@ -583,8 +651,14 @@ private void assertMlResultsFieldMappings(String predictedClassField, String exp
583651
.get(destIndex)
584652
.get("_doc")
585653
.sourceAsMap();
586-
assertThat(getFieldValue(mappings, "properties", "ml", "properties", predictedClassField, "type"), equalTo(expectedType));
587654
assertThat(
655+
mappings.toString(),
656+
getFieldValue(
657+
mappings,
658+
"properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"),
659+
equalTo(expectedType));
660+
assertThat(
661+
mappings.toString(),
588662
getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"),
589663
equalTo(expectedType));
590664
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndex.java

+21-4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import org.elasticsearch.common.collect.ImmutableOpenMap;
2626
import org.elasticsearch.common.settings.Settings;
2727
import org.elasticsearch.index.IndexSortConfig;
28+
import org.elasticsearch.index.mapper.FieldAliasMapper;
29+
import org.elasticsearch.index.mapper.KeywordFieldMapper;
2830
import org.elasticsearch.search.sort.SortOrder;
2931
import org.elasticsearch.xpack.core.ClientHelper;
3032
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
@@ -39,6 +41,7 @@
3941
import java.util.concurrent.atomic.AtomicReference;
4042
import java.util.function.Supplier;
4143

44+
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
4245
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
4346

4447
/**
@@ -160,23 +163,38 @@ private static Integer findMaxSettingValue(GetSettingsResponse settingsResponse,
160163
return maxValue;
161164
}
162165

166+
@SuppressWarnings("unchecked")
163167
private static Map<String, Object> createAdditionalMappings(DataFrameAnalyticsConfig config, Map<String, Object> mappingsProperties) {
164168
Map<String, Object> properties = new HashMap<>();
165169
Map<String, String> idCopyMapping = new HashMap<>();
166-
idCopyMapping.put("type", "keyword");
170+
idCopyMapping.put("type", KeywordFieldMapper.CONTENT_TYPE);
167171
properties.put(ID_COPY, idCopyMapping);
168172
for (Map.Entry<String, String> entry
169173
: config.getAnalysis().getExplicitlyMappedFields(config.getDest().getResultsField()).entrySet()) {
170174
String destFieldPath = entry.getKey();
171175
String sourceFieldPath = entry.getValue();
172-
Object sourceFieldMapping = mappingsProperties.get(sourceFieldPath);
173-
if (sourceFieldMapping != null) {
176+
Object sourceFieldMapping = extractMapping(sourceFieldPath, mappingsProperties);
177+
if (sourceFieldMapping instanceof Map) {
178+
Map<String, Object> sourceFieldMappingAsMap = (Map) sourceFieldMapping;
179+
// If the source field is an alias, fetch the concrete field that the alias points to.
180+
if (FieldAliasMapper.CONTENT_TYPE.equals(sourceFieldMappingAsMap.get("type"))) {
181+
String path = (String) sourceFieldMappingAsMap.get(FieldAliasMapper.Names.PATH);
182+
sourceFieldMapping = extractMapping(path, mappingsProperties);
183+
}
184+
}
185+
// We may have updated the value of {@code sourceFieldMapping} in the "if" block above.
186+
// Hence, we need to check the "instanceof" condition again.
187+
if (sourceFieldMapping instanceof Map) {
174188
properties.put(destFieldPath, sourceFieldMapping);
175189
}
176190
}
177191
return properties;
178192
}
179193

194+
private static Object extractMapping(String path, Map<String, Object> mappingsProperties) {
195+
return extractValue(String.join("." + PROPERTIES + ".", path.split("\\.")), mappingsProperties);
196+
}
197+
180198
private static Map<String, Object> createMetaData(String analyticsId, Clock clock) {
181199
Map<String, Object> metadata = new HashMap<>();
182200
metadata.put(CREATION_DATE_MILLIS, clock.millis());
@@ -239,4 +257,3 @@ private static void checkResultsFieldIsNotPresentInProperties(DataFrameAnalytics
239257
}
240258
}
241259
}
242-

0 commit comments

Comments
 (0)