Skip to content

[Ml] Validate tree feature index is within range #52460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ public static Tree createRandom() {
}

public static Tree buildRandomTree(List<String> featureNames, int depth, TargetType targetType) {
int numFeatures = featureNames.size();
int maxFeatureIndex = featureNames.size() -1;
Tree.Builder builder = Tree.builder();
builder.setFeatureNames(featureNames);

TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble());
TreeNode.Builder node = builder.addJunction(0, randomInt(maxFeatureIndex), true, randomDouble());
List<Integer> childNodes = List.of(node.getLeftChild(), node.getRightChild());

for (int i = 0; i < depth -1; i++) {
Expand All @@ -76,7 +76,7 @@ public static Tree buildRandomTree(List<String> featureNames, int depth, TargetT
builder.addLeaf(nodeId, randomDouble());
} else {
TreeNode.Builder childNode =
builder.addJunction(nodeId, randomInt(numFeatures), true, randomDouble());
builder.addJunction(nodeId, randomInt(maxFeatureIndex), true, randomDouble());
nextNodes.add(childNode.getLeftChild());
nextNodes.add(childNode.getRightChild());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,10 @@ public static Builder builder() {

@Override
public void validate() {
if (featureNames.isEmpty()) {
throw ExceptionsHelper.badRequestException("[{}] must not be empty for tree model", FEATURE_NAMES.getPreferredName());
int maxFeatureIndex = maxFeatureIndex();
if (maxFeatureIndex >= featureNames.size()) {
throw ExceptionsHelper.badRequestException("feature index [{}] is out of bounds for the [{}] array",
maxFeatureIndex, FEATURE_NAMES.getPreferredName());
}
checkTargetType();
detectMissingNodes();
Expand All @@ -267,6 +269,23 @@ public long estimatedNumOperations() {
return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size();
}

/**
* The highest index of a feature used any of the nodes.
* If no nodes use a feature return -1. This can only happen
* if the tree contains a single leaf node.
*
* @return The max or -1
*/
int maxFeatureIndex() {
int maxFeatureIndex = -1;

for (TreeNode node : nodes) {
maxFeatureIndex = Math.max(maxFeatureIndex, node.getSplitFeature());
}

return maxFeatureIndex;
}

private void checkTargetType() {
if (this.classificationLabels != null && this.targetType != TargetType.CLASSIFICATION) {
throw ExceptionsHelper.badRequestException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.stream.IntStream;

import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;


Expand Down Expand Up @@ -72,10 +73,10 @@ public static Tree createRandom() {

public static Tree buildRandomTree(List<String> featureNames, int depth) {
Tree.Builder builder = Tree.builder();
int numFeatures = featureNames.size() - 1;
int maxFeatureIndex = featureNames.size() - 1;
builder.setFeatureNames(featureNames);

TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble());
TreeNode.Builder node = builder.addJunction(0, randomInt(maxFeatureIndex), true, randomDouble());
List<Integer> childNodes = List.of(node.getLeftChild(), node.getRightChild());

for (int i = 0; i < depth -1; i++) {
Expand All @@ -86,7 +87,7 @@ public static Tree buildRandomTree(List<String> featureNames, int depth) {
builder.addLeaf(nodeId, randomDouble());
} else {
TreeNode.Builder childNode =
builder.addJunction(nodeId, randomInt(numFeatures), true, randomDouble());
builder.addJunction(nodeId, randomInt(maxFeatureIndex), true, randomDouble());
nextNodes.add(childNode.getLeftChild());
nextNodes.add(childNode.getRightChild());
}
Expand Down Expand Up @@ -339,26 +340,83 @@ public void testTreeWithTargetTypeAndLabelsMismatch() {
assertThat(ex.getMessage(), equalTo(msg));
}

public void testTreeWithEmptyFeatureNames() {
String msg = "[feature_names] must not be empty for tree model";
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
Tree.builder()
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setSplitFeature(1)
.setThreshold(randomDouble()))
.setFeatureNames(Collections.emptyList())
.build()
.validate();
});
assertThat(ex.getMessage(), equalTo(msg));
}

public void testOperationsEstimations() {
Tree tree = buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);
assertThat(tree.estimatedNumOperations(), equalTo(7L));
}

public void testMaxFeatureIndex() {

int numFeatures = randomIntBetween(1, 15);
// We need a tree where every feature is used, choose a depth big enough to
// accommodate those non-leave nodes (leaf nodes don't have a feature index)
int depth = (int) Math.ceil(Math.log(numFeatures +1) / Math.log(2)) + 1;
List<String> featureNames = new ArrayList<>(numFeatures);
for (int i=0; i<numFeatures; i++) {
featureNames.add("feature" + i);
}

Tree.Builder builder = Tree.builder().setFeatureNames(featureNames);

// build a tree using feature indices 0..numFeatures -1
int featureIndex = 0;
TreeNode.Builder node = builder.addJunction(0, featureIndex++, true, randomDouble());
List<Integer> childNodes = List.of(node.getLeftChild(), node.getRightChild());

for (int i = 0; i < depth -1; i++) {
List<Integer> nextNodes = new ArrayList<>();
for (int nodeId : childNodes) {
if (i == depth -2) {
builder.addLeaf(nodeId, randomDouble());
} else {
TreeNode.Builder childNode =
builder.addJunction(nodeId, featureIndex++ % numFeatures, true, randomDouble());
nextNodes.add(childNode.getLeftChild());
nextNodes.add(childNode.getRightChild());
}
}
childNodes = nextNodes;
}

Tree tree = builder.build();

assertEquals(numFeatures, tree.maxFeatureIndex() +1);
}

public void testMaxFeatureIndexSingleNodeTree() {
Tree tree = Tree.builder()
.setRoot(TreeNode.builder(0).setLeafValue(10.0))
.setFeatureNames(Collections.emptyList())
.build();

assertEquals(-1, tree.maxFeatureIndex());
}

public void testValidateGivenMissingFeatures() {
List<String> featureNames = Arrays.asList("foo", "bar", "baz");

// build a tree referencing a feature at index 3 which is not in the featureNames list
Tree.Builder builder = Tree.builder().setFeatureNames(featureNames);
builder.addJunction(0, 0, true, randomDouble());
builder.addJunction(1, 1, true, randomDouble());
builder.addJunction(2, 3, true, randomDouble());
builder.addLeaf(3, randomDouble());
builder.addLeaf(4, randomDouble());
builder.addLeaf(5, randomDouble());
builder.addLeaf(6, randomDouble());

ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> builder.build().validate());
assertThat(e.getDetailedMessage(), containsString("feature index [3] is out of bounds for the [feature_names] array"));
}

public void testValidateGivenTreeWithNoFeatures() {
Tree.builder()
.setRoot(TreeNode.builder(0).setLeafValue(10.0))
.setFeatureNames(Collections.emptyList())
.build()
.validate();
}

private static Map<String, Object> zipObjMap(List<String> keys, List<? extends Object> values) {
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
}
Expand Down
2 changes: 1 addition & 1 deletion x-pack/plugin/ml/qa/ml-with-security/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ integTest.runner {
'ml/inference_crud/Test get given missing trained model',
'ml/inference_crud/Test get given expression without matches and allow_no_match is false',
'ml/inference_crud/Test put ensemble with empty models',
'ml/inference_crud/Test put ensemble with tree where tree has empty feature-names',
'ml/inference_crud/Test put ensemble with tree where tree has out of bounds feature_names index',
'ml/inference_crud/Test put model with empty input.field_names',
'ml/inference_stats_crud/Test get stats given missing trained model',
'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,40 @@ setup:
- match: { count: 1 }
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
---
"Test put ensemble with single node and empty feature_names":

- do:
ml.put_trained_model:
model_id: "ensemble_tree_empty_feature_names"
body: >
{
"input": {
"field_names": "fieldy_mc_fieldname"
},
"definition": {
"trained_model": {
"ensemble": {
"feature_names": [],
"trained_models": [
{
"tree": {
"feature_names": [],
"tree_structure": [
{
"node_index": 0,
"decision_type": "lte",
"leaf_value": 12.0,
"default_left": true
}]
}
}
]
}
}
}
}

---
"Test put ensemble with empty models":
- do:
catch: /\[trained_models\] must not be empty/
Expand All @@ -353,11 +387,11 @@ setup:
}
}
---
"Test put ensemble with tree where tree has empty feature-names":
"Test put ensemble with tree where tree has out of bounds feature_names index":
- do:
catch: /\[feature_names\] must not be empty/
catch: /feature index \[1\] is out of bounds for the \[feature_names\] array/
ml.put_trained_model:
model_id: "ensemble_tree_missing_feature_names"
model_id: "ensemble_tree_out_of_bounds_feature_names_index"
body: >
{
"input": {
Expand All @@ -374,7 +408,7 @@ setup:
"tree_structure": [
{
"node_index": 0,
"split_feature": 0,
"split_feature": 1,
"split_gain": 12.0,
"threshold": 10.0,
"decision_type": "lte",
Expand Down