Skip to content

Commit 890b3db

Browse files
authored
[ML][Inference] adjusting definition object schema and validation (#47447)
* [ML][Inference] adjusting definition object schema and validation * finalizing schema and fixing inference npe * addressing PR comments
1 parent 446dcbe commit 890b3db

File tree

10 files changed

+459
-63
lines changed

10 files changed

+459
-63
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelDefinition.java

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
2323
import org.elasticsearch.common.ParseField;
2424
import org.elasticsearch.common.Strings;
25+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
2526
import org.elasticsearch.common.xcontent.ObjectParser;
2627
import org.elasticsearch.common.xcontent.ToXContentObject;
2728
import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -38,6 +39,7 @@ public class TrainedModelDefinition implements ToXContentObject {
3839

3940
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
4041
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
42+
public static final ParseField INPUT = new ParseField("input");
4143

4244
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
4345
true,
@@ -51,6 +53,7 @@ public class TrainedModelDefinition implements ToXContentObject {
5153
(p, c, n) -> p.namedObject(PreProcessor.class, n, null),
5254
(trainedModelDefBuilder) -> {/* Does not matter client side*/ },
5355
PREPROCESSORS);
56+
PARSER.declareObject(TrainedModelDefinition.Builder::setInput, (p, c) -> Input.fromXContent(p), INPUT);
5457
}
5558

5659
public static TrainedModelDefinition.Builder fromXContent(XContentParser parser) throws IOException {
@@ -59,10 +62,12 @@ public static TrainedModelDefinition.Builder fromXContent(XContentParser parser)
5962

6063
private final TrainedModel trainedModel;
6164
private final List<PreProcessor> preProcessors;
65+
private final Input input;
6266

63-
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
67+
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors, Input input) {
6468
this.trainedModel = trainedModel;
6569
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
70+
this.input = input;
6671
}
6772

6873
@Override
@@ -78,6 +83,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
7883
true,
7984
PREPROCESSORS.getPreferredName(),
8085
preProcessors);
86+
if (input != null) {
87+
builder.field(INPUT.getPreferredName(), input);
88+
}
8189
builder.endObject();
8290
return builder;
8391
}
@@ -90,6 +98,10 @@ public List<PreProcessor> getPreProcessors() {
9098
return preProcessors;
9199
}
92100

101+
public Input getInput() {
102+
return input;
103+
}
104+
93105
@Override
94106
public String toString() {
95107
return Strings.toString(this);
@@ -101,18 +113,20 @@ public boolean equals(Object o) {
101113
if (o == null || getClass() != o.getClass()) return false;
102114
TrainedModelDefinition that = (TrainedModelDefinition) o;
103115
return Objects.equals(trainedModel, that.trainedModel) &&
104-
Objects.equals(preProcessors, that.preProcessors) ;
116+
Objects.equals(preProcessors, that.preProcessors) &&
117+
Objects.equals(input, that.input);
105118
}
106119

107120
@Override
108121
public int hashCode() {
109-
return Objects.hash(trainedModel, preProcessors);
122+
return Objects.hash(trainedModel, preProcessors, input);
110123
}
111124

112125
public static class Builder {
113126

114127
private List<PreProcessor> preProcessors;
115128
private TrainedModel trainedModel;
129+
private Input input;
116130

117131
public Builder setPreProcessors(List<PreProcessor> preProcessors) {
118132
this.preProcessors = preProcessors;
@@ -124,14 +138,71 @@ public Builder setTrainedModel(TrainedModel trainedModel) {
124138
return this;
125139
}
126140

141+
public Builder setInput(Input input) {
142+
this.input = input;
143+
return this;
144+
}
145+
127146
private Builder setTrainedModel(List<TrainedModel> trainedModel) {
128147
assert trainedModel.size() == 1;
129148
return setTrainedModel(trainedModel.get(0));
130149
}
131150

132151
public TrainedModelDefinition build() {
133-
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
152+
return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.input);
153+
}
154+
}
155+
156+
public static class Input implements ToXContentObject {
157+
158+
public static final String NAME = "trained_mode_definition_input";
159+
public static final ParseField FIELD_NAMES = new ParseField("field_names");
160+
161+
@SuppressWarnings("unchecked")
162+
public static final ConstructingObjectParser<Input, Void> PARSER = new ConstructingObjectParser<>(NAME,
163+
true,
164+
a -> new Input((List<String>)a[0]));
165+
static {
166+
PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
167+
}
168+
169+
public static Input fromXContent(XContentParser parser) throws IOException {
170+
return PARSER.parse(parser, null);
134171
}
172+
173+
private final List<String> fieldNames;
174+
175+
public Input(List<String> fieldNames) {
176+
this.fieldNames = fieldNames;
177+
}
178+
179+
public List<String> getFieldNames() {
180+
return fieldNames;
181+
}
182+
183+
@Override
184+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
185+
builder.startObject();
186+
if (fieldNames != null) {
187+
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
188+
}
189+
builder.endObject();
190+
return builder;
191+
}
192+
193+
@Override
194+
public boolean equals(Object o) {
195+
if (this == o) return true;
196+
if (o == null || getClass() != o.getClass()) return false;
197+
TrainedModelDefinition.Input that = (TrainedModelDefinition.Input) o;
198+
return Objects.equals(fieldNames, that.fieldNames);
199+
}
200+
201+
@Override
202+
public int hashCode() {
203+
return Objects.hash(fieldNames);
204+
}
205+
135206
}
136207

