Skip to content

Commit 6806081

Browse files
authored
[ML] ML Model Inference Ingest Processor (#49052)
* [ML][Inference] adds lazy model loader and inference (#47410) This adds a couple of things: - A model loader service that is accessible via transport calls. This service will load in models and cache them. They will stay loaded until a processor no longer references them - A Model class and its first sub-class LocalModel. Used to cache model information and run inference. - Transport action and handler for requests to infer against a local model Related Feature PRs: * [ML][Inference] Adjust inference configuration option API (#47812) * [ML][Inference] adds logistic_regression output aggregator (#48075) * [ML][Inference] Adding read/del trained models (#47882) * [ML][Inference] Adding inference ingest processor (#47859) * [ML][Inference] fixing classification inference for ensemble (#48463) * [ML][Inference] Adding model memory estimations (#48323) * [ML][Inference] adding more options to inference processor (#48545) * [ML][Inference] handle string values better in feature extraction (#48584) * [ML][Inference] Adding _stats endpoint for inference (#48492) * [ML][Inference] add inference processors and trained models to usage (#47869) * [ML][Inference] add new flag for optionally including model definition (#48718) * [ML][Inference] adding license checks (#49056) * [ML][Inference] Adding memory and compute estimates to inference (#48955)
1 parent 0f6ffc2 commit 6806081

File tree

97 files changed

+7789
-296
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

97 files changed

+7789
-296
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.client.common.TimeUtil;
2323
import org.elasticsearch.common.ParseField;
2424
import org.elasticsearch.common.Strings;
25+
import org.elasticsearch.common.unit.ByteSizeValue;
2526
import org.elasticsearch.common.xcontent.ObjectParser;
2627
import org.elasticsearch.common.xcontent.ToXContentObject;
2728
import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -47,6 +48,8 @@ public class TrainedModelConfig implements ToXContentObject {
4748
public static final ParseField TAGS = new ParseField("tags");
4849
public static final ParseField METADATA = new ParseField("metadata");
4950
public static final ParseField INPUT = new ParseField("input");
51+
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
52+
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
5053

5154
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
5255
true,
@@ -66,6 +69,8 @@ public class TrainedModelConfig implements ToXContentObject {
6669
PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
6770
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
6871
PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT);
72+
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES);
73+
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
6974
}
7075

7176
public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException {
@@ -81,6 +86,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
8186
private final List<String> tags;
8287
private final Map<String, Object> metadata;
8388
private final TrainedModelInput input;
89+
private final Long estimatedHeapMemory;
90+
private final Long estimatedOperations;
8491

8592
TrainedModelConfig(String modelId,
8693
String createdBy,
@@ -90,7 +97,9 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
9097
TrainedModelDefinition definition,
9198
List<String> tags,
9299
Map<String, Object> metadata,
93-
TrainedModelInput input) {
100+
TrainedModelInput input,
101+
Long estimatedHeapMemory,
102+
Long estimatedOperations) {
94103
this.modelId = modelId;
95104
this.createdBy = createdBy;
96105
this.version = version;
@@ -100,6 +109,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
100109
this.tags = tags == null ? null : Collections.unmodifiableList(tags);
101110
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
102111
this.input = input;
112+
this.estimatedHeapMemory = estimatedHeapMemory;
113+
this.estimatedOperations = estimatedOperations;
103114
}
104115

105116
public String getModelId() {
@@ -138,6 +149,18 @@ public TrainedModelInput getInput() {
138149
return input;
139150
}
140151

152+
public ByteSizeValue getEstimatedHeapMemory() {
153+
return estimatedHeapMemory == null ? null : new ByteSizeValue(estimatedHeapMemory);
154+
}
155+
156+
public Long getEstimatedHeapMemoryBytes() {
157+
return estimatedHeapMemory;
158+
}
159+
160+
public Long getEstimatedOperations() {
161+
return estimatedOperations;
162+
}
163+
141164
public static Builder builder() {
142165
return new Builder();
143166
}
@@ -172,6 +195,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
172195
if (input != null) {
173196
builder.field(INPUT.getPreferredName(), input);
174197
}
198+
if (estimatedHeapMemory != null) {
199+
builder.field(ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), estimatedHeapMemory);
200+
}
201+
if (estimatedOperations != null) {
202+
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
203+
}
175204
builder.endObject();
176205
return builder;
177206
}
@@ -194,6 +223,8 @@ public boolean equals(Object o) {
194223
Objects.equals(definition, that.definition) &&
195224
Objects.equals(tags, that.tags) &&
196225
Objects.equals(input, that.input) &&
226+
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
227+
Objects.equals(estimatedOperations, that.estimatedOperations) &&
197228
Objects.equals(metadata, that.metadata);
198229
}
199230

@@ -206,6 +237,8 @@ public int hashCode() {
206237
definition,
207238
description,
208239
tags,
240+
estimatedHeapMemory,
241+
estimatedOperations,
209242
metadata,
210243
input);
211244
}
@@ -222,6 +255,8 @@ public static class Builder {
222255
private List<String> tags;
223256
private TrainedModelDefinition definition;
224257
private TrainedModelInput input;
258+
private Long estimatedHeapMemory;
259+
private Long estimatedOperations;
225260

226261
public Builder setModelId(String modelId) {
227262
this.modelId = modelId;
@@ -277,6 +312,16 @@ public Builder setInput(TrainedModelInput input) {
277312
return this;
278313
}
279314

315+
public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
316+
this.estimatedHeapMemory = estimatedHeapMemory;
317+
return this;
318+
}
319+
320+
public Builder setEstimatedOperations(Long estimatedOperations) {
321+
this.estimatedOperations = estimatedOperations;
322+
return this;
323+
}
324+
280325
public TrainedModelConfig build() {
281326
return new TrainedModelConfig(
282327
modelId,
@@ -287,7 +332,9 @@ public TrainedModelConfig build() {
287332
definition,
288333
tags,
289334
metadata,
290-
input);
335+
input,
336+
estimatedHeapMemory,
337+
estimatedOperations);
291338
}
292339
}
293340

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ protected TrainedModelConfig createTestInstance() {
6464
randomBoolean() ? null :
6565
Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
6666
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
67-
randomBoolean() ? null : TrainedModelInputTests.createRandomInput());
67+
randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
68+
randomBoolean() ? null : randomNonNegativeLong(),
69+
randomBoolean() ? null : randomNonNegativeLong());
70+
6871
}
6972

