-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[ML][Inference] adds lazy model loader and inference #47410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML][Inference] adds lazy model loader and inference #47410
Conversation
Pinging @elastic/ml-core (:ml) |
ac1d0ab
to
dadfec1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are my comments regarding the inference results model objects. I still have to read through the models and the loading service.
public Request(String modelId, long modelVersion, List<Map<String, Object>> objectsToInfer, Integer topClasses) { | ||
this.modelId = modelId; | ||
this.modelVersion = modelVersion; | ||
this.objectsToInfer = objectsToInfer == null ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we tolerate null
for the objects to infer?
public Request(String modelId, long modelVersion, Map<String, Object> objectToInfer, Integer topClasses) { | ||
this(modelId, | ||
modelVersion, | ||
objectToInfer == null ? null : Arrays.asList(objectToInfer), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto regarding null tolerance.
Also prefer Collections.singletonList(objectToInfer)
.
this.modelId = in.readString(); | ||
this.modelVersion = in.readVLong(); | ||
this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap)); | ||
this.topClasses = in.readOptionalInt(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we have a default value for this so it's never null?
this.objectsToInfer = objectsToInfer == null ? | ||
Collections.emptyList() : | ||
Collections.unmodifiableList(objectsToInfer); | ||
this.cacheModel = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like it is hardcoded to true
at the moment. Just double-checking we still need it as part of the request.
|
||
} | ||
|
||
public static class RequestBuilder extends ActionRequestBuilder<Request, Response> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we no longer need to declare those builders. I just realized recently I've also been adding them in vain.
|
||
import java.io.IOException; | ||
|
||
public abstract class SingleValueInferenceResults implements InferenceResults<Double> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure this helps more than it confuses.
- its name is misleading as
value
is more generic thandouble
- it means classes that inherit this have to call their value field
value
. Not sure this is the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its name is misleading as
value
is more generic thandouble
I disagree. double
is exactly the numeric value we need. A whole double
would be returned via classification, and any double
could be returned via regression. This could be called SingleNumericValueInferenceResults
but I thought that was unnecessarily long winded.
When we support more complex models that return strings or a whole collection of options, they will have their own subclass that will cover those scenarios.
it means classes that inherit this have to call their value field
value
. Not sure this is the case.
Could you provide a counter example?
...n/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java
Show resolved
Hide resolved
import org.elasticsearch.common.io.stream.Writeable; | ||
import org.elasticsearch.common.xcontent.ToXContentObject; | ||
|
||
public interface InferenceResults<T> extends ToXContentObject, Writeable { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like us to consider an alternative idea here.
Right now this needs to be generic because of the T value()
method. The paradigm is that we call value()
to get the result. But how are we going to use this result?
I think eventually the result is an object we append on the object-to-infer, right?
Could we thus have:
Map<String, Object> result()
?
It could be we can find a better type than Map<String, Object>
. However, that would mean that each implementation of the results could be returning an object flexibly. Hard to discuss all this in text so I'm sure it warrants a nice design discussion!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think eventually the result is an object we append on the object-to-infer, right?
No, the result could be used any number of ways. We don't really append it to the mapped fields of the inference fields, we will supply it to the caller either via an API call, or through the ingest processor (which will have a target_field
parameter to tell us where to put the result).
Could we thus have:
Map<String, Object>
result()?
I would rather not pass around Map<String, Object>
. If we were going down that path, what is the point of having an object defined at all?
I honestly think having a generic T
covers our uses cases.
Regression: Always returns a single numeric
Classification: Could be numeric or string (depending on if we have field mapped values)
Future: Covers the more exotic cases of List
, Map
, etc. without sacrificing type-safety.
// Always fail immediately and return an error | ||
ex -> true); | ||
if (request.getTopClasses() != null) { | ||
request.getObjectsToInfer().forEach(stringObjectMap -> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we don't do infer
when topClasses is set? I was expecting that top classes would be additional info.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
infer
is a convenience method that (right now) either:
- Returns the top class
- Returns the regression value
topClasses
should only be set on classification models. If it is requested against a regression model, we throw an exception.
.../plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java
Show resolved
Hide resolved
.../plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java
Show resolved
Hide resolved
|
||
@Override | ||
public void infer(Map<String, Object> fields, ActionListener<InferenceResults<?>> listener) { | ||
trainedModelDefinition.getPreProcessors().forEach(preProcessor -> preProcessor.process(fields)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of exposing the preprocessors via a getPreProcessors()
method, we could have a TrainedModelDefinition.preprocess(Map<String, Object> fields)
method.
@Override | ||
public void infer(Map<String, Object> fields, ActionListener<InferenceResults<?>> listener) { | ||
trainedModelDefinition.getPreProcessors().forEach(preProcessor -> preProcessor.process(fields)); | ||
double value = trainedModelDefinition.getTrainedModel().infer(fields); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given both for infer()
and for classificationProbability()
we first preprocess, could we do the preprocessing privately in the underlying model so that we hide preprocessing from the calling code?
trainedModelDefinition.getPreProcessors().forEach(preProcessor -> preProcessor.process(fields)); | ||
double value = trainedModelDefinition.getTrainedModel().infer(fields); | ||
InferenceResults<?> inferenceResults; | ||
if (trainedModelDefinition.getTrainedModel().targetType() == TargetType.CLASSIFICATION) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This polymorphism-by-if is making me wonder if we're missing an abstraction in the design. Could we have a top level ClassificationModel
and RegressionModel
whose infer()
method takes away implementation details like this here? As LocalModel
is not aware of the model type, adding more model types would this if
here quite messy.
} | ||
|
||
@Override | ||
public void classificationProbability(Map<String, Object> fields, int topN, ActionListener<InferenceResults<?>> listener) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we have a classification-related method on LocalModel
? It seems to me LocalModel
should not be aware of the type of inference. I might be getting this all wrong.
|
||
public void getModel(String modelId, long modelVersion, ActionListener<Model> modelActionListener) { | ||
String key = modelKey(modelId, modelVersion); | ||
Optional<Model> cachedModel = loadedModels.get(key); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using loadedModels.getOrDefault(key, Optional.empty())
we can get rid of the null check here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will fix this. I will change the stored object to something like MaybeModel
that also has an exception stored along with it. The thing is, is that if we get the cached model, but there was some intermittent issue in loading it, we should probably just attempt to do it again.
*/ | ||
private boolean loadModelIfNecessary(String key, String modelId, long modelVersion, ActionListener<Model> modelActionListener) { | ||
synchronized (loadingListeners) { | ||
Optional<Model> cachedModel = loadedModels.get(key); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also use getOrDefault(key, Optional.empty())
} | ||
} | ||
if (listeners != null) { | ||
for(ActionListener<Model> listener = listeners.poll(); listener != null; listener = listeners.poll()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: space after for
* Returns false if the model is not loaded or actively being loaded | ||
*/ | ||
private boolean loadModelIfNecessary(String key, String modelId, long modelVersion, ActionListener<Model> modelActionListener) { | ||
synchronized (loadingListeners) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Synchronizing on the listeners means we are loading one model at a time. I assume that is fine, but can you foresee any performance issues?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure this means that we are loading one model at a time. This method does the following while in the synchronized
block
- Check if we have successfully loaded the model and get it.
- If there is a failed load attempt, create a new queue of listeners (adding the new listener to the queue) and kick of an asynchronous loading of the model and exit.
- If we have not attempted to load the model (failed or success), see if there are existing listeners (indicates a loading attempt is in progress) and add the new listener and exit.
Since load_model
is an asynchronous method, we exit the synchronized
block here and then enter another synchronized
block later in handleLoadSuccess
or handleLoadFailure
.
run elasticsearch-ci/2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's getting there now! Some more comments.
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java
Outdated
Show resolved
Hide resolved
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java
Outdated
Show resolved
Hide resolved
...e/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java
Show resolved
Hide resolved
...e/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java
Outdated
Show resolved
Hide resolved
classificationLabels); | ||
} | ||
|
||
int count = numToInclude < 0 ? probabilities.size() : numToInclude; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to have numToInclude < 0
? If not should we just have an assertion at the beginning of the method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dimitris-athanasiou I think allowing -1
to be include all
seems like a good option. What do you think?
...l/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java
Outdated
Show resolved
Hide resolved
...l/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java
Outdated
Show resolved
Hide resolved
ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> { | ||
Object processors = pipelineConfiguration.getConfigAsMap().get("processors"); | ||
if (processors instanceof List<?>) { | ||
for(Object processor : (List<?>)processors) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: space after for
private static Set<String> getReferencedModelKeys(IngestMetadata ingestMetadata) { | ||
Set<String> allReferencedModelKeys = new HashSet<>(); | ||
if (ingestMetadata != null) { | ||
ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we replace this with
Pipeline.create(newConfiguration.getId(), newConfiguration.getConfigAsMap(), processorFactories, scriptService);```
(As in line 544 of `IngestService`)
I would hope that'd give us a parsed Pipeline that we can then get the processor and filter down to `InferenceProcessors`.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dimitris-athanasiou sadly, I don't think this is possible.
The IngestService
requires a view of all the loaded Ingest plugins (of which ML will be one).
Since ModelLoadingService
is built within our createComponents
method, it does not have access to a constructed IngestService anywhere.
I looked into how to create the processorFactories
parameter myself, but again, it requires a view of all the loaded IngestPlugin
classes. Again, I cannot find a place where we have access to the complete list of loaded plugins on the node in the createComponents
method.
I think to support this type of thing, we will have to either:
A. Inject the IngestService into the createComponents
method
B. Inject the list of loaded plugins into the createComponents
method
It seems to me that A is the least invasive, but would require an update to a core method that every plugin uses...not sure it is worth it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sad indeed. Thanks for looking into it!
|
||
public static InferenceParams EMPTY_PARAMS = new InferenceParams(0); | ||
|
||
private final int numTopClasses; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now we are coupling InferenceParams
to classification since those are part of the request object we'll have BWC issues if we change this in the future. I think we should consider whether we should have a InferenceConfig
named writeable which lets us do this more nicely plus it gives us a chance to sanity check the model id matches the inference type the user expects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, I think this may clean up some execution paths here and there. Additionally, it will allow us to do the sanity check you mention.
I will complete this in a follow up PR. It would definitely add more LOC churn and this PR is already beefy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dimitris-athanasiou I wrote the changes up and am ready to open a new PR as soon as this one is closed for this change. It would have added ~200 LOC to this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* [ML][Inference] adds lazy model loader and inference (#47410) This adds a couple of things: - A model loader service that is accessible via transport calls. This service will load in models and cache them. They will stay loaded until a processor no longer references them - A Model class and its first sub-class LocalModel. Used to cache model information and run inference. - Transport action and handler for requests to infer against a local model Related Feature PRs: * [ML][Inference] Adjust inference configuration option API (#47812) * [ML][Inference] adds logistic_regression output aggregator (#48075) * [ML][Inference] Adding read/del trained models (#47882) * [ML][Inference] Adding inference ingest processor (#47859) * [ML][Inference] fixing classification inference for ensemble (#48463) * [ML][Inference] Adding model memory estimations (#48323) * [ML][Inference] adding more options to inference processor (#48545) * [ML][Inference] handle string values better in feature extraction (#48584) * [ML][Inference] Adding _stats endpoint for inference (#48492) * [ML][Inference] add inference processors and trained models to usage (#47869) * [ML][Inference] add new flag for optionally including model definition (#48718) * [ML][Inference] adding license checks (#49056) * [ML][Inference] Adding memory and compute estimates to inference (#48955)
* [ML][Inference] adds lazy model loader and inference (elastic#47410) This adds a couple of things: - A model loader service that is accessible via transport calls. This service will load in models and cache them. They will stay loaded until a processor no longer references them - A Model class and its first sub-class LocalModel. Used to cache model information and run inference. - Transport action and handler for requests to infer against a local model Related Feature PRs: * [ML][Inference] Adjust inference configuration option API (elastic#47812) * [ML][Inference] adds logistic_regression output aggregator (elastic#48075) * [ML][Inference] Adding read/del trained models (elastic#47882) * [ML][Inference] Adding inference ingest processor (elastic#47859) * [ML][Inference] fixing classification inference for ensemble (elastic#48463) * [ML][Inference] Adding model memory estimations (elastic#48323) * [ML][Inference] adding more options to inference processor (elastic#48545) * [ML][Inference] handle string values better in feature extraction (elastic#48584) * [ML][Inference] Adding _stats endpoint for inference (elastic#48492) * [ML][Inference] add inference processors and trained models to usage (elastic#47869) * [ML][Inference] add new flag for optionally including model definition (elastic#48718) * [ML][Inference] adding license checks (elastic#49056) * [ML][Inference] Adding memory and compute estimates to inference (elastic#48955)
* [ML] ML Model Inference Ingest Processor (#49052) * [ML][Inference] adds lazy model loader and inference (#47410) This adds a couple of things: - A model loader service that is accessible via transport calls. This service will load in models and cache them. They will stay loaded until a processor no longer references them - A Model class and its first sub-class LocalModel. Used to cache model information and run inference. - Transport action and handler for requests to infer against a local model Related Feature PRs: * [ML][Inference] Adjust inference configuration option API (#47812) * [ML][Inference] adds logistic_regression output aggregator (#48075) * [ML][Inference] Adding read/del trained models (#47882) * [ML][Inference] Adding inference ingest processor (#47859) * [ML][Inference] fixing classification inference for ensemble (#48463) * [ML][Inference] Adding model memory estimations (#48323) * [ML][Inference] adding more options to inference processor (#48545) * [ML][Inference] handle string values better in feature extraction (#48584) * [ML][Inference] Adding _stats endpoint for inference (#48492) * [ML][Inference] add inference processors and trained models to usage (#47869) * [ML][Inference] add new flag for optionally including model definition (#48718) * [ML][Inference] adding license checks (#49056) * [ML][Inference] Adding memory and compute estimates to inference (#48955) * fixing version of indexed docs for model inference
This adds a couple of things: