Skip to content

Commit adfa977

Browse files
[ML] Inference configs for NLP models (elastic#76350)
Introduce inference configs for NLP models. When a PyTorch model is put, the config now expects a different inference config per task type. Thus, we have a `ner`, `fill_mask`, and `sentiment_analysis` config. In addition, the tokenization parameters have been grouped together and are now part of the relevant inference config objects. Thus the vocabulary can now be on a document on its own. A new vocabulary config object allows the user to specify the location of the vocabulary document.
1 parent e2b5eaf commit adfa977

File tree

50 files changed

+1344
-486
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1344
-486
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/IndexLocation.java

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,33 +19,25 @@
1919
public class IndexLocation implements TrainedModelLocation {
2020

2121
public static final String INDEX = "index";
22-
private static final ParseField MODEL_ID = new ParseField("model_id");
2322
private static final ParseField NAME = new ParseField("name");
2423

2524
private static final ConstructingObjectParser<IndexLocation, Void> PARSER =
26-
new ConstructingObjectParser<>(INDEX, true, a -> new IndexLocation((String) a[0], (String) a[1]));
25+
new ConstructingObjectParser<>(INDEX, true, a -> new IndexLocation((String) a[0]));
2726

2827
static {
29-
PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID);
3028
PARSER.declareString(ConstructingObjectParser.constructorArg(), NAME);
3129
}
3230

3331
public static IndexLocation fromXContent(XContentParser parser) throws IOException {
3432
return PARSER.parse(parser, null);
3533
}
3634

37-
private final String modelId;
3835
private final String index;
3936

40-
public IndexLocation(String modelId, String index) {
41-
this.modelId = Objects.requireNonNull(modelId);
37+
public IndexLocation(String index) {
4238
this.index = Objects.requireNonNull(index);
4339
}
4440

45-
public String getModelId() {
46-
return modelId;
47-
}
48-
4941
public String getIndex() {
5042
return index;
5143
}
@@ -59,7 +51,6 @@ public String getName() {
5951
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
6052
builder.startObject();
6153
builder.field(NAME.getPreferredName(), index);
62-
builder.field(MODEL_ID.getPreferredName(), modelId);
6354
builder.endObject();
6455
return builder;
6556
}
@@ -73,12 +64,11 @@ public boolean equals(Object o) {
7364
return false;
7465
}
7566
IndexLocation that = (IndexLocation) o;
76-
return Objects.equals(modelId, that.modelId)
77-
&& Objects.equals(index, that.index);
67+
return Objects.equals(index, that.index);
7868
}
7969

8070
@Override
8171
public int hashCode() {
82-
return Objects.hash(modelId, index);
72+
return Objects.hash(index);
8373
}
8474
}

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/IndexLocationTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
public class IndexLocationTests extends AbstractXContentTestCase<IndexLocation> {
1818

1919
static IndexLocation randomInstance() {
20-
return new IndexLocation(randomAlphaOfLength(7), randomAlphaOfLength(7));
20+
return new IndexLocation(randomAlphaOfLength(7));
2121
}
2222

2323
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,19 @@
1313
import org.elasticsearch.action.support.master.MasterNodeRequest;
1414
import org.elasticsearch.cluster.node.DiscoveryNode;
1515
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
16+
import org.elasticsearch.common.Strings;
17+
import org.elasticsearch.common.io.stream.StreamInput;
18+
import org.elasticsearch.common.io.stream.StreamOutput;
1619
import org.elasticsearch.common.io.stream.Writeable;
1720
import org.elasticsearch.common.unit.ByteSizeValue;
1821
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1922
import org.elasticsearch.common.xcontent.ParseField;
20-
import org.elasticsearch.common.Strings;
21-
import org.elasticsearch.common.io.stream.StreamInput;
22-
import org.elasticsearch.common.io.stream.StreamOutput;
23-
import org.elasticsearch.common.xcontent.XContentParser;
24-
import org.elasticsearch.core.TimeValue;
2523
import org.elasticsearch.common.xcontent.ToXContentObject;
2624
import org.elasticsearch.common.xcontent.XContentBuilder;
25+
import org.elasticsearch.common.xcontent.XContentParser;
26+
import org.elasticsearch.core.TimeValue;
2727
import org.elasticsearch.tasks.Task;
2828
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
29-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
3029
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
3130
import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;
3231

@@ -137,11 +136,10 @@ public static boolean mayAllocateToNode(DiscoveryNode node) {
137136
private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
138137
"trained_model_deployment_params",
139138
true,
140-
a -> new TaskParams((String)a[0], (String)a[1], (Long)a[2])
139+
a -> new TaskParams((String)a[0], (Long)a[1])
141140
);
142141
static {
143142
PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
144-
PARSER.declareString(ConstructingObjectParser.constructorArg(), IndexLocation.INDEX);
145143
PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES);
146144
}
147145

