10
10
import org .elasticsearch .common .io .stream .StreamInput ;
11
11
import org .elasticsearch .common .io .stream .StreamOutput ;
12
12
import org .elasticsearch .common .io .stream .Writeable ;
13
+ import org .elasticsearch .common .xcontent .ConstructingObjectParser ;
13
14
import org .elasticsearch .common .xcontent .ObjectParser ;
14
15
import org .elasticsearch .common .xcontent .ToXContentObject ;
15
16
import org .elasticsearch .common .xcontent .XContentBuilder ;
30
31
31
32
public class TrainedModelDefinition implements ToXContentObject , Writeable {
32
33
33
- public static final String NAME = "trained_model_doc " ;
34
+ public static final String NAME = "trained_mode_definition " ;
34
35
35
36
public static final ParseField TRAINED_MODEL = new ParseField ("trained_model" );
36
37
public static final ParseField PREPROCESSORS = new ParseField ("preprocessors" );
38
+ public static final ParseField INPUT = new ParseField ("input" );
37
39
38
40
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
39
41
public static final ObjectParser <TrainedModelDefinition .Builder , Void > LENIENT_PARSER = createParser (true );
@@ -55,6 +57,7 @@ private static ObjectParser<TrainedModelDefinition.Builder, Void> createParser(b
55
57
p .namedObject (StrictlyParsedPreProcessor .class , n , null ),
56
58
(trainedModelDefBuilder ) -> trainedModelDefBuilder .setProcessorsInOrder (true ),
57
59
PREPROCESSORS );
60
+ parser .declareObject (TrainedModelDefinition .Builder ::setInput , (p , c ) -> Input .fromXContent (p , ignoreUnknownFields ), INPUT );
58
61
return parser ;
59
62
}
60
63
@@ -64,21 +67,25 @@ public static TrainedModelDefinition.Builder fromXContent(XContentParser parser,
64
67
65
68
private final TrainedModel trainedModel ;
66
69
private final List <PreProcessor > preProcessors ;
70
+ private final Input input ;
67
71
68
- TrainedModelDefinition (TrainedModel trainedModel , List <PreProcessor > preProcessors ) {
69
- this .trainedModel = trainedModel ;
72
+ TrainedModelDefinition (TrainedModel trainedModel , List <PreProcessor > preProcessors , Input input ) {
73
+ this .trainedModel = ExceptionsHelper . requireNonNull ( trainedModel , TRAINED_MODEL ) ;
70
74
this .preProcessors = preProcessors == null ? Collections .emptyList () : Collections .unmodifiableList (preProcessors );
75
+ this .input = ExceptionsHelper .requireNonNull (input , INPUT );
71
76
}
72
77
73
78
public TrainedModelDefinition (StreamInput in ) throws IOException {
74
79
this .trainedModel = in .readNamedWriteable (TrainedModel .class );
75
80
this .preProcessors = in .readNamedWriteableList (PreProcessor .class );
81
+ this .input = new Input (in );
76
82
}
77
83
78
84
@ Override
79
85
public void writeTo (StreamOutput out ) throws IOException {
80
86
out .writeNamedWriteable (trainedModel );
81
87
out .writeNamedWriteableList (preProcessors );
88
+ input .writeTo (out );
82
89
}
83
90
84
91
@ Override
@@ -94,6 +101,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
94
101
true ,
95
102
PREPROCESSORS .getPreferredName (),
96
103
preProcessors );
104
+ builder .field (INPUT .getPreferredName (), input );
97
105
builder .endObject ();
98
106
return builder ;
99
107
}
@@ -106,6 +114,10 @@ public List<PreProcessor> getPreProcessors() {
106
114
return preProcessors ;
107
115
}
108
116
117
+ public Input getInput () {
118
+ return input ;
119
+ }
120
+
109
121
@ Override
110
122
public String toString () {
111
123
return Strings .toString (this );
@@ -117,19 +129,21 @@ public boolean equals(Object o) {
117
129
if (o == null || getClass () != o .getClass ()) return false ;
118
130
TrainedModelDefinition that = (TrainedModelDefinition ) o ;
119
131
return Objects .equals (trainedModel , that .trainedModel ) &&
120
- Objects .equals (preProcessors , that .preProcessors ) ;
132
+ Objects .equals (input , that .input ) &&
133
+ Objects .equals (preProcessors , that .preProcessors );
121
134
}
122
135
123
136
@ Override
124
137
public int hashCode () {
125
- return Objects .hash (trainedModel , preProcessors );
138
+ return Objects .hash (trainedModel , input , preProcessors );
126
139
}
127
140
128
141
public static class Builder {
129
142
130
143
private List <PreProcessor > preProcessors ;
131
144
private TrainedModel trainedModel ;
132
145
private boolean processorsInOrder ;
146
+ private Input input ;
133
147
134
148
private static Builder builderForParser () {
135
149
return new Builder (false );
@@ -153,6 +167,11 @@ public Builder setTrainedModel(TrainedModel trainedModel) {
153
167
return this ;
154
168
}
155
169
170
+ public Builder setInput (Input input ) {
171
+ this .input = input ;
172
+ return this ;
173
+ }
174
+
156
175
private Builder setTrainedModel (List <TrainedModel > trainedModel ) {
157
176
if (trainedModel .size () != 1 ) {
158
177
throw ExceptionsHelper .badRequestException ("[{}] must have exactly one trained model defined." ,
@@ -169,8 +188,71 @@ public TrainedModelDefinition build() {
169
188
if (preProcessors != null && preProcessors .size () > 1 && processorsInOrder == false ) {
170
189
throw new IllegalArgumentException ("preprocessors must be an array of preprocessor objects" );
171
190
}
172
- return new TrainedModelDefinition (this .trainedModel , this .preProcessors );
191
+ return new TrainedModelDefinition (this .trainedModel , this .preProcessors , this . input );
173
192
}
174
193
}
175
194
195
+ public static class Input implements ToXContentObject , Writeable {
196
+
197
+ public static final String NAME = "trained_mode_definition_input" ;
198
+ public static final ParseField FIELD_NAMES = new ParseField ("field_names" );
199
+
200
+ public static final ConstructingObjectParser <Input , Void > LENIENT_PARSER = createParser (true );
201
+ public static final ConstructingObjectParser <Input , Void > STRICT_PARSER = createParser (false );
202
+
203
+ @ SuppressWarnings ("unchecked" )
204
+ private static ConstructingObjectParser <Input , Void > createParser (boolean ignoreUnknownFields ) {
205
+ ConstructingObjectParser <Input , Void > parser = new ConstructingObjectParser <>(NAME ,
206
+ ignoreUnknownFields ,
207
+ a -> new Input ((List <String >)a [0 ]));
208
+ parser .declareStringArray (ConstructingObjectParser .constructorArg (), FIELD_NAMES );
209
+ return parser ;
210
+ }
211
+
212
+ public static Input fromXContent (XContentParser parser , boolean lenient ) throws IOException {
213
+ return lenient ? LENIENT_PARSER .parse (parser , null ) : STRICT_PARSER .parse (parser , null );
214
+ }
215
+
216
+ private final List <String > fieldNames ;
217
+
218
+ public Input (List <String > fieldNames ) {
219
+ this .fieldNames = Collections .unmodifiableList (ExceptionsHelper .requireNonNull (fieldNames , FIELD_NAMES ));
220
+ }
221
+
222
+ public Input (StreamInput in ) throws IOException {
223
+ this .fieldNames = Collections .unmodifiableList (in .readStringList ());
224
+ }
225
+
226
+ public List <String > getFieldNames () {
227
+ return fieldNames ;
228
+ }
229
+
230
+ @ Override
231
+ public void writeTo (StreamOutput out ) throws IOException {
232
+ out .writeStringCollection (fieldNames );
233
+ }
234
+
235
+ @ Override
236
+ public XContentBuilder toXContent (XContentBuilder builder , Params params ) throws IOException {
237
+ builder .startObject ();
238
+ builder .field (FIELD_NAMES .getPreferredName (), fieldNames );
239
+ builder .endObject ();
240
+ return builder ;
241
+ }
242
+
243
+ @ Override
244
+ public boolean equals (Object o ) {
245
+ if (this == o ) return true ;
246
+ if (o == null || getClass () != o .getClass ()) return false ;
247
+ TrainedModelDefinition .Input that = (TrainedModelDefinition .Input ) o ;
248
+ return Objects .equals (fieldNames , that .fieldNames );
249
+ }
250
+
251
+ @ Override
252
+ public int hashCode () {
253
+ return Objects .hash (fieldNames );
254
+ }
255
+
256
+ }
257
+
176
258
}
0 commit comments