-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[ML][Inference] adds lazy model loader and inference #47410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
dadfec1
2d00f52
a4b7643
e2a98cd
ffaa0ed
8140446
f887781
8e2b563
6b795d6
c9a3b55
d92886b
7e1ab77
f753a22
94b83e4
d626a0d
7750906
a241410
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License; | ||
* you may not use this file except in compliance with the Elastic License. | ||
*/ | ||
package org.elasticsearch.xpack.core.ml.action; | ||
|
||
import org.elasticsearch.action.ActionRequest; | ||
import org.elasticsearch.action.ActionRequestBuilder; | ||
import org.elasticsearch.action.ActionRequestValidationException; | ||
import org.elasticsearch.action.ActionResponse; | ||
import org.elasticsearch.action.ActionType; | ||
import org.elasticsearch.client.ElasticsearchClient; | ||
import org.elasticsearch.common.io.stream.StreamInput; | ||
import org.elasticsearch.common.io.stream.StreamOutput; | ||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; | ||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; | ||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; | ||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; | ||
|
||
import java.io.IOException; | ||
import java.util.Arrays; | ||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
|
||
public class InferModelAction extends ActionType<InferModelAction.Response> { | ||
|
||
public static final InferModelAction INSTANCE = new InferModelAction(); | ||
public static final String NAME = "cluster:admin/xpack/ml/infer"; | ||
|
||
private InferModelAction() { | ||
super(NAME, Response::new); | ||
} | ||
|
||
public static class Request extends ActionRequest { | ||
|
||
private final String modelId; | ||
private final long modelVersion; | ||
private final List<Map<String, Object>> objectsToInfer; | ||
private final boolean cacheModel; | ||
private final Integer topClasses; | ||
|
||
public Request(String modelId, long modelVersion) { | ||
this(modelId, modelVersion, Collections.emptyList(), null); | ||
} | ||
|
||
public Request(String modelId, long modelVersion, List<Map<String, Object>> objectsToInfer, Integer topClasses) { | ||
this.modelId = modelId; | ||
this.modelVersion = modelVersion; | ||
this.objectsToInfer = objectsToInfer == null ? | ||
Collections.emptyList() : | ||
Collections.unmodifiableList(objectsToInfer); | ||
this.cacheModel = true; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like it is hardcoded to |
||
this.topClasses = topClasses; | ||
} | ||
|
||
public Request(String modelId, long modelVersion, Map<String, Object> objectToInfer, Integer topClasses) { | ||
this(modelId, | ||
modelVersion, | ||
objectToInfer == null ? null : Arrays.asList(objectToInfer), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto regarding null tolerance. Also prefer |
||
topClasses); | ||
} | ||
|
||
public Request(StreamInput in) throws IOException { | ||
super(in); | ||
this.modelId = in.readString(); | ||
this.modelVersion = in.readVLong(); | ||
this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap)); | ||
this.topClasses = in.readOptionalInt(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we have a default value for this so it's never null? |
||
this.cacheModel = in.readBoolean(); | ||
} | ||
|
||
public String getModelId() { | ||
return modelId; | ||
} | ||
|
||
public long getModelVersion() { | ||
return modelVersion; | ||
} | ||
|
||
public List<Map<String, Object>> getObjectsToInfer() { | ||
return objectsToInfer; | ||
} | ||
|
||
public boolean isCacheModel() { | ||
return cacheModel; | ||
} | ||
|
||
public Integer getTopClasses() { | ||
return topClasses; | ||
} | ||
|
||
@Override | ||
public ActionRequestValidationException validate() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeString(modelId); | ||
out.writeVLong(modelVersion); | ||
out.writeCollection(objectsToInfer, StreamOutput::writeMap); | ||
out.writeOptionalInt(topClasses); | ||
out.writeBoolean(cacheModel); | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
InferModelAction.Request that = (InferModelAction.Request) o; | ||
return Objects.equals(modelId, that.modelId) | ||
&& Objects.equals(modelVersion, that.modelVersion) | ||
&& Objects.equals(topClasses, that.topClasses) | ||
&& Objects.equals(cacheModel, that.cacheModel) | ||
&& Objects.equals(objectsToInfer, that.objectsToInfer); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(modelId, modelVersion, objectsToInfer, topClasses, cacheModel); | ||
} | ||
|
||
} | ||
|
||
public static class RequestBuilder extends ActionRequestBuilder<Request, Response> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we no longer need to declare those builders. I just realized recently I've also been adding them in vain. |
||
public RequestBuilder(ElasticsearchClient client, Request request) { | ||
super(client, INSTANCE, request); | ||
} | ||
} | ||
|
||
public static class Response extends ActionResponse { | ||
|
||
private final List<InferenceResults<?>> inferenceResponse; | ||
private final String resultsType; | ||
|
||
public Response(List<InferenceResults<?>> inferenceResponse, String resultsType) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we call |
||
super(); | ||
this.resultsType = ExceptionsHelper.requireNonNull(resultsType, "resultsType"); | ||
this.inferenceResponse = inferenceResponse == null ? | ||
Collections.emptyList() : | ||
Collections.unmodifiableList(inferenceResponse); | ||
} | ||
|
||
public Response(StreamInput in) throws IOException { | ||
super(in); | ||
this.resultsType = in.readString(); | ||
if(resultsType.equals(ClassificationInferenceResults.RESULT_TYPE)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: space after There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This polymorphism via There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does, but since I could maybe make |
||
this.inferenceResponse = Collections.unmodifiableList(in.readList(ClassificationInferenceResults::new)); | ||
} else if (this.resultsType.equals(RegressionInferenceResults.RESULT_TYPE)) { | ||
this.inferenceResponse = Collections.unmodifiableList(in.readList(RegressionInferenceResults::new)); | ||
} else { | ||
throw new IOException("Unrecognized result type [" + resultsType + "]"); | ||
} | ||
} | ||
|
||
public List<InferenceResults<?>> getInferenceResponse() { | ||
return inferenceResponse; | ||
} | ||
|
||
public String getResultsType() { | ||
return resultsType; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeString(resultsType); | ||
out.writeCollection(inferenceResponse); | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
InferModelAction.Response that = (InferModelAction.Response) o; | ||
return Objects.equals(resultsType, that.resultsType) && Objects.equals(inferenceResponse, that.inferenceResponse); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(resultsType, inferenceResponse); | ||
} | ||
|
||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License; | ||
* you may not use this file except in compliance with the Elastic License. | ||
*/ | ||
package org.elasticsearch.xpack.core.ml.inference.results; | ||
|
||
import org.elasticsearch.common.ParseField; | ||
import org.elasticsearch.common.io.stream.StreamInput; | ||
import org.elasticsearch.common.io.stream.StreamOutput; | ||
import org.elasticsearch.common.io.stream.Writeable; | ||
import org.elasticsearch.common.xcontent.ToXContentObject; | ||
import org.elasticsearch.common.xcontent.XContentBuilder; | ||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; | ||
|
||
import java.io.IOException; | ||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Objects; | ||
|
||
public class ClassificationInferenceResults extends SingleValueInferenceResults { | ||
|
||
public static final String RESULT_TYPE = "classification"; | ||
public static final ParseField CLASSIFICATION_LABEL = new ParseField("classification_label"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the predicted class? If yes should we call it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, this field is just the optional string label for the ordinal value returned from the trained model. Might be null as it is the specific string label for the numeric value returned via the model. If the user already transformed their classes into ordinal numerics, there is no label, just the numeric value. |
||
public static final ParseField TOP_CLASSES = new ParseField("top_classes"); | ||
|
||
private final String classificationLabel; | ||
private final List<TopClassEntry> topClasses; | ||
|
||
public ClassificationInferenceResults(double value, String classificationLabel, List<TopClassEntry> topClasses) { | ||
super(value); | ||
dimitris-athanasiou marked this conversation as resolved.
Show resolved
Hide resolved
|
||
this.classificationLabel = classificationLabel; | ||
dimitris-athanasiou marked this conversation as resolved.
Show resolved
Hide resolved
|
||
this.topClasses = topClasses == null ? null : Collections.unmodifiableList(topClasses); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we ensure |
||
} | ||
|
||
public ClassificationInferenceResults(StreamInput in) throws IOException { | ||
super(in); | ||
this.classificationLabel = in.readOptionalString(); | ||
if (in.readBoolean()) { | ||
this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new)); | ||
} else { | ||
this.topClasses = null; | ||
} | ||
} | ||
|
||
public String getClassificationLabel() { | ||
return classificationLabel; | ||
} | ||
|
||
public List<TopClassEntry> getTopClasses() { | ||
return topClasses; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeOptionalString(classificationLabel); | ||
out.writeBoolean(topClasses != null); | ||
if (topClasses != null) { | ||
out.writeCollection(topClasses); | ||
} | ||
} | ||
|
||
@Override | ||
XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { | ||
if (classificationLabel != null) { | ||
builder.field(CLASSIFICATION_LABEL.getPreferredName(), classificationLabel); | ||
} | ||
if (topClasses != null) { | ||
builder.field(TOP_CLASSES.getPreferredName(), topClasses); | ||
} | ||
return builder; | ||
} | ||
|
||
@Override | ||
public boolean equals(Object object) { | ||
if (object == this) { return true; } | ||
if (object == null || getClass() != object.getClass()) { return false; } | ||
ClassificationInferenceResults that = (ClassificationInferenceResults) object; | ||
return Objects.equals(value(), that.value()) && | ||
Objects.equals(classificationLabel, that.classificationLabel) && | ||
Objects.equals(topClasses, that.topClasses); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(value(), classificationLabel, topClasses); | ||
} | ||
|
||
@Override | ||
public String resultType() { | ||
return RESULT_TYPE; | ||
} | ||
|
||
@Override | ||
public String valueAsString() { | ||
return classificationLabel == null ? super.valueAsString() : classificationLabel; | ||
} | ||
|
||
public static class TopClassEntry implements ToXContentObject, Writeable { | ||
|
||
public final ParseField LABEL = new ParseField("label"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we consider calling this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I opted for |
||
public final ParseField PROBABILITY = new ParseField("probability"); | ||
|
||
private final String label; | ||
private final double probability; | ||
|
||
public TopClassEntry(String label, Double probability) { | ||
this.label = ExceptionsHelper.requireNonNull(label, LABEL); | ||
this.probability = ExceptionsHelper.requireNonNull(probability, PROBABILITY); | ||
} | ||
|
||
public TopClassEntry(StreamInput in) throws IOException { | ||
this.label = in.readString(); | ||
this.probability = in.readDouble(); | ||
} | ||
|
||
public String getLabel() { | ||
return label; | ||
} | ||
|
||
public double getProbability() { | ||
return probability; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeString(label); | ||
out.writeDouble(probability); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
builder.field(LABEL.getPreferredName(), label); | ||
builder.field(PROBABILITY.getPreferredName(), probability); | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
@Override | ||
public boolean equals(Object object) { | ||
if (object == this) { return true; } | ||
if (object == null || getClass() != object.getClass()) { return false; } | ||
TopClassEntry that = (TopClassEntry) object; | ||
return Objects.equals(label, that.label) && | ||
Objects.equals(probability, that.probability); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(label, probability); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License; | ||
* you may not use this file except in compliance with the Elastic License. | ||
*/ | ||
package org.elasticsearch.xpack.core.ml.inference.results; | ||
|
||
import org.elasticsearch.common.io.stream.Writeable; | ||
import org.elasticsearch.common.xcontent.ToXContentObject; | ||
|
||
public interface InferenceResults<T> extends ToXContentObject, Writeable { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would like us to consider an alternative idea here. Right now this needs to be generic because of the I think eventually the result is an object we append on the object-to-infer, right? Could we thus have:
It could be we can find a better type than There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
No, the result could be used any number of ways. We don't really append it to the mapped fields of the inference fields, we will supply it to the caller either via an API call, or through the ingest processor (which will have a
I would rather not pass around I honestly think having a generic Regression: Always returns a single numeric |
||
|
||
String resultType(); | ||
|
||
T value(); | ||
|
||
String valueAsString(); | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we tolerate
null
for the objects to infer?