Skip to content

Commit f4d4106

Browse files
authored
[ML][Inference] adds lazy model loader and inference (#47410)
This adds a couple of things: - A model loader service that is accessible via transport calls. This service will load in models and cache them. They will stay loaded until a processor no longer references them - A Model class and its first sub-class LocalModel. Used to cache model information and run inference. - Transport action and handler for requests to infer against a local model
1 parent 890b3db commit f4d4106

34 files changed

+2151
-122
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
import org.elasticsearch.xpack.core.ml.action.GetModelSnapshotsAction;
9999
import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction;
100100
import org.elasticsearch.xpack.core.ml.action.GetRecordsAction;
101+
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
101102
import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction;
102103
import org.elasticsearch.xpack.core.ml.action.KillProcessAction;
103104
import org.elasticsearch.xpack.core.ml.action.MlInfoAction;
@@ -139,7 +140,14 @@
139140
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall;
140141
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult;
141142
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric;
143+
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
144+
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
145+
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
142146
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
147+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
148+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
149+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
150+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
143151
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
144152
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
145153
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
@@ -323,6 +331,7 @@ public List<ActionType<? extends ActionResponse>> getClientActions() {
323331
StartDataFrameAnalyticsAction.INSTANCE,
324332
EvaluateDataFrameAction.INSTANCE,
325333
EstimateMemoryUsageAction.INSTANCE,
334+
InferModelAction.INSTANCE,
326335
// security
327336
ClearRealmCacheAction.INSTANCE,
328337
ClearRolesCacheAction.INSTANCE,
@@ -451,6 +460,17 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
451460
new NamedWriteableRegistry.Entry(PreProcessor.class, TargetMeanEncoding.NAME.getPreferredName(), TargetMeanEncoding::new),
452461
// ML - Inference models
453462
new NamedWriteableRegistry.Entry(TrainedModel.class, Tree.NAME.getPreferredName(), Tree::new),
463+
new NamedWriteableRegistry.Entry(TrainedModel.class, Ensemble.NAME.getPreferredName(), Ensemble::new),
464+
// ML - Inference aggregators
465+
new NamedWriteableRegistry.Entry(OutputAggregator.class, WeightedSum.NAME.getPreferredName(), WeightedSum::new),
466+
new NamedWriteableRegistry.Entry(OutputAggregator.class, WeightedMode.NAME.getPreferredName(), WeightedMode::new),
467+
// ML - Inference Results
468+
new NamedWriteableRegistry.Entry(InferenceResults.class,
469+
ClassificationInferenceResults.NAME,
470+
ClassificationInferenceResults::new),
471+
new NamedWriteableRegistry.Entry(InferenceResults.class,
472+
RegressionInferenceResults.NAME,
473+
RegressionInferenceResults::new),
454474

455475
// monitoring
456476
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new),
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.core.ml.action;
7+
8+
import org.elasticsearch.action.ActionRequest;
9+
import org.elasticsearch.action.ActionRequestValidationException;
10+
import org.elasticsearch.action.ActionResponse;
11+
import org.elasticsearch.action.ActionType;
12+
import org.elasticsearch.common.io.stream.StreamInput;
13+
import org.elasticsearch.common.io.stream.StreamOutput;
14+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
15+
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
16+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams;
17+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
18+
19+
import java.io.IOException;
20+
import java.util.Arrays;
21+
import java.util.Collections;
22+
import java.util.List;
23+
import java.util.Map;
24+
import java.util.Objects;
25+
26+
public class InferModelAction extends ActionType<InferModelAction.Response> {
27+
28+
public static final InferModelAction INSTANCE = new InferModelAction();
29+
public static final String NAME = "cluster:admin/xpack/ml/infer";
30+
31+
private InferModelAction() {
32+
super(NAME, Response::new);
33+
}
34+
35+
public static class Request extends ActionRequest {
36+
37+
private final String modelId;
38+
private final long modelVersion;
39+
private final List<Map<String, Object>> objectsToInfer;
40+
private final InferenceParams params;
41+
42+
public Request(String modelId, long modelVersion) {
43+
this(modelId, modelVersion, Collections.emptyList(), InferenceParams.EMPTY_PARAMS);
44+
}
45+
46+
public Request(String modelId, long modelVersion, List<Map<String, Object>> objectsToInfer, InferenceParams inferenceParams) {
47+
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
48+
this.modelVersion = modelVersion;
49+
this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, "objects_to_infer"));
50+
this.params = inferenceParams == null ? InferenceParams.EMPTY_PARAMS : inferenceParams;
51+
}
52+
53+
public Request(String modelId, long modelVersion, Map<String, Object> objectToInfer, InferenceParams params) {
54+
this(modelId,
55+
modelVersion,
56+
Arrays.asList(ExceptionsHelper.requireNonNull(objectToInfer, "objects_to_infer")),
57+
params);
58+
}
59+
60+
public Request(StreamInput in) throws IOException {
61+
super(in);
62+
this.modelId = in.readString();
63+
this.modelVersion = in.readVLong();
64+
this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap));
65+
this.params = new InferenceParams(in);
66+
}
67+
68+
public String getModelId() {
69+
return modelId;
70+
}
71+
72+
public long getModelVersion() {
73+
return modelVersion;
74+
}
75+
76+
public List<Map<String, Object>> getObjectsToInfer() {
77+
return objectsToInfer;
78+
}
79+
80+
public InferenceParams getParams() {
81+
return params;
82+
}
83+
84+
@Override
85+
public ActionRequestValidationException validate() {
86+
return null;
87+
}
88+
89+
@Override
90+
public void writeTo(StreamOutput out) throws IOException {
91+
super.writeTo(out);
92+
out.writeString(modelId);
93+
out.writeVLong(modelVersion);
94+
out.writeCollection(objectsToInfer, StreamOutput::writeMap);
95+
params.writeTo(out);
96+
}
97+
98+
@Override
99+
public boolean equals(Object o) {
100+
if (this == o) return true;
101+
if (o == null || getClass() != o.getClass()) return false;
102+
InferModelAction.Request that = (InferModelAction.Request) o;
103+
return Objects.equals(modelId, that.modelId)
104+
&& Objects.equals(modelVersion, that.modelVersion)
105+
&& Objects.equals(params, that.params)
106+
&& Objects.equals(objectsToInfer, that.objectsToInfer);
107+
}
108+
109+
@Override
110+
public int hashCode() {
111+
return Objects.hash(modelId, modelVersion, objectsToInfer, params);
112+
}
113+
114+
}
115+
116+
public static class Response extends ActionResponse {
117+
118+
private final List<InferenceResults> inferenceResults;
119+
120+
public Response(List<InferenceResults> inferenceResults) {
121+
super();
122+
this.inferenceResults = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(inferenceResults, "inferenceResults"));
123+
}
124+
125+
public Response(StreamInput in) throws IOException {
126+
super(in);
127+
this.inferenceResults = Collections.unmodifiableList(in.readNamedWriteableList(InferenceResults.class));
128+
}
129+
130+
public List<InferenceResults> getInferenceResults() {
131+
return inferenceResults;
132+
}
133+
134+
@Override
135+
public void writeTo(StreamOutput out) throws IOException {
136+
out.writeNamedWriteableList(inferenceResults);
137+
}
138+
139+
@Override
140+
public boolean equals(Object o) {
141+
if (this == o) return true;
142+
if (o == null || getClass() != o.getClass()) return false;
143+
InferModelAction.Response that = (InferModelAction.Response) o;
144+
return Objects.equals(inferenceResults, that.inferenceResults);
145+
}
146+
147+
@Override
148+
public int hashCode() {
149+
return Objects.hash(inferenceResults);
150+
}
151+
152+
}
153+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
99
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
1010
import org.elasticsearch.plugins.spi.NamedXContentProvider;
11+
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
12+
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
13+
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
1114
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
1215
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
1316
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
@@ -100,6 +103,14 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
100103
WeightedMode.NAME.getPreferredName(),
101104
WeightedMode::new));
102105

106+
// Inference Results
107+
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
108+
ClassificationInferenceResults.NAME,
109+
ClassificationInferenceResults::new));
110+
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
111+
RegressionInferenceResults.NAME,
112+
RegressionInferenceResults::new));
113+
103114
return namedWriteables;
104115
}
105116
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
1919
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
2020
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
21+
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
22+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams;
2123
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
2224
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
2325
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
@@ -27,6 +29,7 @@
2729
import java.io.IOException;
2830
import java.util.Collections;
2931
import java.util.List;
32+
import java.util.Map;
3033
import java.util.Objects;
3134

3235
public class TrainedModelDefinition implements ToXContentObject, Writeable {
@@ -118,6 +121,15 @@ public Input getInput() {
118121
return input;
119122
}
120123

124+
private void preProcess(Map<String, Object> fields) {
125+
preProcessors.forEach(preProcessor -> preProcessor.process(fields));
126+
}
127+
128+
public InferenceResults infer(Map<String, Object> fields, InferenceParams params) {
129+
preProcess(fields);
130+
return trainedModel.infer(fields, params);
131+
}
132+
121133
@Override
122134
public String toString() {
123135
return Strings.toString(this);

0 commit comments

Comments
 (0)