@@ -157,12 +155,10 @@ public static TaskParams fromXContent(XContentParser parser) {
157155
private static final ByteSizeValue MEMORY_OVERHEAD = ByteSizeValue.ofMb(270);
158156

159157
private final String modelId;
160-
private final String index;
161158
private final long modelBytes;
162159

163-
public TaskParams(String modelId, String index, long modelBytes) {
160+
public TaskParams(String modelId, long modelBytes) {
164161
this.modelId = Objects.requireNonNull(modelId);
165-
this.index = Objects.requireNonNull(index);
166162
this.modelBytes = modelBytes;
167163
if (modelBytes < 0) {
168164
throw new IllegalArgumentException("modelBytes must be non-negative");
@@ -171,18 +167,13 @@ public TaskParams(String modelId, String index, long modelBytes) {
171167

172168
public TaskParams(StreamInput in) throws IOException {
173169
this.modelId = in.readString();
174-
this.index = in.readString();
175170
this.modelBytes = in.readVLong();
176171
}
177172

178173
public String getModelId() {
179174
return modelId;
180175
}
181176

182-
public String getIndex() {
183-
return index;
184-
}
185-
186177
public long estimateMemoryUsageBytes() {
187178
// While loading the model in the process we need twice the model size.
188179
return MEMORY_OVERHEAD.getBytes() + 2 * modelBytes;
@@ -195,23 +186,21 @@ public Version getMinimalSupportedVersion() {
195186
@Override
196187
public void writeTo(StreamOutput out) throws IOException {
197188
out.writeString(modelId);
198-
out.writeString(index);
199189
out.writeVLong(modelBytes);
200190
}
201191

202192
@Override
203193
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
204194
builder.startObject();
205195
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
206-
builder.field(IndexLocation.INDEX.getPreferredName(), index);
207196
builder.field(MODEL_BYTES.getPreferredName(), modelBytes);
208197
builder.endObject();
209198
return builder;
210199
}
211200

212201
@Override
213202
public int hashCode() {
214-
return Objects.hash(modelId, index, modelBytes);
203+
return Objects.hash(modelId, modelBytes);
215204
}
216205

217206
@Override
@@ -221,7 +210,6 @@ public boolean equals(Object o) {
221210

222211
TaskParams other = (TaskParams) o;
223212
return Objects.equals(modelId, other.modelId)
224-
&& Objects.equals(index, other.index)
225213
&& modelBytes == other.modelBytes;
226214
}
227215

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1010
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
11+
import org.elasticsearch.common.xcontent.ParseField;
1112
import org.elasticsearch.plugins.spi.NamedXContentProvider;
1213
import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding;
1314
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
@@ -26,22 +27,26 @@
2627
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
2728
import org.elasticsearch.xpack.core.ml.inference.results.SentimentAnalysisResults;
2829
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
30+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertPassThroughConfig;
2931
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
3032
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
3133
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
34+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
3235
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
3336
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
3437
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
3538
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig;
3639
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
3740
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModelLocation;
41+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
3842
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
3943
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
4044
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
45+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.SentimentAnalysisConfig;
4146
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
47+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
4248
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModelLocation;
4349
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
44-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
4550
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
4651
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
4752
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Exponent;
@@ -155,6 +160,22 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
155160
RegressionConfig::fromXContentLenient));
156161
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedInferenceConfig.class, RegressionConfig.NAME,
157162
RegressionConfig::fromXContentStrict));
163+
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedInferenceConfig.class, new ParseField(NerConfig.NAME),
164+
NerConfig::fromXContentLenient));
165+
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedInferenceConfig.class, new ParseField(NerConfig.NAME),
166+
NerConfig::fromXContentStrict));
167+
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedInferenceConfig.class, new ParseField(FillMaskConfig.NAME),
168+
FillMaskConfig::fromXContentLenient));
169+
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedInferenceConfig.class, new ParseField(FillMaskConfig.NAME),
170+
FillMaskConfig::fromXContentStrict));
171+
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedInferenceConfig.class,
172+
new ParseField(SentimentAnalysisConfig.NAME), SentimentAnalysisConfig::fromXContentLenient));
173+
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedInferenceConfig.class, new ParseField(SentimentAnalysisConfig.NAME),
174+
SentimentAnalysisConfig::fromXContentStrict));
175+
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedInferenceConfig.class,
176+
new ParseField(BertPassThroughConfig.NAME), BertPassThroughConfig::fromXContentLenient));
177+
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedInferenceConfig.class, new ParseField(BertPassThroughConfig.NAME),
178+
BertPassThroughConfig::fromXContentStrict));
158179

