Skip to content

Commit 4396a1f

Browse files
authored
[ML][Inference] fix support for nested fields (#50258) (#50335)
This fixes support for nested fields We now support fully nested, fully collapsed, or a mix of both on inference docs. ES mappings allow the `_source` to be any combination of nested objects + dot delimited fields. So, we should do our best to find the best path down the Map for the desired field.
1 parent 06a24f0 commit 4396a1f

File tree

16 files changed

+582
-33
lines changed

16 files changed

+582
-33
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.common.xcontent.XContentBuilder;
1515
import org.elasticsearch.common.xcontent.XContentParser;
1616
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
17+
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
1718

1819
import java.io.IOException;
1920
import java.util.Collections;
@@ -103,7 +104,7 @@ public String getName() {
103104

104105
@Override
105106
public void process(Map<String, Object> fields) {
106-
Object value = fields.get(field);
107+
Object value = MapHelper.dig(field, fields);
107108
if (value == null) {
108109
return;
109110
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.common.xcontent.XContentBuilder;
1515
import org.elasticsearch.common.xcontent.XContentParser;
1616
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
17+
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
1718

1819
import java.io.IOException;
1920
import java.util.Collections;
@@ -86,7 +87,7 @@ public String getName() {
8687

8788
@Override
8889
public void process(Map<String, Object> fields) {
89-
Object value = fields.get(field);
90+
Object value = MapHelper.dig(field, fields);
9091
if (value == null) {
9192
return;
9293
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.common.xcontent.XContentBuilder;
1515
import org.elasticsearch.common.xcontent.XContentParser;
1616
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
17+
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
1718

1819
import java.io.IOException;
1920
import java.util.Collections;
@@ -114,7 +115,7 @@ public String getName() {
114115

115116
@Override
116117
public void process(Map<String, Object> fields) {
117-
Object value = fields.get(field);
118+
Object value = MapHelper.dig(field, fields);
118119
if (value == null) {
119120
return;
120121
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
2929
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
3030
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
31+
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
3132

3233
import java.io.IOException;
3334
import java.util.ArrayDeque;
@@ -129,7 +130,9 @@ public InferenceResults infer(Map<String, Object> fields, InferenceConfig config
129130
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
130131
}
131132

132-
List<Double> features = featureNames.stream().map(f -> InferenceHelpers.toDouble(fields.get(f))).collect(Collectors.toList());
133+
List<Double> features = featureNames.stream()
134+
.map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields)))
135+
.collect(Collectors.toList());
133136
return infer(features, config);
134137
}
135138

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.core.ml.utils;
7+
8+
import org.elasticsearch.common.Nullable;
9+
10+
import java.util.Arrays;
11+
import java.util.Map;
12+
import java.util.Stack;
13+
14+
public final class MapHelper {
15+
16+
private MapHelper() {}
17+
18+
/**
19+
* This eagerly digs (depth first search, longer keys first) through the map by tokenizing the provided path on '.'.
20+
*
21+
* It is possible for ES _source docs to have "mixed" path formats. So, we should search all potential paths
22+
* given the current knowledge of the map.
23+
*
24+
* Examples:
25+
*
26+
* The following maps would return `2` given the path "a.b.c.d"
27+
*
28+
* {
29+
* "a.b.c.d" : 2
30+
* }
31+
* {
32+
* "a" :{"b": {"c": {"d" : 2}}}
33+
* }
34+
* {
35+
* "a" :{"b.c": {"d" : 2}}}
36+
* }
37+
* {
38+
* "a" :{"b": {"c": {"d" : 2}}},
39+
* "a.b" :{"c": {"d" : 5}} // we choose the first one found, we go down longer keys first
40+
* }
41+
* {
42+
* "a" :{"b": {"c": {"NOT_d" : 2, "d": 2}}}
43+
* }
44+
*
45+
* Conceptual "Worse case" 5 potential paths explored for "a.b.c.d" until 2 is finally returned
46+
* {
47+
* "a.b.c": {"not_d": 2},
48+
* "a.b": {"c": {"not_d": 2}},
49+
* "a": {"b.c": {"not_d": 2}},
50+
* "a": {"b" :{ "c.not_d": 2}},
51+
* "a" :{"b": {"c": {"not_d" : 2}}},
52+
* "a" :{"b": {"c": {"d" : 2}}},
53+
* }
54+
*
55+
* We don't exhaustively create all potential paths.
56+
* If we did, this would result in 2^n-1 total possible paths, where {@code n = path.split("\\.").length}.
57+
*
58+
* Instead we lazily create potential paths once we know that they are possibilities.
59+
*
60+
* @param path Dot delimited path containing the field desired
61+
* @param map The {@link Map} map to dig
62+
* @return The found object. Returns {@code null} if not found
63+
*/
64+
@Nullable
65+
public static Object dig(String path, Map<String, Object> map) {
66+
// short cut before search
67+
if (map.keySet().contains(path)) {
68+
return map.get(path);
69+
}
70+
String[] fields = path.split("\\.");
71+
if (Arrays.stream(fields).anyMatch(String::isEmpty)) {
72+
throw new IllegalArgumentException("Empty path detected. Invalid field name");
73+
}
74+
Stack<PotentialPath> pathStack = new Stack<>();
75+
pathStack.push(new PotentialPath(map, 0));
76+
return explore(fields, pathStack);
77+
}
78+
79+
@SuppressWarnings("unchecked")
80+
private static Object explore(String[] path, Stack<PotentialPath> pathStack) {
81+
while (pathStack.empty() == false) {
82+
PotentialPath potentialPath = pathStack.pop();
83+
int endPos = potentialPath.pathPosition + 1;
84+
int startPos = potentialPath.pathPosition;
85+
Map<String, Object> map = potentialPath.map;
86+
String candidateKey = null;
87+
while(endPos <= path.length) {
88+
candidateKey = mergePath(path, startPos, endPos);
89+
Object next = map.get(candidateKey);
90+
if (endPos == path.length && next != null) { // exit early, we reached the full path and found something
91+
return next;
92+
}
93+
if (next instanceof Map<?, ?>) { // we found another map, continue exploring down this path
94+
pathStack.push(new PotentialPath((Map<String, Object>)next, endPos));
95+
}
96+
endPos++;
97+
}
98+
if (candidateKey != null && map.containsKey(candidateKey)) { //exit early
99+
return map.get(candidateKey);
100+
}
101+
}
102+
103+
return null;
104+
}
105+
106+
private static String mergePath(String[] path, int start, int end) {
107+
if (start + 1 == end) { // early exit, no need to create sb
108+
return path[start];
109+
}
110+
111+
StringBuilder sb = new StringBuilder();
112+
for (int i = start; i < end - 1; i++) {
113+
sb.append(path[i]);
114+
sb.append(".");
115+
}
116+
sb.append(path[end - 1]);
117+
return sb.toString();
118+
}
119+
120+
private static class PotentialPath {
121+
122+
// Pointer to where to start exploring
123+
private final Map<String, Object> map;
124+
// Where in the requested path are we
125+
private final int pathPosition;
126+
127+
private PotentialPath(Map<String, Object> map, int pathPosition) {
128+
this.map = map;
129+
this.pathPosition = pathPosition;
130+
}
131+
132+
}
133+
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,22 @@ public void testProcessWithFieldPresent() {
6565
testProcess(encoding, fieldValues, matchers);
6666
}
6767

68+
public void testProcessWithNestedField() {
69+
String field = "categorical.child";
70+
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5);
71+
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
72+
v -> randomDoubleBetween(0.0, 1.0, false)));
73+
String encodedFeatureName = "encoded";
74+
FrequencyEncoding encoding = new FrequencyEncoding(field, encodedFeatureName, valueMap);
75+
76+
Map<String, Object> fieldValues = new HashMap<String, Object>() {{
77+
put("categorical", new HashMap<String, Object>(){{
78+
put("child", "farequote");
79+
}});
80+
}};
81+
82+
encoding.process(fieldValues);
83+
assertThat(fieldValues.get("encoded"), equalTo(valueMap.get("farequote")));
84+
}
85+
6886
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncodingTests.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,19 @@ public void testProcessWithFieldPresent() {
6767
testProcess(encoding, fieldValues, matchers);
6868
}
6969

70+
public void testProcessWithNestedField() {
71+
String field = "categorical.child";
72+
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5);
73+
Map<String, String> valueMap = values.stream().collect(Collectors.toMap(Object::toString, v -> "Column_" + v.toString()));
74+
OneHotEncoding encoding = new OneHotEncoding(field, valueMap);
75+
Map<String, Object> fieldValues = new HashMap<String, Object>() {{
76+
put("categorical", new HashMap<String, Object>(){{
77+
put("child", "farequote");
78+
}});
79+
}};
80+
81+
encoding.process(fieldValues);
82+
assertThat(fieldValues.get("Column_farequote"), equalTo(1));
83+
}
84+
7085
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,24 @@ public void testProcessWithFieldPresent() {
6868
testProcess(encoding, fieldValues, matchers);
6969
}
7070

71+
public void testProcessWithNestedField() {
72+
String field = "categorical.child";
73+
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5);
74+
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
75+
v -> randomDoubleBetween(0.0, 1.0, false)));
76+
String encodedFeatureName = "encoded";
77+
Double defaultvalue = randomDouble();
78+
TargetMeanEncoding encoding = new TargetMeanEncoding(field, encodedFeatureName, valueMap, defaultvalue);
79+
80+
Map<String, Object> fieldValues = new HashMap<String, Object>() {{
81+
put("categorical", new HashMap<String, Object>(){{
82+
put("child", "farequote");
83+
}});
84+
}};
85+
86+
encoding.process(fieldValues);
87+
88+
assertThat(fieldValues.get("encoded"), equalTo(valueMap.get("farequote")));
89+
}
90+
7191
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,63 @@ public void testRegressionInference() {
445445
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
446446
}
447447