7073
@Override

server/src/main/java/org/elasticsearch/ingest/IngestStats.java

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import java.util.HashMap;
3434
import java.util.List;
3535
import java.util.Map;
36+
import java.util.Objects;
3637
import java.util.concurrent.TimeUnit;
3738

3839
public class IngestStats implements Writeable, ToXContentFragment {
@@ -146,6 +147,21 @@ public Map<String, List<ProcessorStat>> getProcessorStats() {
146147
return processorStats;
147148
}
148149

150+
@Override
151+
public boolean equals(Object o) {
152+
if (this == o) return true;
153+
if (o == null || getClass() != o.getClass()) return false;
154+
IngestStats that = (IngestStats) o;
155+
return Objects.equals(totalStats, that.totalStats)
156+
&& Objects.equals(pipelineStats, that.pipelineStats)
157+
&& Objects.equals(processorStats, that.processorStats);
158+
}
159+
160+
@Override
161+
public int hashCode() {
162+
return Objects.hash(totalStats, pipelineStats, processorStats);
163+
}
164+
149165
public static class Stats implements Writeable, ToXContentFragment {
150166

151167
private final long ingestCount;
@@ -214,6 +230,22 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
214230
builder.field("failed", ingestFailedCount);
215231
return builder;
216232
}
233+
234+
@Override
235+
public boolean equals(Object o) {
236+
if (this == o) return true;
237+
if (o == null || getClass() != o.getClass()) return false;
238+
IngestStats.Stats that = (IngestStats.Stats) o;
239+
return Objects.equals(ingestCount, that.ingestCount)
240+
&& Objects.equals(ingestTimeInMillis, that.ingestTimeInMillis)
241+
&& Objects.equals(ingestFailedCount, that.ingestFailedCount)
242+
&& Objects.equals(ingestCurrent, that.ingestCurrent);
243+
}
244+
245+
@Override
246+
public int hashCode() {
247+
return Objects.hash(ingestCount, ingestTimeInMillis, ingestFailedCount, ingestCurrent);
248+
}
217249
}
218250

