Skip to content

Commit f38a129

Browse files
committed
[Ml] Validate tree feature index is within range (elastic#52460)
This changes the tree validation code to ensure no node in the tree has a feature index that is beyond the bounds of the feature_names array. Specifically this handles the situation where the C++ emits a tree containing a single node and an empty feature_names list. This is valid tree used to centre the data in the ensemble but the validation code would reject this as feature_names is empty. This meant a broken workflow as you cannot GET the model and PUT it back
1 parent 4d006f0 commit f38a129

File tree

5 files changed

+139
-28
lines changed
  • client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree
  • x-pack/plugin

5 files changed

+139
-28
lines changed

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ public static Tree createRandom() {
6161
}
6262

6363
public static Tree buildRandomTree(List<String> featureNames, int depth, TargetType targetType) {
64-
int numFeatures = featureNames.size();
64+
int maxFeatureIndex = featureNames.size() -1;
6565
Tree.Builder builder = Tree.builder();
6666
builder.setFeatureNames(featureNames);
6767

68-
TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble());
68+
TreeNode.Builder node = builder.addJunction(0, randomInt(maxFeatureIndex), true, randomDouble());
6969
List<Integer> childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild());
7070

7171
for (int i = 0; i < depth -1; i++) {
@@ -76,7 +76,7 @@ public static Tree buildRandomTree(List<String> featureNames, int depth, TargetT
7676
builder.addLeaf(nodeId, randomDouble());
7777
} else {
7878
TreeNode.Builder childNode =
79-
builder.addJunction(nodeId, randomInt(numFeatures), true, randomDouble());
79+
builder.addJunction(nodeId, randomInt(maxFeatureIndex), true, randomDouble());
8080
nextNodes.add(childNode.getLeftChild());
8181
nextNodes.add(childNode.getRightChild());
8282
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java

+21-2
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,10 @@ public static Builder builder() {
253253

254254
@Override
255255
public void validate() {
256-
if (featureNames.isEmpty()) {
257-
throw ExceptionsHelper.badRequestException("[{}] must not be empty for tree model", FEATURE_NAMES.getPreferredName());
256+
int maxFeatureIndex = maxFeatureIndex();
257+
if (maxFeatureIndex >= featureNames.size()) {
258+
throw ExceptionsHelper.badRequestException("feature index [{}] is out of bounds for the [{}] array",
259+
maxFeatureIndex, FEATURE_NAMES.getPreferredName());
258260
}
259261
checkTargetType();
260262
detectMissingNodes();
@@ -267,6 +269,23 @@ public long estimatedNumOperations() {
267269
return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size();
268270
}
269271

272+
/**
273+
* The highest index of a feature used any of the nodes.
274+
* If no nodes use a feature return -1. This can only happen
275+
* if the tree contains a single leaf node.
276+
*
277+
* @return The max or -1
278+
*/
279+
int maxFeatureIndex() {
280+
int maxFeatureIndex = -1;
281+
282+
for (TreeNode node : nodes) {
283+
maxFeatureIndex = Math.max(maxFeatureIndex, node.getSplitFeature());
284+
}
285+
286+
return maxFeatureIndex;
287+
}
288+
270289
private void checkTargetType() {
271290
if (this.classificationLabels != null && this.targetType != TargetType.CLASSIFICATION) {
272291
throw ExceptionsHelper.badRequestException(

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java

+76-18
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import java.util.stream.IntStream;
3030

3131
import static org.hamcrest.Matchers.closeTo;
32+
import static org.hamcrest.Matchers.containsString;
3233
import static org.hamcrest.Matchers.equalTo;
3334

3435

@@ -72,10 +73,10 @@ public static Tree createRandom() {
7273

7374
public static Tree buildRandomTree(List<String> featureNames, int depth) {
7475
Tree.Builder builder = Tree.builder();
75-
int numFeatures = featureNames.size() - 1;
76+
int maxFeatureIndex = featureNames.size() - 1;
7677
builder.setFeatureNames(featureNames);
7778

78-
TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble());
79+
TreeNode.Builder node = builder.addJunction(0, randomInt(maxFeatureIndex), true, randomDouble());
7980
List<Integer> childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild());
8081

8182
for (int i = 0; i < depth -1; i++) {
@@ -86,7 +87,7 @@ public static Tree buildRandomTree(List<String> featureNames, int depth) {
8687
builder.addLeaf(nodeId, randomDouble());
8788
} else {
8889
TreeNode.Builder childNode =
89-
builder.addJunction(nodeId, randomInt(numFeatures), true, randomDouble());
90+
builder.addJunction(nodeId, randomInt(maxFeatureIndex), true, randomDouble());
9091
nextNodes.add(childNode.getLeftChild());
9192
nextNodes.add(childNode.getRightChild());
9293
}
@@ -339,26 +340,83 @@ public void testTreeWithTargetTypeAndLabelsMismatch() {
339340
assertThat(ex.getMessage(), equalTo(msg));
340341
}
341342

342-
public void testTreeWithEmptyFeatureNames() {
343-
String msg = "[feature_names] must not be empty for tree model";
344-
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
345-
Tree.builder()
346-
.setRoot(TreeNode.builder(0)
347-
.setLeftChild(1)
348-
.setSplitFeature(1)
349-
.setThreshold(randomDouble()))
350-
.setFeatureNames(Collections.emptyList())
351-
.build()
352-
.validate();
353-
});
354-
assertThat(ex.getMessage(), equalTo(msg));
355-
}
356-
357343
public void testOperationsEstimations() {
358344
Tree tree = buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);
359345
assertThat(tree.estimatedNumOperations(), equalTo(7L));
360346
}
361347

348+
public void testMaxFeatureIndex() {
349+
350+
int numFeatures = randomIntBetween(1, 15);
351+
// We need a tree where every feature is used, choose a depth big enough to
352+
// accommodate those non-leave nodes (leaf nodes don't have a feature index)
353+
int depth = (int) Math.ceil(Math.log(numFeatures +1) / Math.log(2)) + 1;
354+
List<String> featureNames = new ArrayList<>(numFeatures);
355+
for (int i=0; i<numFeatures; i++) {
356+
featureNames.add("feature" + i);
357+
}
358+
359+
Tree.Builder builder = Tree.builder().setFeatureNames(featureNames);
360+
361+
// build a tree using feature indices 0..numFeatures -1
362+
int featureIndex = 0;
363+
TreeNode.Builder node = builder.addJunction(0, featureIndex++, true, randomDouble());
364+
List<Integer> childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild());
365+
366+
for (int i = 0; i < depth -1; i++) {
367+
List<Integer> nextNodes = new ArrayList<>();
368+
for (int nodeId : childNodes) {
369+
if (i == depth -2) {
370+
builder.addLeaf(nodeId, randomDouble());
371+
} else {
372+
TreeNode.Builder childNode =
373+
builder.addJunction(nodeId, featureIndex++ % numFeatures, true, randomDouble());
374+
nextNodes.add(childNode.getLeftChild());
375+
nextNodes.add(childNode.getRightChild());
376+
}
377+
}
378+
childNodes = nextNodes;
379+
}
380+
381+
Tree tree = builder.build();
382+
383+
assertEquals(numFeatures, tree.maxFeatureIndex() +1);
384+
}
385+
386+
public void testMaxFeatureIndexSingleNodeTree() {
387+
Tree tree = Tree.builder()
388+
.setRoot(TreeNode.builder(0).setLeafValue(10.0))
389+
.setFeatureNames(Collections.emptyList())
390+
.build();
391+
392+
assertEquals(-1, tree.maxFeatureIndex());
393+
}
394+
395+
public void testValidateGivenMissingFeatures() {
396+
List<String> featureNames = Arrays.asList("foo", "bar", "baz");
397+
398+
// build a tree referencing a feature at index 3 which is not in the featureNames list
399+
Tree.Builder builder = Tree.builder().setFeatureNames(featureNames);
400+
builder.addJunction(0, 0, true, randomDouble());
401+
builder.addJunction(1, 1, true, randomDouble());
402+
builder.addJunction(2, 3, true, randomDouble());
403+
builder.addLeaf(3, randomDouble());
404+
builder.addLeaf(4, randomDouble());
405+
builder.addLeaf(5, randomDouble());
406+
builder.addLeaf(6, randomDouble());
407+
408+
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> builder.build().validate());
409+
assertThat(e.getDetailedMessage(), containsString("feature index [3] is out of bounds for the [feature_names] array"));
410+
}
411+
412+
public void testValidateGivenTreeWithNoFeatures() {
413+
Tree.builder()
414+
.setRoot(TreeNode.builder(0).setLeafValue(10.0))
415+
.setFeatureNames(Collections.emptyList())
416+
.build()
417+
.validate();
418+
}
419+
362420
private static Map<String, Object> zipObjMap(List<String> keys, List<? extends Object> values) {
363421
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
364422
}

x-pack/plugin/ml/qa/ml-with-security/build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ integTest.runner {
137137
'ml/inference_crud/Test get given missing trained model',
138138
'ml/inference_crud/Test get given expression without matches and allow_no_match is false',
139139
'ml/inference_crud/Test put ensemble with empty models',
140-
'ml/inference_crud/Test put ensemble with tree where tree has empty feature-names',
140+
'ml/inference_crud/Test put ensemble with tree where tree has out of bounds feature_names index',
141141
'ml/inference_crud/Test put model with empty input.field_names',
142142
'ml/inference_stats_crud/Test get stats given missing trained model',
143143
'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false',

x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml

+38-4
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,40 @@ setup:
333333
- match: { count: 1 }
334334
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
335335
---
336+
"Test put ensemble with single node and empty feature_names":
337+
338+
- do:
339+
ml.put_trained_model:
340+
model_id: "ensemble_tree_empty_feature_names"
341+
body: >
342+
{
343+
"input": {
344+
"field_names": "fieldy_mc_fieldname"
345+
},
346+
"definition": {
347+
"trained_model": {
348+
"ensemble": {
349+
"feature_names": [],
350+
"trained_models": [
351+
{
352+
"tree": {
353+
"feature_names": [],
354+
"tree_structure": [
355+
{
356+
"node_index": 0,
357+
"decision_type": "lte",
358+
"leaf_value": 12.0,
359+
"default_left": true
360+
}]
361+
}
362+
}
363+
]
364+
}
365+
}
366+
}
367+
}
368+
369+
---
336370
"Test put ensemble with empty models":
337371
- do:
338372
catch: /\[trained_models\] must not be empty/
@@ -353,11 +387,11 @@ setup:
353387
}
354388
}
355389
---
356-
"Test put ensemble with tree where tree has empty feature-names":
390+
"Test put ensemble with tree where tree has out of bounds feature_names index":
357391
- do:
358-
catch: /\[feature_names\] must not be empty/
392+
catch: /feature index \[1\] is out of bounds for the \[feature_names\] array/
359393
ml.put_trained_model:
360-
model_id: "ensemble_tree_missing_feature_names"
394+
model_id: "ensemble_tree_out_of_bounds_feature_names_index"
361395
body: >
362396
{
363397
"input": {
@@ -374,7 +408,7 @@ setup:
374408
"tree_structure": [
375409
{
376410
"node_index": 0,
377-
"split_feature": 0,
411+
"split_feature": 1,
378412
"split_gain": 12.0,
379413
"threshold": 10.0,
380414
"decision_type": "lte",

0 commit comments

Comments
 (0)