Skip to content

Commit f2d4c94

Browse files
[7.x][ML] Deduplicate multi-fields for data frame analytics (#48799) (#48806)
In the case multi-fields exist in the source index, we pick all variants of them in our extracted fields detection for data frame analytics. This means we may have multiple instances of the same feature. The worse consequence of this is when the dependent variable (for regression or classification) is also duplicated which means we train a model on the dependent variable itself. Now that #48770 is merged, this commit is adding logic to only select one variant of multi-fields. Closes #48756 Backport of #48799
1 parent fd4ae69 commit f2d4c94

File tree

2 files changed

+209
-3
lines changed

2 files changed

+209
-3
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java

+55-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import java.util.Collections;
3232
import java.util.HashSet;
3333
import java.util.Iterator;
34+
import java.util.LinkedHashMap;
3435
import java.util.List;
3536
import java.util.Map;
3637
import java.util.Objects;
@@ -238,7 +239,9 @@ private ExtractedFields detectExtractedFields(Set<String> fields) {
238239
// We sort the fields to ensure the checksum for each document is deterministic
239240
Collections.sort(sortedFields);
240241
ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse);
241-
if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) {
242+
boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit;
243+
extractedFields = deduplicateMultiFields(extractedFields, preferSource);
244+
if (preferSource) {
242245
extractedFields = fetchFromSourceIfSupported(extractedFields);
243246
if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) {
244247
throw ExceptionsHelper.badRequestException("[{}] fields must be retrieved from doc_values but the limit is [{}]; " +
@@ -250,9 +253,59 @@ private ExtractedFields detectExtractedFields(Set<String> fields) {
250253
return extractedFields;
251254
}
252255

256+
private ExtractedFields deduplicateMultiFields(ExtractedFields extractedFields, boolean preferSource) {
257+
Set<String> requiredFields = config.getAnalysis().getRequiredFields().stream().map(RequiredField::getName)
258+
.collect(Collectors.toSet());
259+
Map<String, ExtractedField> nameOrParentToField = new LinkedHashMap<>();
260+
for (ExtractedField currentField : extractedFields.getAllFields()) {
261+
String nameOrParent = currentField.isMultiField() ? currentField.getParentField() : currentField.getName();
262+
ExtractedField existingField = nameOrParentToField.putIfAbsent(nameOrParent, currentField);
263+
if (existingField != null) {
264+
ExtractedField parent = currentField.isMultiField() ? existingField : currentField;
265+
ExtractedField multiField = currentField.isMultiField() ? currentField : existingField;
266+
nameOrParentToField.put(nameOrParent, chooseMultiFieldOrParent(preferSource, requiredFields, parent, multiField));
267+
}
268+
}
269+
return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()));
270+
}
271+
272+
private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set<String> requiredFields,
273+
ExtractedField parent, ExtractedField multiField) {
274+
// Check requirements first
275+
if (requiredFields.contains(parent.getName())) {
276+
return parent;
277+
}
278+
if (requiredFields.contains(multiField.getName())) {
279+
return multiField;
280+
}
281+
282+
// If both are multi-fields it means there are several. In this case parent is the previous multi-field
283+
// we selected. We'll just keep that.
284+
if (parent.isMultiField() && multiField.isMultiField()) {
285+
return parent;
286+
}
287+
288+
// If we prefer source only the parent may support it. If it does we pick it immediately.
289+
if (preferSource && parent.supportsFromSource()) {
290+
return parent;
291+
}
292+
293+
// If any of the two is a doc_value field let's prefer it as it'd support aggregations.
294+
// We check the parent first as it'd be a shorter field name.
295+
if (parent.getMethod() == ExtractedField.Method.DOC_VALUE) {
296+
return parent;
297+
}
298+
if (multiField.getMethod() == ExtractedField.Method.DOC_VALUE) {
299+
return multiField;
300+
}
301+
302+
// None is aggregatable. Let's pick the parent for its shorter name.
303+
return parent;
304+
}
305+
253306
private ExtractedFields fetchFromSourceIfSupported(ExtractedFields extractedFields) {
254307
List<ExtractedField> adjusted = new ArrayList<>(extractedFields.getAllFields().size());
255-
for (ExtractedField field : extractedFields.getDocValueFields()) {
308+
for (ExtractedField field : extractedFields.getAllFields()) {
256309
adjusted.add(field.supportsFromSource() ? field.newFromSource() : field);
257310
}
258311
return new ExtractedFields(adjusted);

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java

+154-1
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,151 @@ public void testDetect_GivenBooleanField_BooleanMappedAsString() {
534534
assertThat(booleanField.value(hit), arrayContaining("false", "true", "false"));
535535
}
536536

537+
public void testDetect_GivenMultiFields() {
538+
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
539+
.addAggregatableField("a_float", "float")
540+
.addNonAggregatableField("text_without_keyword", "text")
541+
.addNonAggregatableField("text_1", "text")
542+
.addAggregatableField("text_1.keyword", "keyword")
543+
.addNonAggregatableField("text_2", "text")
544+
.addAggregatableField("text_2.keyword", "keyword")
545+
.addAggregatableField("keyword_1", "keyword")
546+
.addNonAggregatableField("keyword_1.text", "text")
547+
.build();
548+
549+
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
550+
SOURCE_INDEX, buildRegressionConfig("a_float"), RESULTS_FIELD, true, 100, fieldCapabilities);
551+
ExtractedFields extractedFields = extractedFieldsDetector.detect();
552+
553+
assertThat(extractedFields.getAllFields().size(), equalTo(5));
554+
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
555+
.collect(Collectors.toList());
556+
assertThat(extractedFieldNames, contains("a_float", "keyword_1", "text_1.keyword", "text_2.keyword", "text_without_keyword"));
557+
}
558+
559+
public void testDetect_GivenMultiFieldAndParentIsRequired() {
560+
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
561+
.addAggregatableField("field_1", "keyword")
562+
.addAggregatableField("field_1.keyword", "keyword")
563+
.addAggregatableField("field_2", "float")
564+
.build();
565+
566+
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
567+
SOURCE_INDEX, buildClassificationConfig("field_1"), RESULTS_FIELD, true, 100, fieldCapabilities);
568+
ExtractedFields extractedFields = extractedFieldsDetector.detect();
569+
570+
assertThat(extractedFields.getAllFields().size(), equalTo(2));
571+
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
572+
.collect(Collectors.toList());
573+
assertThat(extractedFieldNames, contains("field_1", "field_2"));
574+
}
575+
576+
public void testDetect_GivenMultiFieldAndMultiFieldIsRequired() {
577+
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
578+
.addAggregatableField("field_1", "keyword")
579+
.addAggregatableField("field_1.keyword", "keyword")
580+
.addAggregatableField("field_2", "float")
581+
.build();
582+
583+
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
584+
SOURCE_INDEX, buildClassificationConfig("field_1.keyword"), RESULTS_FIELD, true, 100, fieldCapabilities);
585+
ExtractedFields extractedFields = extractedFieldsDetector.detect();
586+
587+
assertThat(extractedFields.getAllFields().size(), equalTo(2));
588+
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
589+
.collect(Collectors.toList());
590+
assertThat(extractedFieldNames, contains("field_1.keyword", "field_2"));
591+
}
592+
593+
public void testDetect_GivenSeveralMultiFields_ShouldPickFirstSorted() {
594+
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
595+
.addNonAggregatableField("field_1", "text")
596+
.addAggregatableField("field_1.keyword_3", "keyword")
597+
.addAggregatableField("field_1.keyword_2", "keyword")
598+
.addAggregatableField("field_1.keyword_1", "keyword")
599+
.addAggregatableField("field_2", "float")
600+
.build();
601+
602+
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
603+
SOURCE_INDEX, buildRegressionConfig("field_2"), RESULTS_FIELD, true, 100, fieldCapabilities);
604+
ExtractedFields extractedFields = extractedFieldsDetector.detect();
605+
606+
assertThat(extractedFields.getAllFields().size(), equalTo(2));
607+
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
608+
.collect(Collectors.toList());
609+
assertThat(extractedFieldNames, contains("field_1.keyword_1", "field_2"));
610+
}
611+
612+
public void testDetect_GivenMultiFields_OverDocValueLimit() {
613+
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
614+
.addNonAggregatableField("field_1", "text")
615+
.addAggregatableField("field_1.keyword_1", "keyword")
616+
.addAggregatableField("field_2", "float")
617+
.build();
618+
619+
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
620+
SOURCE_INDEX, buildRegressionConfig("field_2"), RESULTS_FIELD, true, 0, fieldCapabilities);
621+
ExtractedFields extractedFields = extractedFieldsDetector.detect();
622+
623+
assertThat(extractedFields.getAllFields().size(), equalTo(2));
624+
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
625+
.collect(Collectors.toList());
626+
assertThat(extractedFieldNames, contains("field_1", "field_2"));
627+
}
628+
629+
public void testDetect_GivenParentAndMultiFieldBothAggregatable() {
630+
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
631+
.addAggregatableField("field_1", "keyword")
632+
.addAggregatableField("field_1.keyword", "keyword")
633+
.addAggregatableField("field_2.keyword", "float")
634+
.addAggregatableField("field_2.double", "double")
635+
.build();
636+
637+
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
638+
SOURCE_INDEX, buildRegressionConfig("field_2.double"), RESULTS_FIELD, true, 100, fieldCapabilities);
639+
ExtractedFields extractedFields = extractedFieldsDetector.detect();
640+
641+
assertThat(extractedFields.getAllFields().size(), equalTo(2));
642+
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
643+
.collect(Collectors.toList());
644+
assertThat(extractedFieldNames, contains("field_1", "field_2.double"));
645+
}
646+
647+
public void testDetect_GivenParentAndMultiFieldNoneAggregatable() {
648+
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
649+
.addNonAggregatableField("field_1", "text")
650+
.addNonAggregatableField("field_1.text", "text")
651+
.addAggregatableField("field_2", "float")
652+
.build();
653+
654+
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
655+
SOURCE_INDEX, buildRegressionConfig("field_2"), RESULTS_FIELD, true, 100, fieldCapabilities);
656+
ExtractedFields extractedFields = extractedFieldsDetector.detect();
657+
658+
assertThat(extractedFields.getAllFields().size(), equalTo(2));
659+
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
660+
.collect(Collectors.toList());
661+
assertThat(extractedFieldNames, contains("field_1", "field_2"));
662+
}
663+
664+
public void testDetect_GivenMultiFields_AndExplicitlyIncludedFields() {
665+
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
666+
.addNonAggregatableField("field_1", "text")
667+
.addAggregatableField("field_1.keyword", "keyword")
668+
.addAggregatableField("field_2", "float")
669+
.build();
670+
FetchSourceContext analyzedFields = new FetchSourceContext(true, new String[] { "field_1", "field_2" }, new String[0]);
671+
672+
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
673+
SOURCE_INDEX, buildRegressionConfig("field_2", analyzedFields), RESULTS_FIELD, false, 100, fieldCapabilities);
674+
ExtractedFields extractedFields = extractedFieldsDetector.detect();
675+
676+
assertThat(extractedFields.getAllFields().size(), equalTo(2));
677+
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
678+
.collect(Collectors.toList());
679+
assertThat(extractedFieldNames, contains("field_1", "field_2"));
680+
}
681+
537682
private static DataFrameAnalyticsConfig buildOutlierDetectionConfig() {
538683
return buildOutlierDetectionConfig(null);
539684
}
@@ -576,9 +721,17 @@ private static class MockFieldCapsResponseBuilder {
576721
private final Map<String, Map<String, FieldCapabilities>> fieldCaps = new HashMap<>();
577722

578723
private MockFieldCapsResponseBuilder addAggregatableField(String field, String... types) {
724+
return addField(field, true, types);
725+
}
726+
727+
private MockFieldCapsResponseBuilder addNonAggregatableField(String field, String... types) {
728+
return addField(field, false, types);
729+
}
730+
731+
private MockFieldCapsResponseBuilder addField(String field, boolean isAggregatable, String... types) {
579732
Map<String, FieldCapabilities> caps = new HashMap<>();
580733
for (String type : types) {
581-
caps.put(type, new FieldCapabilities(field, type, true, true));
734+
caps.put(type, new FieldCapabilities(field, type, true, isAggregatable));
582735
}
583736
fieldCaps.put(field, caps);
584737
return this;

0 commit comments

Comments
 (0)