Skip to content

Commit dadfec1

Browse files
committed
[ML][Inference] adds lazy model loader and inference
1 parent e2b9c1b commit dadfec1

File tree

12 files changed

+1146
-3
lines changed

12 files changed

+1146
-3
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ public List<String> getFeatureNames() {
107107

108108
@Override
109109
public double infer(Map<String, Object> fields) {
110-
List<Double> features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList());
110+
List<Double> features = featureNames.stream().map(f -> ((Number)fields.get(f)).doubleValue()).collect(Collectors.toList());
111111
return infer(features);
112112
}
113113

@@ -128,7 +128,7 @@ public List<Double> classificationProbability(Map<String, Object> fields) {
128128
throw new UnsupportedOperationException(
129129
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
130130
}
131-
List<Double> features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList());
131+
List<Double> features = featureNames.stream().map(f -> ((Number)fields.get(f)).doubleValue()).collect(Collectors.toList());
132132
return classificationProbability(features);
133133
}
134134

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@
161161
import org.elasticsearch.xpack.ml.action.TransportGetModelSnapshotsAction;
162162
import org.elasticsearch.xpack.ml.action.TransportGetOverallBucketsAction;
163163
import org.elasticsearch.xpack.ml.action.TransportGetRecordsAction;
164+
import org.elasticsearch.xpack.ml.action.TransportInferModelAction;
164165
import org.elasticsearch.xpack.ml.action.TransportIsolateDatafeedAction;
165166
import org.elasticsearch.xpack.ml.action.TransportKillProcessAction;
166167
import org.elasticsearch.xpack.ml.action.TransportMlInfoAction;
@@ -200,7 +201,10 @@
200201
import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory;
201202
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
202203
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
204+
import org.elasticsearch.xpack.ml.inference.action.InferModelAction;
205+
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
203206
import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex;
207+
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
204208
import org.elasticsearch.xpack.ml.job.JobManager;
205209
import org.elasticsearch.xpack.ml.job.JobManagerHolder;
206210
import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier;
@@ -495,6 +499,8 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
495499
notifier,
496500
xContentRegistry);
497501

502+
final TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry);
503+
final ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider);
498504
// special holder for @link(MachineLearningFeatureSetUsage) which needs access to job manager if ML is enabled
499505
JobManagerHolder jobManagerHolder = new JobManagerHolder(jobManager);
500506

@@ -607,7 +613,9 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
607613
analyticsProcessManager,
608614
memoryEstimationProcessManager,
609615
dataFrameAnalyticsConfigProvider,
610-
nativeStorageProvider
616+
nativeStorageProvider,
617+
modelLoadingService,
618+
trainedModelProvider
611619
);
612620
}
613621

