Skip to content

Commit af4e6ed

Browse files
authored
[ML][Inference] adding ensemble model objects (#47241)
* [ML][Inference] adding ensemble model objects * addressing PR comments * Update TreeTests.java * addressing PR comments * fixing test
1 parent f47da1d commit af4e6ed

31 files changed

+2421
-118
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
package org.elasticsearch.client.ml.inference;
2020

2121
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
22+
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
23+
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator;
24+
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
25+
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
2226
import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
2327
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
2428
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
@@ -47,6 +51,15 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
4751

4852
// Model
4953
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Tree.NAME), Tree::fromXContent));
54+
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Ensemble.NAME), Ensemble::fromXContent));
55+
56+
// Aggregating output
57+
namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class,
58+
new ParseField(WeightedMode.NAME),
59+
WeightedMode::fromXContent));
60+
namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class,
61+
new ParseField(WeightedSum.NAME),
62+
WeightedSum::fromXContent));
5063

5164
return namedXContent;
5265
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.inference.trainedmodel;
20+
21+
import java.util.Locale;
22+
23+
public enum TargetType {
24+
25+
REGRESSION, CLASSIFICATION;
26+
27+
public static TargetType fromString(String name) {
28+
return valueOf(name.trim().toUpperCase(Locale.ROOT));
29+
}
30+
31+
@Override
32+
public String toString() {
33+
return name().toLowerCase(Locale.ROOT);
34+
}
35+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
20+
21+
import org.elasticsearch.client.ml.inference.NamedXContentObjectHelper;
22+
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
23+
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
24+
import org.elasticsearch.common.Nullable;
25+
import org.elasticsearch.common.ParseField;
26+
import org.elasticsearch.common.xcontent.ObjectParser;
27+
import org.elasticsearch.common.xcontent.ToXContent;
28+
import org.elasticsearch.common.xcontent.XContentBuilder;
29+
import org.elasticsearch.common.xcontent.XContentParser;
30+
31+
import java.io.IOException;
32+
import java.util.Collections;
33+
import java.util.List;
34+
import java.util.Objects;
35+
36+
public class Ensemble implements TrainedModel {
37+
38+
public static final String NAME = "ensemble";
39+
public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
40+
public static final ParseField TRAINED_MODELS = new ParseField("trained_models");
41+
public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output");
42+
public static final ParseField TARGET_TYPE = new ParseField("target_type");
43+
public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
44+
45+
private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
46+
NAME,
47+
true,
48+
Ensemble.Builder::new);
49+
50+
static {
51+
PARSER.declareStringArray(Ensemble.Builder::setFeatureNames, FEATURE_NAMES);
52+
PARSER.declareNamedObjects(Ensemble.Builder::setTrainedModels,
53+
(p, c, n) ->
54+
p.namedObject(TrainedModel.class, n, null),
55+
(ensembleBuilder) -> { /* Noop does not matter client side */ },
56+
TRAINED_MODELS);
57+
PARSER.declareNamedObjects(Ensemble.Builder::setOutputAggregatorFromParser,
58+
(p, c, n) -> p.namedObject(OutputAggregator.class, n, null),
59+
(ensembleBuilder) -> { /* Noop does not matter client side */ },
60+
AGGREGATE_OUTPUT);
61+
PARSER.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE);
62+
PARSER.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
63+
}
64+
65+
public static Ensemble fromXContent(XContentParser parser) {
66+
return PARSER.apply(parser, null).build();
67+
}
68+
69+
private final List<String> featureNames;
70+
private final List<TrainedModel> models;
71+
private final OutputAggregator outputAggregator;
72+
private final TargetType targetType;
73+
private final List<String> classificationLabels;
74+
75+
Ensemble(List<String> featureNames,
76+
List<TrainedModel> models,
77+
@Nullable OutputAggregator outputAggregator,
78+
TargetType targetType,
79+
@Nullable List<String> classificationLabels) {
80+
this.featureNames = featureNames;
81+
this.models = models;
82+
this.outputAggregator = outputAggregator;
83+
this.targetType = targetType;
84+
this.classificationLabels = classificationLabels;
85+
}
86+
87+
@Override
88+
public List<String> getFeatureNames() {
89+
return featureNames;
90+
}
91+
92+
@Override
93+
public String getName() {
94+
return NAME;
95+
}
96+
97+
@Override
98+
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
99+
builder.startObject();
100+
if (featureNames != null) {
101+
builder.field(FEATURE_NAMES.getPreferredName(), featureNames);
102+
}
103+
if (models != null) {
104+
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, TRAINED_MODELS.getPreferredName(), models);
105+
}
106+
if (outputAggregator != null) {
107+
NamedXContentObjectHelper.writeNamedObjects(builder,
108+
params,
109+
false,
110+
AGGREGATE_OUTPUT.getPreferredName(),
111+
Collections.singletonList(outputAggregator));
112+
}
113+
if (targetType != null) {
114+
builder.field(TARGET_TYPE.getPreferredName(), targetType);
115+
}
116+
if (classificationLabels != null) {
117+
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
118+
}
119+
builder.endObject();
120+
return builder;
121+
}
122+
123+
@Override
124+
public boolean equals(Object o) {
125+
if (this == o) return true;
126+
if (o == null || getClass() != o.getClass()) return false;
127+
Ensemble that = (Ensemble) o;
128+
return Objects.equals(featureNames, that.featureNames)
129+
&& Objects.equals(models, that.models)
130+
&& Objects.equals(targetType, that.targetType)
131+
&& Objects.equals(classificationLabels, that.classificationLabels)
132+
&& Objects.equals(outputAggregator, that.outputAggregator);
133+
}
134+
135+
@Override
136+
public int hashCode() {
137+
return Objects.hash(featureNames, models, outputAggregator, classificationLabels, targetType);
138+
}
139+
140+
public static Builder builder() {
141+
return new Builder();
142+
}
143+
144+
public static class Builder {
145+
private List<String> featureNames;
146+
private List<TrainedModel> trainedModels;
147+
private OutputAggregator outputAggregator;
148+
private TargetType targetType;
149+
private List<String> classificationLabels;
150+
151+
public Builder setFeatureNames(List<String> featureNames) {
152+
this.featureNames = featureNames;
153+
return this;
154+
}
155+
156+
public Builder setTrainedModels(List<TrainedModel> trainedModels) {
157+
this.trainedModels = trainedModels;
158+
return this;
159+
}
160+
161+
public Builder setOutputAggregator(OutputAggregator outputAggregator) {
162+
this.outputAggregator = outputAggregator;
163+
return this;
164+
}
165+
166+
public Builder setTargetType(TargetType targetType) {
167+
this.targetType = targetType;
168+
return this;
169+
}
170+
171+
public Builder setClassificationLabels(List<String> classificationLabels) {
172+
this.classificationLabels = classificationLabels;
173+
return this;
174+
}
175+
176+
private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) {
177+
this.setOutputAggregator(outputAggregators.get(0));
178+
}
179+
180+
private void setTargetType(String targetType) {
181+
this.targetType = TargetType.fromString(targetType);
182+
}
183+
184+
public Ensemble build() {
185+
return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels);
186+
}
187+
}
188+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
20+
21+
import org.elasticsearch.client.ml.inference.NamedXContentObject;
22+
23+
public interface OutputAggregator extends NamedXContentObject {
24+
/**
25+
* @return The name of the output aggregator
26+
*/
27+
String getName();
28+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
20+
21+
22+
import org.elasticsearch.common.ParseField;
23+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
24+
import org.elasticsearch.common.xcontent.ToXContent;
25+
import org.elasticsearch.common.xcontent.XContentBuilder;
26+
import org.elasticsearch.common.xcontent.XContentParser;
27+
28+
import java.io.IOException;
29+
import java.util.List;
30+
import java.util.Objects;
31+
32+
33+
public class WeightedMode implements OutputAggregator {
34+
35+
public static final String NAME = "weighted_mode";
36+
public static final ParseField WEIGHTS = new ParseField("weights");
37+
38+
@SuppressWarnings("unchecked")
39+
private static final ConstructingObjectParser<WeightedMode, Void> PARSER = new ConstructingObjectParser<>(
40+
NAME,
41+
true,
42+
a -> new WeightedMode((List<Double>)a[0]));
43+
static {
44+
PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
45+
}
46+
47+
public static WeightedMode fromXContent(XContentParser parser) {
48+
return PARSER.apply(parser, null);
49+
}
50+
51+
private final List<Double> weights;
52+
53+
public WeightedMode(List<Double> weights) {
54+
this.weights = weights;
55+
}
56+
57+
@Override
58+
public String getName() {
59+
return NAME;
60+
}
61+
62+
@Override
63+
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
64+
builder.startObject();
65+
if (weights != null) {
66+
builder.field(WEIGHTS.getPreferredName(), weights);
67+
}
68+
builder.endObject();
69+
return builder;
70+
}
71+
72+
@Override
73+
public boolean equals(Object o) {
74+
if (this == o) return true;
75+
if (o == null || getClass() != o.getClass()) return false;
76+
WeightedMode that = (WeightedMode) o;
77+
return Objects.equals(weights, that.weights);
78+
}
79+
80+
@Override
81+
public int hashCode() {
82+
return Objects.hash(weights);
83+
}
84+
}

0 commit comments

Comments
 (0)