Skip to content

Commit 19a6c5d

Browse files
authored
[7.x] [ML][Inference] Add support for multi-value leaves to the tree model (#52531) (#52901)
* [ML][Inference] Add support for multi-value leaves to the tree model (#52531) This adds support for multi-value leaves. This is a prerequisite for multi-class boosted tree classification.
1 parent 710a9ea commit 19a6c5d

File tree

26 files changed

+575
-197
lines changed

26 files changed

+575
-197
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import java.io.IOException;
3131
import java.util.ArrayList;
3232
import java.util.Arrays;
33+
import java.util.Collections;
3334
import java.util.List;
3435
import java.util.Objects;
3536
import java.util.stream.Collectors;
@@ -225,7 +226,7 @@ public Builder addLeaf(int nodeIndex, double value) {
225226
for (int i = nodes.size(); i < nodeIndex + 1; i++) {
226227
nodes.add(null);
227228
}
228-
nodes.set(nodeIndex, TreeNode.builder(nodeIndex).setLeafValue(value));
229+
nodes.set(nodeIndex, TreeNode.builder(nodeIndex).setLeafValue(Collections.singletonList(value)));
229230
return this;
230231
}
231232

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNode.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.common.xcontent.XContentParser;
2828

2929
import java.io.IOException;
30+
import java.util.List;
3031
import java.util.Objects;
3132

3233
public class TreeNode implements ToXContentObject {
@@ -61,7 +62,7 @@ public class TreeNode implements ToXContentObject {
6162
PARSER.declareInt(Builder::setSplitFeature, SPLIT_FEATURE);
6263
PARSER.declareInt(Builder::setNodeIndex, NODE_INDEX);
6364
PARSER.declareDouble(Builder::setSplitGain, SPLIT_GAIN);
64-
PARSER.declareDouble(Builder::setLeafValue, LEAF_VALUE);
65+
PARSER.declareDoubleArray(Builder::setLeafValue, LEAF_VALUE);
6566
PARSER.declareLong(Builder::setNumberSamples, NUMBER_SAMPLES);
6667
}
6768

@@ -74,7 +75,7 @@ public static Builder fromXContent(XContentParser parser) {
7475
private final Integer splitFeature;
7576
private final int nodeIndex;
7677
private final Double splitGain;
77-
private final Double leafValue;
78+
private final List<Double> leafValue;
7879
private final Boolean defaultLeft;
7980
private final Integer leftChild;
8081
private final Integer rightChild;
@@ -86,7 +87,7 @@ public static Builder fromXContent(XContentParser parser) {
8687
Integer splitFeature,
8788
int nodeIndex,
8889
Double splitGain,
89-
Double leafValue,
90+
List<Double> leafValue,
9091
Boolean defaultLeft,
9192
Integer leftChild,
9293
Integer rightChild,
@@ -123,7 +124,7 @@ public Double getSplitGain() {
123124
return splitGain;
124125
}
125126

126-
public Double getLeafValue() {
127+
public List<Double> getLeafValue() {
127128
return leafValue;
128129
}
129130

@@ -212,7 +213,7 @@ public static class Builder {
212213
private Integer splitFeature;
213214
private int nodeIndex;
214215
private Double splitGain;
215-
private Double leafValue;
216+
private List<Double> leafValue;
216217
private Boolean defaultLeft;
217218
private Integer leftChild;
218219
private Integer rightChild;
@@ -250,7 +251,7 @@ public Builder setSplitGain(Double splitGain) {
250251
return this;
251252
}
252253

253-
public Builder setLeafValue(Double leafValue) {
254+
public Builder setLeafValue(List<Double> leafValue) {
254255
this.leafValue = leafValue;
255256
return this;
256257
}

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNodeTests.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.test.AbstractXContentTestCase;
2424

2525
import java.io.IOException;
26+
import java.util.Collections;
2627

2728
public class TreeNodeTests extends AbstractXContentTestCase<TreeNode> {
2829

@@ -48,7 +49,7 @@ protected TreeNode createTestInstance() {
4849
public static TreeNode createRandomLeafNode(double internalValue) {
4950
return TreeNode.builder(randomInt(100))
5051
.setDefaultLeft(randomBoolean() ? null : randomBoolean())
51-
.setLeafValue(internalValue)
52+
.setLeafValue(Collections.singletonList(internalValue))
5253
.setNumberSamples(randomNonNegativeLong())
5354
.build();
5455
}
@@ -60,7 +61,7 @@ public static TreeNode.Builder createRandom(int nodeIndex,
6061
Integer featureIndex,
6162
Operator operator) {
6263
return TreeNode.builder(nodeIndex)
63-
.setLeafValue(left == null ? randomDouble() : null)
64+
.setLeafValue(left == null ? Collections.singletonList(randomDouble()) : null)
6465
.setDefaultLeft(randomBoolean() ? null : randomBoolean())
6566
.setLeftChild(left)
6667
.setRightChild(right)

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,51 @@
55
*/
66
package org.elasticsearch.xpack.core.ml.inference.results;
77

8-
import org.elasticsearch.common.io.stream.StreamInput;
98
import org.elasticsearch.common.io.stream.StreamOutput;
109
import org.elasticsearch.ingest.IngestDocument;
1110

1211
import java.io.IOException;
12+
import java.util.Arrays;
1313
import java.util.Map;
1414
import java.util.Objects;
1515

16-
public class RawInferenceResults extends SingleValueInferenceResults {
16+
public class RawInferenceResults implements InferenceResults {
1717

1818
public static final String NAME = "raw";
1919

20-
public RawInferenceResults(double value, Map<String, Double> featureImportance) {
21-
super(value, featureImportance);
20+
private final double[] value;
21+
private final Map<String, Double> featureImportance;
22+
23+
public RawInferenceResults(double[] value, Map<String, Double> featureImportance) {
24+
this.value = value;
25+
this.featureImportance = featureImportance;
26+
}
27+
28+
public double[] getValue() {
29+
return value;
2230
}
2331

24-
public RawInferenceResults(StreamInput in) throws IOException {
25-
super(in);
32+
public Map<String, Double> getFeatureImportance() {
33+
return featureImportance;
2634
}
2735

2836
@Override
2937
public void writeTo(StreamOutput out) throws IOException {
30-
super.writeTo(out);
38+
throw new UnsupportedOperationException("[raw] does not support wire serialization");
3139
}
3240

3341
@Override
3442
public boolean equals(Object object) {
3543
if (object == this) { return true; }
3644
if (object == null || getClass() != object.getClass()) { return false; }
3745
RawInferenceResults that = (RawInferenceResults) object;
38-
return Objects.equals(value(), that.value())
39-
&& Objects.equals(getFeatureImportance(), that.getFeatureImportance());
46+
return Arrays.equals(value, that.value)
47+
&& Objects.equals(featureImportance, that.featureImportance);
4048
}
4149

4250
@Override
4351
public int hashCode() {
44-
return Objects.hash(value(), getFeatureImportance());
52+
return Objects.hash(Arrays.hashCode(value), featureImportance);
4553
}
4654

4755
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,30 +26,29 @@ private InferenceHelpers() { }
2626
/**
2727
* @return Tuple of the highest scored index and the top classes
2828
*/
29-
public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses(List<Double> probabilities,
29+
public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses(double[] probabilities,
3030
List<String> classificationLabels,
3131
@Nullable double[] classificationWeights,
3232
int numToInclude) {
3333

34-
if (classificationLabels != null && probabilities.size() != classificationLabels.size()) {
34+
if (classificationLabels != null && probabilities.length != classificationLabels.size()) {
3535
throw ExceptionsHelper
3636
.serverError(
3737
"model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]",
3838
null,
39-
probabilities.size(),
39+
probabilities.length,
4040
classificationLabels.size());
4141
}
4242

43-
List<Double> scores = classificationWeights == null ?
43+
double[] scores = classificationWeights == null ?
4444
probabilities :
45-
IntStream.range(0, probabilities.size())
46-
.mapToDouble(i -> probabilities.get(i) * classificationWeights[i])
47-
.boxed()
48-
.collect(Collectors.toList());
45+
IntStream.range(0, probabilities.length)
46+
.mapToDouble(i -> probabilities[i] * classificationWeights[i])
47+
.toArray();
4948

50-
int[] sortedIndices = IntStream.range(0, probabilities.size())
49+
int[] sortedIndices = IntStream.range(0, scores.length)
5150
.boxed()
52-
.sorted(Comparator.comparing(scores::get).reversed())
51+
.sorted(Comparator.comparing(i -> scores[(Integer)i]).reversed())
5352
.mapToInt(i -> i)
5453
.toArray();
5554

@@ -59,14 +58,14 @@ public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>>
5958

6059
List<String> labels = classificationLabels == null ?
6160
// If we don't have the labels we should return the top classification values anyways, they will just be numeric
62-
IntStream.range(0, probabilities.size()).boxed().map(String::valueOf).collect(Collectors.toList()) :
61+
IntStream.range(0, probabilities.length).boxed().map(String::valueOf).collect(Collectors.toList()) :
6362
classificationLabels;
6463

65-
int count = numToInclude < 0 ? probabilities.size() : Math.min(numToInclude, probabilities.size());
64+
int count = numToInclude < 0 ? probabilities.length : Math.min(numToInclude, probabilities.length);
6665
List<ClassificationInferenceResults.TopClassEntry> topClassEntries = new ArrayList<>(count);
6766
for(int i = 0; i < count; i++) {
6867
int idx = sortedIndices[i];
69-
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx), scores.get(idx)));
68+
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities[idx], scores[idx]));
7069
}
7170

7271
return Tuple.tuple(sortedIndices[0], topClassEntries);

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
77

88
import org.apache.lucene.util.Accountable;
9+
import org.elasticsearch.Version;
910
import org.elasticsearch.common.Nullable;
1011
import org.elasticsearch.common.io.stream.NamedWriteable;
1112
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
@@ -62,4 +63,8 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
6263
* @return A {@code Map<String, Double>} mapping each featureName to its importance
6364
*/
6465
Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
66+
67+
default Version getMinimalCompatibilityVersion() {
68+
return Version.V_7_6_0;
69+
}
6570
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import org.apache.lucene.util.Accountable;
99
import org.apache.lucene.util.Accountables;
1010
import org.apache.lucene.util.RamUsageEstimator;
11+
import org.elasticsearch.Version;
1112
import org.elasticsearch.common.Nullable;
1213
import org.elasticsearch.common.ParseField;
1314
import org.elasticsearch.common.collect.Tuple;
@@ -20,7 +21,6 @@
2021
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
2122
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
2223
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
23-
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
2424
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
2525
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
2626
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
@@ -139,19 +139,20 @@ public InferenceResults infer(Map<String, Object> fields, InferenceConfig config
139139
throw ExceptionsHelper.badRequestException(
140140
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
141141
}
142-
List<Double> inferenceResults = new ArrayList<>(this.models.size());
142+
double[][] inferenceResults = new double[this.models.size()][];
143143
List<Map<String, Double>> featureInfluence = new ArrayList<>();
144+
int i = 0;
144145
NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(config.requestingImportance());
145-
this.models.forEach(model -> {
146+
for (TrainedModel model : models) {
146147
InferenceResults result = model.infer(fields, subModelInferenceConfig, Collections.emptyMap());
147-
assert result instanceof SingleValueInferenceResults;
148-
SingleValueInferenceResults inferenceResult = (SingleValueInferenceResults) result;
149-
inferenceResults.add(inferenceResult.value());
148+
assert result instanceof RawInferenceResults;
149+
RawInferenceResults inferenceResult = (RawInferenceResults) result;
150+
inferenceResults[i++] = inferenceResult.getValue();
150151
if (config.requestingImportance()) {
151152
featureInfluence.add(inferenceResult.getFeatureImportance());
152153
}
153-
});
154-
List<Double> processed = outputAggregator.processValues(inferenceResults);
154+
}
155+
double[] processed = outputAggregator.processValues(inferenceResults);
155156
return buildResults(processed, featureInfluence, config, featureDecoderMap);
156157
}
157158

@@ -160,13 +161,13 @@ public TargetType targetType() {
160161
return targetType;
161162
}
162163

163-
private InferenceResults buildResults(List<Double> processedInferences,
164+
private InferenceResults buildResults(double[] processedInferences,
164165
List<Map<String, Double>> featureInfluence,
165166
InferenceConfig config,
166167
Map<String, String> featureDecoderMap) {
167168
// Indicates that the config is useless and the caller just wants the raw value
168169
if (config instanceof NullInferenceConfig) {
169-
return new RawInferenceResults(outputAggregator.aggregate(processedInferences),
170+
return new RawInferenceResults(new double[] {outputAggregator.aggregate(processedInferences)},
170171
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
171172
}
172173
switch(targetType) {
@@ -176,7 +177,7 @@ private InferenceResults buildResults(List<Double> processedInferences,
176177
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
177178
case CLASSIFICATION:
178179
ClassificationConfig classificationConfig = (ClassificationConfig) config;
179-
assert classificationWeights == null || processedInferences.size() == classificationWeights.length;
180+
assert classificationWeights == null || processedInferences.length == classificationWeights.length;
180181
// Adjust the probabilities according to the thresholds
181182
Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
182183
processedInferences,
@@ -356,6 +357,11 @@ public Collection<Accountable> getChildResources() {
356357
return Collections.unmodifiableCollection(accountables);
357358
}
358359

360+
@Override
361+
public Version getMinimalCompatibilityVersion() {
362+
return models.stream().map(TrainedModel::getMinimalCompatibilityVersion).max(Version::compareTo).orElse(Version.V_7_6_0);
363+
}
364+
359365
public static class Builder {
360366
private List<String> featureNames;
361367
private List<TrainedModel> trainedModels;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
import java.util.Arrays;
2020
import java.util.List;
2121
import java.util.Objects;
22-
import java.util.stream.IntStream;
2322

2423
import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.sigmoid;
24+
import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax;
2525

2626
public class LogisticRegression implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
2727

@@ -78,31 +78,39 @@ public Integer expectedValueSize() {
7878
}
7979

8080
@Override
81-
public List<Double> processValues(List<Double> values) {
81+
public double[] processValues(double[][] values) {
8282
Objects.requireNonNull(values, "values must not be null");
83-
if (weights != null && values.size() != weights.length) {
83+
if (weights != null && values.length != weights.length) {
8484
throw new IllegalArgumentException("values must be the same length as weights.");
8585
}
86-
double summation = weights == null ?
87-
values.stream().mapToDouble(Double::valueOf).sum() :
88-
IntStream.range(0, weights.length).mapToDouble(i -> values.get(i) * weights[i]).sum();
89-
double probOfClassOne = sigmoid(summation);
86+
double[] sumOnAxis1 = new double[values[0].length];
87+
for (int j = 0; j < values.length; j++) {
88+
double[] value = values[j];
89+
double weight = weights == null ? 1.0 : weights[j];
90+
for(int i = 0; i < value.length; i++) {
91+
if (i >= sumOnAxis1.length) {
92+
throw new IllegalArgumentException("value entries must have the same dimensions");
93+
}
94+
sumOnAxis1[i] += (value[i] * weight);
95+
}
96+
}
97+
if (sumOnAxis1.length > 1) {
98+
return softMax(sumOnAxis1);
99+
}
100+
101+
double probOfClassOne = sigmoid(sumOnAxis1[0]);
90102
assert 0.0 <= probOfClassOne && probOfClassOne <= 1.0;
91-
return Arrays.asList(1.0 - probOfClassOne, probOfClassOne);
103+
return new double[] {1.0 - probOfClassOne, probOfClassOne};
92104
}
93105

94106
@Override
95-
public double aggregate(List<Double> values) {
107+
public double aggregate(double[] values) {
96108
Objects.requireNonNull(values, "values must not be null");
97-
assert values.size() == 2;
98109
int bestValue = 0;
99110
double bestProb = Double.NEGATIVE_INFINITY;
100-
for (int i = 0; i < values.size(); i++) {
101-
if (values.get(i) == null) {
102-
throw new IllegalArgumentException("values must not contain null values");
103-
}
104-
if (values.get(i) > bestProb) {
105-
bestProb = values.get(i);
111+
for (int i = 0; i < values.length; i++) {
112+
if (values[i] > bestProb) {
113+
bestProb = values[i];
106114
bestValue = i;
107115
}
108116
}

0 commit comments

Comments
 (0)