219251
/**
@@ -266,6 +298,20 @@ public String getPipelineId() {
266298
public Stats getStats() {
267299
return stats;
268300
}
301+
302+
@Override
303+
public boolean equals(Object o) {
304+
if (this == o) return true;
305+
if (o == null || getClass() != o.getClass()) return false;
306+
IngestStats.PipelineStat that = (IngestStats.PipelineStat) o;
307+
return Objects.equals(pipelineId, that.pipelineId)
308+
&& Objects.equals(stats, that.stats);
309+
}
310+
311+
@Override
312+
public int hashCode() {
313+
return Objects.hash(pipelineId, stats);
314+
}
269315
}
270316

271317
/**
@@ -293,5 +339,21 @@ public String getType() {
293339
public Stats getStats() {
294340
return stats;
295341
}
342+
343+
344+
@Override
345+
public boolean equals(Object o) {
346+
if (this == o) return true;
347+
if (o == null || getClass() != o.getClass()) return false;
348+
IngestStats.ProcessorStat that = (IngestStats.ProcessorStat) o;
349+
return Objects.equals(name, that.name)
350+
&& Objects.equals(type, that.type)
351+
&& Objects.equals(stats, that.stats);
352+
}
353+
354+
@Override
355+
public int hashCode() {
356+
return Objects.hash(name, type, stats);
357+
}
296358
}
297359
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction;
7878
import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
7979
import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
80+
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
8081
import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction;
8182
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
8283
import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction;
@@ -98,6 +99,9 @@
9899
import org.elasticsearch.xpack.core.ml.action.GetModelSnapshotsAction;
99100
import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction;
100101
import org.elasticsearch.xpack.core.ml.action.GetRecordsAction;
102+
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
103+
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
104+
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
101105
import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction;
102106
import org.elasticsearch.xpack.core.ml.action.KillProcessAction;
103107
import org.elasticsearch.xpack.core.ml.action.MlInfoAction;
@@ -139,6 +143,19 @@
139143
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall;
140144
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult;
141145
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric;
146+
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
147+
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
148+
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
149+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
150+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
151+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
152+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
153+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
154+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LogisticRegression;
155+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
156+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
157+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
158+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
142159
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
143160
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
144161
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
@@ -329,6 +346,10 @@ public List<ActionType<? extends ActionResponse>> getClientActions() {
329346
StartDataFrameAnalyticsAction.INSTANCE,
330347
EvaluateDataFrameAction.INSTANCE,
331348
EstimateMemoryUsageAction.INSTANCE,
349+
InferModelAction.INSTANCE,
350+
GetTrainedModelsAction.INSTANCE,
351+
DeleteTrainedModelAction.INSTANCE,
352+
GetTrainedModelsStatsAction.INSTANCE,
332353
// security
333354
ClearRealmCacheAction.INSTANCE,
334355
ClearRolesCacheAction.INSTANCE,
@@ -464,6 +485,16 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
464485
new NamedWriteableRegistry.Entry(OutputAggregator.class,
465486
LogisticRegression.NAME.getPreferredName(),
466487
LogisticRegression::new),
488+
// ML - Inference Results
489+
new NamedWriteableRegistry.Entry(InferenceResults.class,
490+
ClassificationInferenceResults.NAME,
491+
ClassificationInferenceResults::new),
492+
new NamedWriteableRegistry.Entry(InferenceResults.class,
493+
RegressionInferenceResults.NAME,
494+
RegressionInferenceResults::new),
495+
// ML - Inference Configuration
496+
new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new),
497+
new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new),
467498

468499
// monitoring
469500
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new),

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,22 +29,26 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
2929
public static final String CREATED_BY = "created_by";
3030
public static final String NODE_COUNT = "node_count";
3131
public static final String DATA_FRAME_ANALYTICS_JOBS_FIELD = "data_frame_analytics_jobs";
32+
public static final String INFERENCE_FIELD = "inference";
3233

3334
private final Map<String, Object> jobsUsage;
3435
private final Map<String, Object> datafeedsUsage;
3536
private final Map<String, Object> analyticsUsage;
37+
private final Map<String, Object> inferenceUsage;
3638
private final int nodeCount;
3739

3840
public MachineLearningFeatureSetUsage(boolean available,
3941
boolean enabled,
4042
Map<String, Object> jobsUsage,
4143
Map<String, Object> datafeedsUsage,
4244
Map<String, Object> analyticsUsage,
45+
Map<String, Object> inferenceUsage,
4346
int nodeCount) {
4447
super(XPackField.MACHINE_LEARNING, available, enabled);
4548
this.jobsUsage = Objects.requireNonNull(jobsUsage);
4649
this.datafeedsUsage = Objects.requireNonNull(datafeedsUsage);
4750
this.analyticsUsage = Objects.requireNonNull(analyticsUsage);
51+
this.inferenceUsage = Objects.requireNonNull(inferenceUsage);
4852
this.nodeCount = nodeCount;
4953
}
5054

@@ -57,6 +61,11 @@ public MachineLearningFeatureSetUsage(StreamInput in) throws IOException {
5761
} else {
5862
this.analyticsUsage = Collections.emptyMap();
5963
}
64+
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
65+
this.inferenceUsage = in.readMap();
66+
} else {
67+
this.inferenceUsage = Collections.emptyMap();
68+
}
6069
this.nodeCount = in.readInt();
6170
}
6271

@@ -68,6 +77,9 @@ public void writeTo(StreamOutput out) throws IOException {
6877
if (out.getVersion().onOrAfter(Version.V_7_4_0)) {
6978
out.writeMap(analyticsUsage);
7079
}
80+
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
81+
out.writeMap(inferenceUsage);
82+
}
7183
out.writeInt(nodeCount);
7284
}
7385

@@ -77,6 +89,7 @@ protected void innerXContent(XContentBuilder builder, Params params) throws IOEx
7789
builder.field(JOBS_FIELD, jobsUsage);
7890
builder.field(DATAFEEDS_FIELD, datafeedsUsage);
7991
builder.field(DATA_FRAME_ANALYTICS_JOBS_FIELD, analyticsUsage);
92+
builder.field(INFERENCE_FIELD, inferenceUsage);
8093
if (nodeCount >= 0) {
8194
builder.field(NODE_COUNT, nodeCount);
8295
}

0 commit comments

Comments
 (0)