159180
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, ClassificationConfigUpdate.NAME,
160181
ClassificationConfigUpdate::fromXContentStrict));
@@ -237,6 +258,14 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
237258
ClassificationConfig.NAME.getPreferredName(), ClassificationConfig::new));
238259
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
239260
RegressionConfig.NAME.getPreferredName(), RegressionConfig::new));
261+
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
262+
NerConfig.NAME, NerConfig::new));
263+
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
264+
FillMaskConfig.NAME, FillMaskConfig::new));
265+
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
266+
SentimentAnalysisConfig.NAME, SentimentAnalysisConfig::new));
267+
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
268+
BertPassThroughConfig.NAME, BertPassThroughConfig::new));
240269

241270
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class,
242271
ClassificationConfigUpdate.NAME.getPreferredName(), ClassificationConfigUpdate::new));

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo
181181
this.description = description;
182182
this.tags = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(tags, TAGS));
183183
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
184-
this.input = ExceptionsHelper.requireNonNull(input, INPUT);
184+
this.input = ExceptionsHelper.requireNonNull(handleDefaultInput(input, modelType), INPUT);
185185
if (ExceptionsHelper.requireNonNull(estimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES) < 0) {
186186
throw new IllegalArgumentException(
187187
"[" + ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName() + "] must be greater than or equal to 0");
@@ -197,6 +197,13 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo
197197
this.location = location;
198198
}
199199

200+
private static TrainedModelInput handleDefaultInput(TrainedModelInput input, TrainedModelType modelType) {
201+
if (modelType == null) {
202+
return input;
203+
}
204+
return input == null ? modelType.getDefaultInput() : input;
205+
}
206+
200207
public TrainedModelConfig(StreamInput in) throws IOException {
201208
modelId = in.readString();
202209
createdBy = in.readString();

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelType.java

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,20 @@
77

88
package org.elasticsearch.xpack.core.ml.inference;
99

10+
import org.elasticsearch.core.Nullable;
1011
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
1112
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
1213
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
1314
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
1415

16+
import java.util.Collections;
1517
import java.util.Locale;
1618

1719
public enum TrainedModelType {
1820

19-
TREE_ENSEMBLE,
20-
LANG_IDENT,
21-
PYTORCH {
22-
@Override
23-
public boolean hasInferenceDefinition() {
24-
return false;
25-
}
26-
};
21+
TREE_ENSEMBLE(null),
22+
LANG_IDENT(null),
23+
PYTORCH(new TrainedModelInput(Collections.singletonList("input")));
2724

2825
public static TrainedModelType fromString(String name) {
2926
return valueOf(name.trim().toUpperCase(Locale.ROOT));
@@ -45,12 +42,19 @@ public static TrainedModelType typeFromTrainedModel(TrainedModel model) {
4542
}
4643
}
4744

48-
public boolean hasInferenceDefinition() {
49-
return true;
45+
private final TrainedModelInput defaultInput;
46+
47+
TrainedModelType(@Nullable TrainedModelInput defaultInput) {
48+
this.defaultInput =defaultInput;
5049
}
5150

5251
@Override
5352
public String toString() {
5453
return name().toLowerCase(Locale.ROOT);
5554
}
55+
56+
@Nullable
57+
public TrainedModelInput getDefaultInput() {
58+
return defaultInput;
59+
}
5660
}

0 commit comments

Comments
 (0)