Skip to content

Commit 18ea144

Browse files
authored
[ML][Inference] fixing classification inference for ensemble (#48463)
1 parent 1b8e288 commit 18ea144

File tree

6 files changed

+110
-5
lines changed

6 files changed

+110
-5
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.core.ml.inference.results;
7+
8+
import org.elasticsearch.common.io.stream.StreamInput;
9+
import org.elasticsearch.common.io.stream.StreamOutput;
10+
import org.elasticsearch.common.xcontent.XContentBuilder;
11+
import org.elasticsearch.ingest.IngestDocument;
12+
13+
import java.io.IOException;
14+
import java.util.Objects;
15+
16+
public class RawInferenceResults extends SingleValueInferenceResults {
17+
18+
public static final String NAME = "raw";
19+
20+
public RawInferenceResults(double value) {
21+
super(value);
22+
}
23+
24+
public RawInferenceResults(StreamInput in) throws IOException {
25+
super(in.readDouble());
26+
}
27+
28+
@Override
29+
public void writeTo(StreamOutput out) throws IOException {
30+
super.writeTo(out);
31+
}
32+
33+
@Override
34+
XContentBuilder innerToXContent(XContentBuilder builder, Params params) {
35+
return builder;
36+
}
37+
38+
@Override
39+
public boolean equals(Object object) {
40+
if (object == this) { return true; }
41+
if (object == null || getClass() != object.getClass()) { return false; }
42+
RawInferenceResults that = (RawInferenceResults) object;
43+
return Objects.equals(value(), that.value());
44+
}
45+
46+
@Override
47+
public int hashCode() {
48+
return Objects.hash(value());
49+
}
50+
51+
@Override
52+
public void writeResult(IngestDocument document, String resultField) {
53+
throw new UnsupportedOperationException("[raw] does not support writing inference results");
54+
}
55+
56+
@Override
57+
public String getWriteableName() {
58+
return NAME;
59+
}
60+
61+
@Override
62+
public String getName() {
63+
return NAME;
64+
}
65+
}
Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,18 @@
33
* or more contributor license agreements. Licensed under the Elastic License;
44
* you may not use this file except in compliance with the Elastic License.
55
*/
6-
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
6+
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
77

88
import org.elasticsearch.Version;
99
import org.elasticsearch.common.io.stream.StreamOutput;
1010
import org.elasticsearch.common.xcontent.XContentBuilder;
11-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
12-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
1311

1412
import java.io.IOException;
1513

1614
/**
1715
* Used by ensemble to pass into sub-models.
1816
*/
19-
class NullInferenceConfig implements InferenceConfig {
17+
public class NullInferenceConfig implements InferenceConfig {
2018

2119
public static final NullInferenceConfig INSTANCE = new NullInferenceConfig();
2220

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
import org.elasticsearch.common.xcontent.XContentParser;
1515
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
1616
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
17+
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
1718
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
1819
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
1920
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
2021
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
2122
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
2223
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
24+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
2325
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
2426
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
2527
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
@@ -135,6 +137,10 @@ public TargetType targetType() {
135137
}
136138

137139
private InferenceResults buildResults(List<Double> processedInferences, InferenceConfig config) {
140+
// Indicates that the config is useless and the caller just wants the raw value
141+
if (config instanceof NullInferenceConfig) {
142+
return new RawInferenceResults(outputAggregator.aggregate(processedInferences));
143+
}
138144
switch(targetType) {
139145
case REGRESSION:
140146
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences));

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
import org.elasticsearch.common.xcontent.XContentParser;
1616
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
1717
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
18+
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
1819
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
1920
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
2021
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
2122
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
2223
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
24+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
2325
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
2426
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
2527
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@@ -134,6 +136,10 @@ private InferenceResults infer(List<Double> features, InferenceConfig config) {
134136
}
135137

136138
private InferenceResults buildResult(Double value, InferenceConfig config) {
139+
// Indicates that the config is useless and the caller just wants the raw value
140+
if (config instanceof NullInferenceConfig) {
141+
return new RawInferenceResults(value);
142+
}
137143
switch (targetType) {
138144
case CLASSIFICATION:
139145
ClassificationConfig classificationConfig = (ClassificationConfig) config;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.core.ml.inference.results;
7+
8+
import org.elasticsearch.common.io.stream.Writeable;
9+
import org.elasticsearch.test.AbstractWireSerializingTestCase;
10+
11+
public class RawInferenceResultsTests extends AbstractWireSerializingTestCase<RawInferenceResults> {
12+
13+
public static RawInferenceResults createRandomResults() {
14+
return new RawInferenceResults(randomDouble());
15+
}
16+
17+
@Override
18+
protected RawInferenceResults createTestInstance() {
19+
return createRandomResults();
20+
}
21+
22+
@Override
23+
protected Writeable.Reader<RawInferenceResults> instanceReader() {
24+
return RawInferenceResults::new;
25+
}
26+
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,9 @@ public void testClassificationInference() {
322322
.setLeftChild(3)
323323
.setRightChild(4))
324324
.addNode(TreeNode.builder(3).setLeafValue(0.0))
325-
.addNode(TreeNode.builder(4).setLeafValue(1.0)).build();
325+
.addNode(TreeNode.builder(4).setLeafValue(1.0))
326+
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
327+
.build();
326328
Tree tree2 = Tree.builder()
327329
.setFeatureNames(featureNames)
328330
.setRoot(TreeNode.builder(0)
@@ -332,6 +334,7 @@ public void testClassificationInference() {
332334
.setThreshold(0.5))
333335
.addNode(TreeNode.builder(1).setLeafValue(0.0))
334336
.addNode(TreeNode.builder(2).setLeafValue(1.0))
337+
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
335338
.build();
336339
Tree tree3 = Tree.builder()
337340
.setFeatureNames(featureNames)
@@ -342,6 +345,7 @@ public void testClassificationInference() {
342345
.setThreshold(1.0))
343346
.addNode(TreeNode.builder(1).setLeafValue(1.0))
344347
.addNode(TreeNode.builder(2).setLeafValue(0.0))
348+
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
345349
.build();
346350
Ensemble ensemble = Ensemble.builder()
347351
.setTargetType(TargetType.CLASSIFICATION)

0 commit comments

Comments
 (0)