@@ -762,6 +770,7 @@ public List<RestHandler> getRestHandlers(Settings settings, RestController restC
762770
new ActionHandler<>(StopDataFrameAnalyticsAction.INSTANCE, TransportStopDataFrameAnalyticsAction.class),
763771
new ActionHandler<>(EvaluateDataFrameAction.INSTANCE, TransportEvaluateDataFrameAction.class),
764772
new ActionHandler<>(EstimateMemoryUsageAction.INSTANCE, TransportEstimateMemoryUsageAction.class),
773+
new ActionHandler<>(InferModelAction.INSTANCE, TransportInferModelAction.class),
765774
usageAction,
766775
infoAction);
767776
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.ml.action;
7+
8+
import org.elasticsearch.action.ActionListener;
9+
import org.elasticsearch.action.support.ActionFilters;
10+
import org.elasticsearch.action.support.HandledTransportAction;
11+
import org.elasticsearch.client.Client;
12+
import org.elasticsearch.common.inject.Inject;
13+
import org.elasticsearch.tasks.Task;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.transport.TransportService;
16+
import org.elasticsearch.xpack.ml.inference.action.InferModelAction;
17+
import org.elasticsearch.xpack.ml.inference.loadingservice.Model;
18+
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
19+
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
20+
21+
import java.util.List;
22+
23+
public class TransportInferModelAction extends HandledTransportAction<InferModelAction.Request, InferModelAction.Response> {
24+
25+
private final ModelLoadingService modelLoadingService;
26+
private final Client client;
27+
28+
@Inject
29+
public TransportInferModelAction(String actionName,
30+
TransportService transportService,
31+
ActionFilters actionFilters,
32+
ModelLoadingService modelLoadingService,
33+
Client client) {
34+
super(actionName, transportService, actionFilters, InferModelAction.Request::new);
35+
this.modelLoadingService = modelLoadingService;
36+
this.client = client;
37+
}
38+
39+
@Override
40+
protected void doExecute(Task task, InferModelAction.Request request, ActionListener<InferModelAction.Response> listener) {
41+
42+
ActionListener<List<Object>> inferenceCompleteListener = ActionListener.wrap(
43+
inferenceResponse -> listener.onResponse(new InferModelAction.Response(inferenceResponse)),
44+
listener::onFailure
45+
);
46+
47+
ActionListener<Model> getModelListener = ActionListener.wrap(
48+
model -> {
49+
TypedChainTaskExecutor<Object> typedChainTaskExecutor =
50+
new TypedChainTaskExecutor<>(client.threadPool().executor(ThreadPool.Names.SAME),
51+
// run through all tasks
52+
r -> true,
53+
// Always fail immediately and return an error
54+
ex -> true);
55+
if (request.getTopClasses() != null) {
56+
request.getObjectsToInfer().forEach(stringObjectMap ->
57+
typedChainTaskExecutor.add(chainedTask -> model.confidence(stringObjectMap, request.getTopClasses(), chainedTask))
58+
);
59+
} else {
60+
request.getObjectsToInfer().forEach(stringObjectMap ->
61+
typedChainTaskExecutor.add(chainedTask -> model.infer(stringObjectMap, chainedTask))
62+
);
63+
}
64+
typedChainTaskExecutor.execute(inferenceCompleteListener);
65+
},
66+
listener::onFailure
67+
);
68+
69+
this.modelLoadingService.getModelAndCache(request.getModelId(), request.getModelVersion(), getModelListener);
70+
}
71+
}
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.ml.inference.action;
7+
8+
import org.elasticsearch.action.ActionRequest;
9+
import org.elasticsearch.action.ActionRequestBuilder;
10+
import org.elasticsearch.action.ActionRequestValidationException;
11+
import org.elasticsearch.action.ActionResponse;
12+
import org.elasticsearch.action.ActionType;
13+
import org.elasticsearch.client.ElasticsearchClient;
14+
import org.elasticsearch.common.io.stream.StreamInput;
15+
import org.elasticsearch.common.io.stream.StreamOutput;
16+
17+
import java.io.IOException;
18+
import java.util.ArrayList;
19+
import java.util.Collections;
20+
import java.util.List;
21+
import java.util.Map;
22+
import java.util.Objects;
23+
24+
public class InferModelAction extends ActionType<InferModelAction.Response> {
25+
26+
public static final InferModelAction INSTANCE = new InferModelAction();
27+
public static final String NAME = "cluster:admin/xpack/ml/infer";
28+
29+
private InferModelAction() {
30+
super(NAME, Response::new);
31+
}
32+
33+
public static class Request extends ActionRequest {
34+
35+
private final String modelId;
36+
private final long modelVersion;
37+
private final List<Map<String, Object>> objectsToInfer;
38+
private final boolean cacheModel;
39+
private final Integer topClasses;
40+
41+
public Request(String modelId, long modelVersion) {
42+
this(modelId, modelVersion, Collections.emptyList(), null);
43+
}
44+
45+
public Request(String modelId, long modelVersion, List<Map<String, Object>> objectsToInfer, Integer topClasses) {
46+
this.modelId = modelId;
47+
this.modelVersion = modelVersion;
48+
this.objectsToInfer = objectsToInfer == null ? Collections.emptyList() :
49+
Collections.unmodifiableList(new ArrayList<>(objectsToInfer));
50+
this.cacheModel = true;
51+
this.topClasses = topClasses;
52+
}
53+
54+
public Request(String modelId, long modelVersion, Map<String, Object> objectToInfer, Integer topClasses) {
55+
this(modelId,
56+
modelVersion,
57+
objectToInfer == null ? Collections.emptyList() : Collections.singletonList(objectToInfer),
58+
topClasses);
59+
}
60+
61+
public Request(StreamInput in) throws IOException {
62+
super(in);
63+
this.modelId = in.readString();
64+
this.modelVersion = in.readVLong();
65+
this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap));
66+
this.topClasses = in.readOptionalInt();
67+
this.cacheModel = in.readBoolean();
68+
}
69+
70+
public String getModelId() {
71+
return modelId;
72+
}
73+
74+
public long getModelVersion() {
75+
return modelVersion;
76+
}
77+
78+
public List<Map<String, Object>> getObjectsToInfer() {
79+
return objectsToInfer;
80+
}
81+
82+
public boolean isCacheModel() {
83+
return cacheModel;
84+
}
85+
86+
public Integer getTopClasses() {
87+
return topClasses;
88+
}
89+
90+
@Override
91+
public ActionRequestValidationException validate() {
92+
return null;
93+
}
94+
95+
@Override
96+
public void writeTo(StreamOutput out) throws IOException {
97+
super.writeTo(out);
98+
out.writeString(modelId);
99+
out.writeVLong(modelVersion);
100+
out.writeCollection(objectsToInfer, StreamOutput::writeMap);
101+
out.writeOptionalInt(topClasses);
102+
out.writeBoolean(cacheModel);
103+
}
104+
105+
@Override
106+
public boolean equals(Object o) {
107+
if (this == o) return true;
108+
if (o == null || getClass() != o.getClass()) return false;
109+
InferModelAction.Request that = (InferModelAction.Request) o;
110+
return Objects.equals(modelId, that.modelId)
111+
&& Objects.equals(modelVersion, that.modelVersion)
112+
&& Objects.equals(topClasses, that.topClasses)
113+
&& Objects.equals(cacheModel, that.cacheModel)
114+
&& Objects.equals(objectsToInfer, that.objectsToInfer);
115+
}
116+
117+
@Override
118+
public int hashCode() {
119+
return Objects.hash(modelId, modelVersion, objectsToInfer, topClasses, cacheModel);
120+
}
121+
122+
}
123+
124+
public static class RequestBuilder extends ActionRequestBuilder<Request, Response> {
125+
public RequestBuilder(ElasticsearchClient client, Request request) {
126+
super(client, INSTANCE, request);
127+
}
128+
}
129+
130+
public static class Response extends ActionResponse {
131+
132+
// TODO come up with a better union type object
133+
private final List<Object> inferenceResponse;
134+
135+
public Response(List<Object> inferenceResponse) {
136+
super();
137+
this.inferenceResponse = Collections.unmodifiableList(inferenceResponse);
138+
}
139+
140+
public Response(StreamInput in) throws IOException {
141+
super(in);
142+
this.inferenceResponse = Collections.unmodifiableList(in.readList(StreamInput::readGenericValue));
143+
}
144+
145+
public List<Object> getInferenceResponse() {
146+
return inferenceResponse;
147+
}
148+
149+
@Override
150+
public void writeTo(StreamOutput out) throws IOException {
151+
out.writeCollection(inferenceResponse, StreamOutput::writeGenericValue);
152+
}
153+
154+
@Override
155+
public boolean equals(Object o) {
156+
if (this == o) return true;
157+
if (o == null || getClass() != o.getClass()) return false;
158+
InferModelAction.Response that = (InferModelAction.Response) o;
159+
return Objects.equals(inferenceResponse, that.inferenceResponse);
160+
}
161+
162+
@Override
163+
public int hashCode() {
164+
return Objects.hash(inferenceResponse);
165+
}
166+
167+
}
168+
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.ml.inference.loadingservice;
7+
8+
import org.elasticsearch.ElasticsearchStatusException;
9+
import org.elasticsearch.action.ActionListener;
10+
import org.elasticsearch.rest.RestStatus;
11+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
12+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
13+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
14+
15+
import java.util.Collections;
16+
import java.util.Comparator;
17+
import java.util.HashMap;
18+
import java.util.List;
19+
import java.util.Map;
20+
import java.util.stream.Collectors;
21+
import java.util.stream.IntStream;
22+
23+
public class LocalModel implements Model {
24+
25+
private final TrainedModelDefinition trainedModelDefinition;
26+
public LocalModel(TrainedModelDefinition trainedModelDefinition) {
27+
this.trainedModelDefinition = trainedModelDefinition;
28+
}
29+
30+
@Override
31+
public void infer(Map<String, Object> fields, ActionListener<Object> listener) {
32+
trainedModelDefinition.getPreProcessors().forEach(preProcessor -> preProcessor.process(fields));
33+
double value = trainedModelDefinition.getTrainedModel().infer(fields);
34+
if (trainedModelDefinition.getTrainedModel().targetType() == TargetType.CLASSIFICATION &&
35+
trainedModelDefinition.getTrainedModel().classificationLabels() != null) {
36+
assert value == Math.rint(value);
37+
int classIndex = Double.valueOf(value).intValue();
38+
if (classIndex < 0 || classIndex >= trainedModelDefinition.getTrainedModel().classificationLabels().size()) {
39+
listener.onFailure(new ElasticsearchStatusException("model returned classification [{}] which is invalid given labels {}",
40+
RestStatus.INTERNAL_SERVER_ERROR,
41+
classIndex,
42+
trainedModelDefinition.getTrainedModel().classificationLabels()));
43+
return;
44+
}
45+
listener.onResponse(trainedModelDefinition.getTrainedModel().classificationLabels().get(classIndex));
46+
return;
47+
}
48+
listener.onResponse(Double.valueOf(value));
49+
}
50+
51+
@Override
52+
public void confidence(Map<String, Object> fields, int topN, ActionListener<Object> listener) {
53+
if (topN == 0) {
54+
listener.onResponse(Collections.emptyMap());
55+
return;
56+
}
57+
if (trainedModelDefinition.getTrainedModel().targetType() != TargetType.CLASSIFICATION) {
58+
listener.onFailure(ExceptionsHelper
59+
.badRequestException("top result probabilities is only available for classification models"));
60+
return;
61+
}
62+
trainedModelDefinition.getPreProcessors().forEach(preProcessor -> preProcessor.process(fields));
63+
List<Double> probabilities = trainedModelDefinition.getTrainedModel().classificationProbability(fields);
64+
int[] sortedIndices = IntStream.range(0, probabilities.size())
65+
.boxed()
66+
.sorted(Comparator.comparing(probabilities::get).reversed())
67+
.mapToInt(i -> i)
68+
.toArray();
69+
if (trainedModelDefinition.getTrainedModel().classificationLabels() != null) {
70+
if (probabilities.size() != trainedModelDefinition.getTrainedModel().classificationLabels().size()) {
71+
listener.onFailure(ExceptionsHelper
72+
.badRequestException(
73+
"model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]",
74+
probabilities.size(),
75+
trainedModelDefinition.getTrainedModel().classificationLabels()));
76+
return;
77+
}
78+
}
79+
List<String> labels = trainedModelDefinition.getTrainedModel().classificationLabels() == null ?
80+
// If we don't have the labels we should return the top classification values anyways, they will just be numeric
81+
IntStream.range(0, probabilities.size()).boxed().map(String::valueOf).collect(Collectors.toList()) :
82+
trainedModelDefinition.getTrainedModel().classificationLabels();
83+
84+
int count = topN < 0 ? probabilities.size() : topN;
85+
Map<String, Double> probabilityMap = new HashMap<>(count);
86+
for(int i = 0; i < count; i++) {
87+
int idx = sortedIndices[i];
88+
probabilityMap.put(labels.get(idx), probabilities.get(idx));
89+
}
90+
listener.onResponse(probabilityMap);
91+
}
92+
}

0 commit comments

Comments
 (0)