448+
public void testInferNestedFields() {
449+
List<String> featureNames = Arrays.asList("foo.baz", "bar.biz");
450+
Tree tree1 = Tree.builder()
451+
.setFeatureNames(featureNames)
452+
.setRoot(TreeNode.builder(0)
453+
.setLeftChild(1)
454+
.setRightChild(2)
455+
.setSplitFeature(0)
456+
.setThreshold(0.5))
457+
.addNode(TreeNode.builder(1).setLeafValue(0.3))
458+
.addNode(TreeNode.builder(2)
459+
.setThreshold(0.8)
460+
.setSplitFeature(1)
461+
.setLeftChild(3)
462+
.setRightChild(4))
463+
.addNode(TreeNode.builder(3).setLeafValue(0.1))
464+
.addNode(TreeNode.builder(4).setLeafValue(0.2)).build();
465+
Tree tree2 = Tree.builder()
466+
.setFeatureNames(featureNames)
467+
.setRoot(TreeNode.builder(0)
468+
.setLeftChild(1)
469+
.setRightChild(2)
470+
.setSplitFeature(0)
471+
.setThreshold(0.5))
472+
.addNode(TreeNode.builder(1).setLeafValue(1.5))
473+
.addNode(TreeNode.builder(2).setLeafValue(0.9))
474+
.build();
475+
Ensemble ensemble = Ensemble.builder()
476+
.setTargetType(TargetType.REGRESSION)
477+
.setFeatureNames(featureNames)
478+
.setTrainedModels(Arrays.asList(tree1, tree2))
479+
.setOutputAggregator(new WeightedSum(new double[]{0.5, 0.5}))
480+
.build();
481+
482+
Map<String, Object> featureMap = new HashMap<String, Object>() {{
483+
put("foo", new HashMap<String, Object>(){{
484+
put("baz", 0.4);
485+
}});
486+
put("bar", new HashMap<String, Object>(){{
487+
put("biz", 0.0);
488+
}});
489+
}};
490+
assertThat(0.9,
491+
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
492+
493+
featureMap = new HashMap<String, Object>() {{
494+
put("foo", new HashMap<String, Object>(){{
495+
put("baz", 2.0);
496+
}});
497+
put("bar", new HashMap<String, Object>(){{
498+
put("biz", 0.7);
499+
}});
500+
}};
501+
assertThat(0.5,
502+
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
503+
}
504+
448505
public void testOperationsEstimations() {
449506
Tree tree1 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar"), 2);
450507
Tree tree2 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,58 @@ public void testInfer() {
169169
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
170170
}
171171

172+
public void testInferNestedFields() {
173+
// Build a tree with 2 nodes and 3 leaves using 2 features
174+
// The leaves have unique values 0.1, 0.2, 0.3
175+
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
176+
TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
177+
builder.addLeaf(rootNode.getRightChild(), 0.3);
178+
TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
179+
builder.addLeaf(leftChildNode.getLeftChild(), 0.1);
180+
builder.addLeaf(leftChildNode.getRightChild(), 0.2);
181+
182+
List<String> featureNames = Arrays.asList("foo.baz", "bar.biz");
183+
Tree tree = builder.setFeatureNames(featureNames).build();
184+
185+
// This feature vector should hit the right child of the root node
186+
Map<String, Object> featureMap = new HashMap<String, Object>() {{
187+
put("foo", new HashMap<String, Object>(){{
188+
put("baz", 0.6);
189+
}});
190+
put("bar", new HashMap<String, Object>(){{
191+
put("biz", 0.0);
192+
}});
193+
}};
194+
assertThat(0.3,
195+
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
196+
197+
// This should hit the left child of the left child of the root node
198+
// i.e. it takes the path left, left
199+
featureMap = new HashMap<String, Object>() {{
200+
put("foo", new HashMap<String, Object>(){{
201+
put("baz", 0.3);
202+
}});
203+
put("bar", new HashMap<String, Object>(){{
204+
put("biz", 0.7);
205+
}});
206+
}};
207+
assertThat(0.1,
208+
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
209+
210+
// This should hit the right child of the left child of the root node
211+
// i.e. it takes the path left, right
212+
featureMap = new HashMap<String, Object>() {{
213+
put("foo", new HashMap<String, Object>(){{
214+
put("baz", 0.3);
215+
}});
216+
put("bar", new HashMap<String, Object>(){{
217+
put("biz", 0.9);
218+
}});
219+
}};
220+
assertThat(0.2,
221+
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
222+
}
223+
172224
public void testTreeClassificationProbability() {
173225
// Build a tree with 2 nodes and 3 leaves using 2 features
174226
// The leaves have unique values 0.1, 0.2, 0.3

0 commit comments

Comments
 (0)