137208
}

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/TargetMeanEncoding.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public class TargetMeanEncoding implements PreProcessor {
3939
public static final String NAME = "target_mean_encoding";
4040
public static final ParseField FIELD = new ParseField("field");
4141
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
42-
public static final ParseField TARGET_MEANS = new ParseField("target_means");
42+
public static final ParseField TARGET_MAP = new ParseField("target_map");
4343
public static final ParseField DEFAULT_VALUE = new ParseField("default_value");
4444

4545
@SuppressWarnings("unchecked")
@@ -52,7 +52,7 @@ public class TargetMeanEncoding implements PreProcessor {
5252
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
5353
PARSER.declareObject(ConstructingObjectParser.constructorArg(),
5454
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
55-
TARGET_MEANS);
55+
TARGET_MAP);
5656
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), DEFAULT_VALUE);
5757
}
5858

@@ -110,7 +110,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
110110
builder.startObject();
111111
builder.field(FIELD.getPreferredName(), field);
112112
builder.field(FEATURE_NAME.getPreferredName(), featureName);
113-
builder.field(TARGET_MEANS.getPreferredName(), meanMap);
113+
builder.field(TARGET_MAP.getPreferredName(), meanMap);
114114
builder.field(DEFAULT_VALUE.getPreferredName(), defaultValue);
115115
builder.endObject();
116116
return builder;

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ public static TrainedModelDefinition.Builder createRandomBuilder() {
6464
TargetMeanEncodingTests.createRandom()))
6565
.limit(numberOfProcessors)
6666
.collect(Collectors.toList()))
67-
.setTrainedModel(randomFrom(TreeTests.createRandom()));
67+
.setTrainedModel(randomFrom(TreeTests.createRandom()))
68+
.setInput(new TrainedModelDefinition.Input(Stream.generate(() -> randomAlphaOfLength(10))
69+
.limit(randomLongBetween(1, 10))
70+
.collect(Collectors.toList())));
6871
}
6972

7073
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java

Lines changed: 88 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.common.io.stream.StreamInput;
1111
import org.elasticsearch.common.io.stream.StreamOutput;
1212
import org.elasticsearch.common.io.stream.Writeable;
13+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1314
import org.elasticsearch.common.xcontent.ObjectParser;
1415
import org.elasticsearch.common.xcontent.ToXContentObject;
1516
import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -30,10 +31,11 @@
3031

3132
public class TrainedModelDefinition implements ToXContentObject, Writeable {
3233

33-
public static final String NAME = "trained_model_doc";
34+
public static final String NAME = "trained_mode_definition";
3435

3536
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
3637
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
38+
public static final ParseField INPUT = new ParseField("input");
3739

3840
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
3941
public static final ObjectParser<TrainedModelDefinition.Builder, Void> LENIENT_PARSER = createParser(true);
@@ -55,6 +57,7 @@ private static ObjectParser<TrainedModelDefinition.Builder, Void> createParser(b
5557
p.namedObject(StrictlyParsedPreProcessor.class, n, null),
5658
(trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
5759
PREPROCESSORS);
60+
parser.declareObject(TrainedModelDefinition.Builder::setInput, (p, c) -> Input.fromXContent(p, ignoreUnknownFields), INPUT);
5861
return parser;
5962
}
6063

@@ -64,21 +67,25 @@ public static TrainedModelDefinition.Builder fromXContent(XContentParser parser,
6467

6568
private final TrainedModel trainedModel;
6669
private final List<PreProcessor> preProcessors;
70+
private final Input input;
6771

68-
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
69-
this.trainedModel = trainedModel;
72+
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors, Input input) {
73+
this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL);
7074
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
75+
this.input = ExceptionsHelper.requireNonNull(input, INPUT);
7176
}
7277

7378
public TrainedModelDefinition(StreamInput in) throws IOException {
7479
this.trainedModel = in.readNamedWriteable(TrainedModel.class);
7580
this.preProcessors = in.readNamedWriteableList(PreProcessor.class);
81+
this.input = new Input(in);
7682
}
7783

