Skip to content

Commit fcddaa9

Browse files
authored
[7.x] [ML][Inference] adding tree model (#47044) (#47141)
* [ML][Inference] adding tree model (#47044) * [ML][Inference] adding tree model * renaming features for updated schema * fixing 7.x compilation
1 parent 7ac647c commit fcddaa9

File tree

17 files changed

+1711
-4
lines changed

17 files changed

+1711
-4
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
*/
1919
package org.elasticsearch.client.ml.inference;
2020

21+
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
22+
import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
2123
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
2224
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
2325
import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor;
@@ -42,6 +44,10 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
4244
TargetMeanEncoding::fromXContent));
4345
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(FrequencyEncoding.NAME),
4446
FrequencyEncoding::fromXContent));
47+
48+
// Model
49+
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Tree.NAME), Tree::fromXContent));
50+
4551
return namedXContent;
4652
}
4753

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.inference.trainedmodel;
20+
21+
import org.elasticsearch.common.xcontent.ToXContentObject;
22+
23+
import java.util.List;
24+
25+
public interface TrainedModel extends ToXContentObject {
26+
27+
/**
28+
* @return List of featureNames expected by the model. In the order that they are expected
29+
*/
30+
List<String> getFeatureNames();
31+
32+
/**
33+
* @return The name of the model
34+
*/
35+
String getName();
36+
}
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.inference.trainedmodel.tree;
20+
21+
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
22+
import org.elasticsearch.common.ParseField;
23+
import org.elasticsearch.common.Strings;
24+
import org.elasticsearch.common.xcontent.ObjectParser;
25+
import org.elasticsearch.common.xcontent.XContentBuilder;
26+
import org.elasticsearch.common.xcontent.XContentParser;
27+
28+
import java.io.IOException;
29+
import java.util.ArrayList;
30+
import java.util.Arrays;
31+
import java.util.Collections;
32+
import java.util.List;
33+
import java.util.Objects;
34+
import java.util.stream.Collectors;
35+
36+
public class Tree implements TrainedModel {
37+
38+
public static final String NAME = "tree";
39+
40+
public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
41+
public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure");
42+
43+
private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME, true, Builder::new);
44+
45+
static {
46+
PARSER.declareStringArray(Builder::setFeatureNames, FEATURE_NAMES);
47+
PARSER.declareObjectArray(Builder::setNodes, (p, c) -> TreeNode.fromXContent(p), TREE_STRUCTURE);
48+
}
49+
50+
public static Tree fromXContent(XContentParser parser) {
51+
return PARSER.apply(parser, null).build();
52+
}
53+
54+
private final List<String> featureNames;
55+
private final List<TreeNode> nodes;
56+
57+
Tree(List<String> featureNames, List<TreeNode> nodes) {
58+
this.featureNames = Collections.unmodifiableList(Objects.requireNonNull(featureNames));
59+
this.nodes = Collections.unmodifiableList(Objects.requireNonNull(nodes));
60+
}
61+
62+
@Override
63+
public String getName() {
64+
return NAME;
65+
}
66+
67+
@Override
68+
public List<String> getFeatureNames() {
69+
return featureNames;
70+
}
71+
72+
public List<TreeNode> getNodes() {
73+
return nodes;
74+
}
75+
76+
@Override
77+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
78+
builder.startObject();
79+
builder.field(FEATURE_NAMES.getPreferredName(), featureNames);
80+
builder.field(TREE_STRUCTURE.getPreferredName(), nodes);
81+
builder.endObject();
82+
return builder;
83+
}
84+
85+
@Override
86+
public String toString() {
87+
return Strings.toString(this);
88+
}
89+
90+
@Override
91+
public boolean equals(Object o) {
92+
if (this == o) return true;
93+
if (o == null || getClass() != o.getClass()) return false;
94+
Tree that = (Tree) o;
95+
return Objects.equals(featureNames, that.featureNames)
96+
&& Objects.equals(nodes, that.nodes);
97+
}
98+
99+
@Override
100+
public int hashCode() {
101+
return Objects.hash(featureNames, nodes);
102+
}
103+
104+
public static Builder builder() {
105+
return new Builder();
106+
}
107+
108+
public static class Builder {
109+
private List<String> featureNames;
110+
private ArrayList<TreeNode.Builder> nodes;
111+
private int numNodes;
112+
113+
public Builder() {
114+
nodes = new ArrayList<>();
115+
// allocate space in the root node and set to a leaf
116+
nodes.add(null);
117+
addLeaf(0, 0.0);
118+
numNodes = 1;
119+
}
120+
121+
public Builder setFeatureNames(List<String> featureNames) {
122+
this.featureNames = featureNames;
123+
return this;
124+
}
125+
126+
public Builder addNode(TreeNode.Builder node) {
127+
nodes.add(node);
128+
return this;
129+
}
130+
131+
public Builder setNodes(List<TreeNode.Builder> nodes) {
132+
this.nodes = new ArrayList<>(nodes);
133+
return this;
134+
}
135+
136+
public Builder setNodes(TreeNode.Builder... nodes) {
137+
return setNodes(Arrays.asList(nodes));
138+
}
139+
140+
/**
141+
* Add a decision node. Space for the child nodes is allocated
142+
* @param nodeIndex Where to place the node. This is either 0 (root) or an existing child node index
143+
* @param featureIndex The feature index the decision is made on
144+
* @param isDefaultLeft Default left branch if the feature is missing
145+
* @param decisionThreshold The decision threshold
146+
* @return The created node
147+
*/
148+
public TreeNode.Builder addJunction(int nodeIndex, int featureIndex, boolean isDefaultLeft, double decisionThreshold) {
149+
int leftChild = numNodes++;
150+
int rightChild = numNodes++;
151+
nodes.ensureCapacity(nodeIndex + 1);
152+
for (int i = nodes.size(); i < nodeIndex + 1; i++) {
153+
nodes.add(null);
154+
}
155+
156+
TreeNode.Builder node = TreeNode.builder(nodeIndex)
157+
.setDefaultLeft(isDefaultLeft)
158+
.setLeftChild(leftChild)
159+
.setRightChild(rightChild)
160+
.setSplitFeature(featureIndex)
161+
.setThreshold(decisionThreshold);
162+
nodes.set(nodeIndex, node);
163+
164+
// allocate space for the child nodes
165+
while (nodes.size() <= rightChild) {
166+
nodes.add(null);
167+
}
168+
169+
return node;
170+
}
171+
172+
/**
173+
* Sets the node at {@code nodeIndex} to a leaf node.
174+
* @param nodeIndex The index as allocated by a call to {@link #addJunction(int, int, boolean, double)}
175+
* @param value The prediction value
176+
* @return this
177+
*/
178+
public Builder addLeaf(int nodeIndex, double value) {
179+
for (int i = nodes.size(); i < nodeIndex + 1; i++) {
180+
nodes.add(null);
181+
}
182+
nodes.set(nodeIndex, TreeNode.builder(nodeIndex).setLeafValue(value));
183+
return this;
184+
}
185+
186+
public Tree build() {
187+
return new Tree(featureNames,
188+
nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()));
189+
}
190+
}
191+
192+
}

0 commit comments

Comments
 (0)