7884
@Override
7985
public void writeTo(StreamOutput out) throws IOException {
8086
out.writeNamedWriteable(trainedModel);
8187
out.writeNamedWriteableList(preProcessors);
88+
input.writeTo(out);
8289
}
8390

8491
@Override
@@ -94,6 +101,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
94101
true,
95102
PREPROCESSORS.getPreferredName(),
96103
preProcessors);
104+
builder.field(INPUT.getPreferredName(), input);
97105
builder.endObject();
98106
return builder;
99107
}
@@ -106,6 +114,10 @@ public List<PreProcessor> getPreProcessors() {
106114
return preProcessors;
107115
}
108116

117+
public Input getInput() {
118+
return input;
119+
}
120+
109121
@Override
110122
public String toString() {
111123
return Strings.toString(this);
@@ -117,19 +129,21 @@ public boolean equals(Object o) {
117129
if (o == null || getClass() != o.getClass()) return false;
118130
TrainedModelDefinition that = (TrainedModelDefinition) o;
119131
return Objects.equals(trainedModel, that.trainedModel) &&
120-
Objects.equals(preProcessors, that.preProcessors) ;
132+
Objects.equals(input, that.input) &&
133+
Objects.equals(preProcessors, that.preProcessors);
121134
}
122135

123136
@Override
124137
public int hashCode() {
125-
return Objects.hash(trainedModel, preProcessors);
138+
return Objects.hash(trainedModel, input, preProcessors);
126139
}
127140

128141
public static class Builder {
129142

130143
private List<PreProcessor> preProcessors;
131144
private TrainedModel trainedModel;
132145
private boolean processorsInOrder;
146+
private Input input;
133147

134148
private static Builder builderForParser() {
135149
return new Builder(false);
@@ -153,6 +167,11 @@ public Builder setTrainedModel(TrainedModel trainedModel) {
153167
return this;
154168
}
155169

170+
public Builder setInput(Input input) {
171+
this.input = input;
172+
return this;
173+
}
174+
156175
private Builder setTrainedModel(List<TrainedModel> trainedModel) {
157176
if (trainedModel.size() != 1) {
158177
throw ExceptionsHelper.badRequestException("[{}] must have exactly one trained model defined.",
@@ -169,8 +188,71 @@ public TrainedModelDefinition build() {
169188
if (preProcessors != null && preProcessors.size() > 1 && processorsInOrder == false) {
170189
throw new IllegalArgumentException("preprocessors must be an array of preprocessor objects");
171190
}
172-
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
191+
return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.input);
173192
}
174193
}
175194

195+
public static class Input implements ToXContentObject, Writeable {
196+
197+
public static final String NAME = "trained_mode_definition_input";
198+
public static final ParseField FIELD_NAMES = new ParseField("field_names");
199+
200+
public static final ConstructingObjectParser<Input, Void> LENIENT_PARSER = createParser(true);
201+
public static final ConstructingObjectParser<Input, Void> STRICT_PARSER = createParser(false);
202+
203+
@SuppressWarnings("unchecked")
204+
private static ConstructingObjectParser<Input, Void> createParser(boolean ignoreUnknownFields) {
205+
ConstructingObjectParser<Input, Void> parser = new ConstructingObjectParser<>(NAME,
206+
ignoreUnknownFields,
207+
a -> new Input((List<String>)a[0]));
208+
parser.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
209+
return parser;
210+
}
211+
212+
public static Input fromXContent(XContentParser parser, boolean lenient) throws IOException {
213+
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
214+
}
215+
216+
private final List<String> fieldNames;
217+
218+
public Input(List<String> fieldNames) {
219+
this.fieldNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(fieldNames, FIELD_NAMES));
220+
}
221+
222+
public Input(StreamInput in) throws IOException {
223+
this.fieldNames = Collections.unmodifiableList(in.readStringList());
224+
}
225+
226+
public List<String> getFieldNames() {
227+
return fieldNames;
228+
}
229+
230+
@Override
231+
public void writeTo(StreamOutput out) throws IOException {
232+
out.writeStringCollection(fieldNames);
233+
}
234+
235+
@Override
236+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
237+
builder.startObject();
238+
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
239+
builder.endObject();
240+
return builder;
241+
}
242+
243+
@Override
244+
public boolean equals(Object o) {
245+
if (this == o) return true;
246+
if (o == null || getClass() != o.getClass()) return false;
247+
TrainedModelDefinition.Input that = (TrainedModelDefinition.Input) o;
248+
return Objects.equals(fieldNames, that.fieldNames);
249+
}
250+
251+
@Override
252+
public int hashCode() {
253+
return Objects.hash(fieldNames);
254+
}
255+
256+
}
257+
176258
}

0 commit comments

Comments
 (0)