From f4d410630015a13612b587ae3fb39fd436cf194d Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 9 Oct 2019 11:27:04 -0400 Subject: [PATCH 01/17] [ML][Inference] adds lazy model loader and inference (#47410) This adds a couple of things: - A model loader service that is accessible via transport calls. This service will load in models and cache them. They will stay loaded until a processor no longer references them - A Model class and its first sub-class LocalModel. Used to cache model information and run inference. - Transport action and handler for requests to infer against a local model --- .../xpack/core/XPackClientPlugin.java | 20 ++ .../core/ml/action/InferModelAction.java | 153 ++++++++++ .../MlInferenceNamedXContentProvider.java | 11 + .../ml/inference/TrainedModelDefinition.java | 12 + .../ClassificationInferenceResults.java | 175 +++++++++++ .../inference/results/InferenceResults.java | 16 + .../results/RegressionInferenceResults.java | 68 +++++ .../results/SingleValueInferenceResults.java | 51 ++++ .../trainedmodel/InferenceHelpers.java | 75 +++++ .../trainedmodel/InferenceParams.java | 65 ++++ .../inference/trainedmodel/TrainedModel.java | 30 +- .../trainedmodel/ensemble/Ensemble.java | 60 ++-- .../ensemble/OutputAggregator.java | 2 + .../trainedmodel/ensemble/WeightedMode.java | 5 + .../trainedmodel/ensemble/WeightedSum.java | 5 + .../ml/inference/trainedmodel/tree/Tree.java | 60 ++-- .../core/ml/inference/utils/Statistics.java | 18 +- .../xpack/core/ml/utils/ExceptionsHelper.java | 4 + .../action/InferModelActionRequestTests.java | 47 +++ .../action/InferModelActionResponseTests.java | 58 ++++ .../ClassificationInferenceResultsTests.java | 83 +++++ .../RegressionInferenceResultsTests.java | 41 +++ .../trainedmodel/InferenceParamsTests.java | 27 ++ .../trainedmodel/ensemble/EnsembleTests.java | 56 ++-- .../trainedmodel/tree/TreeTests.java | 47 ++- .../xpack/ml/MachineLearning.java | 11 +- .../ml/action/TransportInferModelAction.java | 64 ++++ .../inference/ingest/InferenceProcessor.java | 41 +++ .../inference/loadingservice/LocalModel.java | 51 ++++ .../ml/inference/loadingservice/Model.java | 20 ++ .../loadingservice/ModelLoadingService.java | 287 ++++++++++++++++++ .../loadingservice/LocalModelTests.java | 213 +++++++++++++ .../ModelLoadingServiceTests.java | 212 +++++++++++++ .../integration/ModelInferenceActionIT.java | 185 +++++++++++ 34 files changed, 2151 insertions(+), 122 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParamsTests.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index b74bf82466eac..2c8553d68fe98 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -98,6 +98,7 @@ import org.elasticsearch.xpack.core.ml.action.GetModelSnapshotsAction; import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction; import org.elasticsearch.xpack.core.ml.action.GetRecordsAction; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction; import org.elasticsearch.xpack.core.ml.action.KillProcessAction; import org.elasticsearch.xpack.core.ml.action.MlInfoAction; @@ -139,7 +140,14 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; @@ -323,6 +331,7 @@ public List> getClientActions() { StartDataFrameAnalyticsAction.INSTANCE, EvaluateDataFrameAction.INSTANCE, EstimateMemoryUsageAction.INSTANCE, + InferModelAction.INSTANCE, // security ClearRealmCacheAction.INSTANCE, ClearRolesCacheAction.INSTANCE, @@ -451,6 +460,17 @@ public List getNamedWriteables() { new NamedWriteableRegistry.Entry(PreProcessor.class, TargetMeanEncoding.NAME.getPreferredName(), TargetMeanEncoding::new), // ML - Inference models new NamedWriteableRegistry.Entry(TrainedModel.class, Tree.NAME.getPreferredName(), Tree::new), + new NamedWriteableRegistry.Entry(TrainedModel.class, Ensemble.NAME.getPreferredName(), Ensemble::new), + // ML - Inference aggregators + new NamedWriteableRegistry.Entry(OutputAggregator.class, WeightedSum.NAME.getPreferredName(), WeightedSum::new), + new NamedWriteableRegistry.Entry(OutputAggregator.class, WeightedMode.NAME.getPreferredName(), WeightedMode::new), + // ML - Inference Results + new NamedWriteableRegistry.Entry(InferenceResults.class, + ClassificationInferenceResults.NAME, + ClassificationInferenceResults::new), + new NamedWriteableRegistry.Entry(InferenceResults.class, + RegressionInferenceResults.NAME, + RegressionInferenceResults::new), // monitoring new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java new file mode 100644 index 0000000000000..248d6180d3256 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java @@ -0,0 +1,153 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class InferModelAction extends ActionType { + + public static final InferModelAction INSTANCE = new InferModelAction(); + public static final String NAME = "cluster:admin/xpack/ml/infer"; + + private InferModelAction() { + super(NAME, Response::new); + } + + public static class Request extends ActionRequest { + + private final String modelId; + private final long modelVersion; + private final List> objectsToInfer; + private final InferenceParams params; + + public Request(String modelId, long modelVersion) { + this(modelId, modelVersion, Collections.emptyList(), InferenceParams.EMPTY_PARAMS); + } + + public Request(String modelId, long modelVersion, List> objectsToInfer, InferenceParams inferenceParams) { + this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); + this.modelVersion = modelVersion; + this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, "objects_to_infer")); + this.params = inferenceParams == null ? InferenceParams.EMPTY_PARAMS : inferenceParams; + } + + public Request(String modelId, long modelVersion, Map objectToInfer, InferenceParams params) { + this(modelId, + modelVersion, + Arrays.asList(ExceptionsHelper.requireNonNull(objectToInfer, "objects_to_infer")), + params); + } + + public Request(StreamInput in) throws IOException { + super(in); + this.modelId = in.readString(); + this.modelVersion = in.readVLong(); + this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap)); + this.params = new InferenceParams(in); + } + + public String getModelId() { + return modelId; + } + + public long getModelVersion() { + return modelVersion; + } + + public List> getObjectsToInfer() { + return objectsToInfer; + } + + public InferenceParams getParams() { + return params; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(modelId); + out.writeVLong(modelVersion); + out.writeCollection(objectsToInfer, StreamOutput::writeMap); + params.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferModelAction.Request that = (InferModelAction.Request) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(modelVersion, that.modelVersion) + && Objects.equals(params, that.params) + && Objects.equals(objectsToInfer, that.objectsToInfer); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, modelVersion, objectsToInfer, params); + } + + } + + public static class Response extends ActionResponse { + + private final List inferenceResults; + + public Response(List inferenceResults) { + super(); + this.inferenceResults = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(inferenceResults, "inferenceResults")); + } + + public Response(StreamInput in) throws IOException { + super(in); + this.inferenceResults = Collections.unmodifiableList(in.readNamedWriteableList(InferenceResults.class)); + } + + public List getInferenceResults() { + return inferenceResults; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeNamedWriteableList(inferenceResults); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferModelAction.Response that = (InferModelAction.Response) o; + return Objects.equals(inferenceResults, that.inferenceResults); + } + + @Override + public int hashCode() { + return Objects.hash(inferenceResults); + } + + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 7fff4d6abbd3b..7b56a4c3b4da3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -8,6 +8,9 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; @@ -100,6 +103,14 @@ public List getNamedWriteables() { WeightedMode.NAME.getPreferredName(), WeightedMode::new)); + // Inference Results + namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, + ClassificationInferenceResults.NAME, + ClassificationInferenceResults::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, + RegressionInferenceResults.NAME, + RegressionInferenceResults::new)); + return namedWriteables; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java index f85c184646e1f..e936d60bf87b7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java @@ -18,6 +18,8 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; @@ -27,6 +29,7 @@ import java.io.IOException; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Objects; public class TrainedModelDefinition implements ToXContentObject, Writeable { @@ -118,6 +121,15 @@ public Input getInput() { return input; } + private void preProcess(Map fields) { + preProcessors.forEach(preProcessor -> preProcessor.process(fields)); + } + + public InferenceResults infer(Map fields, InferenceParams params) { + preProcess(fields); + return trainedModel.infer(fields, params); + } + @Override public String toString() { return Strings.toString(this); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java new file mode 100644 index 0000000000000..662585bedf51d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java @@ -0,0 +1,175 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +public class ClassificationInferenceResults extends SingleValueInferenceResults { + + public static final String NAME = "classification"; + public static final ParseField CLASSIFICATION_LABEL = new ParseField("classification_label"); + public static final ParseField TOP_CLASSES = new ParseField("top_classes"); + + private final String classificationLabel; + private final List topClasses; + + public ClassificationInferenceResults(double value, String classificationLabel, List topClasses) { + super(value); + this.classificationLabel = classificationLabel; + this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses); + } + + public ClassificationInferenceResults(StreamInput in) throws IOException { + super(in); + this.classificationLabel = in.readOptionalString(); + this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new)); + } + + public String getClassificationLabel() { + return classificationLabel; + } + + public List getTopClasses() { + return topClasses; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalString(classificationLabel); + out.writeCollection(topClasses); + } + + @Override + XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { + if (classificationLabel != null) { + builder.field(CLASSIFICATION_LABEL.getPreferredName(), classificationLabel); + } + if (topClasses.isEmpty() == false) { + builder.field(TOP_CLASSES.getPreferredName(), topClasses); + } + return builder; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + ClassificationInferenceResults that = (ClassificationInferenceResults) object; + return Objects.equals(value(), that.value()) && + Objects.equals(classificationLabel, that.classificationLabel) && + Objects.equals(topClasses, that.topClasses); + } + + @Override + public int hashCode() { + return Objects.hash(value(), classificationLabel, topClasses); + } + + @Override + public String valueAsString() { + return classificationLabel == null ? super.valueAsString() : classificationLabel; + } + + @Override + public void writeResult(IngestDocument document, String resultField) { + ExceptionsHelper.requireNonNull(document, "document"); + ExceptionsHelper.requireNonNull(resultField, "resultField"); + if (topClasses.isEmpty()) { + document.setFieldValue(resultField, valueAsString()); + } else { + document.setFieldValue(resultField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList())); + } + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public String getName() { + return NAME; + } + + public static class TopClassEntry implements ToXContentObject, Writeable { + + public final ParseField CLASSIFICATION = new ParseField("classification"); + public final ParseField PROBABILITY = new ParseField("probability"); + + private final String classification; + private final double probability; + + public TopClassEntry(String classification, Double probability) { + this.classification = ExceptionsHelper.requireNonNull(classification, CLASSIFICATION); + this.probability = ExceptionsHelper.requireNonNull(probability, PROBABILITY); + } + + public TopClassEntry(StreamInput in) throws IOException { + this.classification = in.readString(); + this.probability = in.readDouble(); + } + + public String getClassification() { + return classification; + } + + public double getProbability() { + return probability; + } + + public Map asValueMap() { + Map map = new HashMap<>(2); + map.put(CLASSIFICATION.getPreferredName(), classification); + map.put(PROBABILITY.getPreferredName(), probability); + return map; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(classification); + out.writeDouble(probability); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASSIFICATION.getPreferredName(), classification); + builder.field(PROBABILITY.getPreferredName(), probability); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + TopClassEntry that = (TopClassEntry) object; + return Objects.equals(classification, that.classification) && + Objects.equals(probability, that.probability); + } + + @Override + public int hashCode() { + return Objects.hash(classification, probability); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java new file mode 100644 index 0000000000000..00744f6982f46 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; + +public interface InferenceResults extends NamedXContentObject, NamedWriteable { + + void writeResult(IngestDocument document, String resultField); + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java new file mode 100644 index 0000000000000..e186489b91dab --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class RegressionInferenceResults extends SingleValueInferenceResults { + + public static final String NAME = "regression"; + + public RegressionInferenceResults(double value) { + super(value); + } + + public RegressionInferenceResults(StreamInput in) throws IOException { + super(in.readDouble()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + XContentBuilder innerToXContent(XContentBuilder builder, Params params) { + return builder; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + RegressionInferenceResults that = (RegressionInferenceResults) object; + return Objects.equals(value(), that.value()); + } + + @Override + public int hashCode() { + return Objects.hash(value()); + } + + @Override + public void writeResult(IngestDocument document, String resultField) { + ExceptionsHelper.requireNonNull(document, "document"); + ExceptionsHelper.requireNonNull(resultField, "resultField"); + document.setFieldValue(resultField, value()); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java new file mode 100644 index 0000000000000..2905a6679584c --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; + +public abstract class SingleValueInferenceResults implements InferenceResults { + + public final ParseField VALUE = new ParseField("value"); + + private final double value; + + SingleValueInferenceResults(StreamInput in) throws IOException { + value = in.readDouble(); + } + + SingleValueInferenceResults(double value) { + this.value = value; + } + + public Double value() { + return value; + } + + public String valueAsString() { + return String.valueOf(value); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(value); + } + + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(VALUE.getPreferredName(), value); + innerToXContent(builder, params); + builder.endObject(); + return builder; + } + + abstract XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException; +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java new file mode 100644 index 0000000000000..5e37b237e9f79 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public final class InferenceHelpers { + + private InferenceHelpers() { } + + public static List topClasses(List probabilities, + List classificationLabels, + int numToInclude) { + if (numToInclude == 0) { + return Collections.emptyList(); + } + int[] sortedIndices = IntStream.range(0, probabilities.size()) + .boxed() + .sorted(Comparator.comparing(probabilities::get).reversed()) + .mapToInt(i -> i) + .toArray(); + + if (classificationLabels != null && probabilities.size() != classificationLabels.size()) { + throw ExceptionsHelper + .serverError( + "model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]", + null, + probabilities.size(), + classificationLabels); + } + + List labels = classificationLabels == null ? + // If we don't have the labels we should return the top classification values anyways, they will just be numeric + IntStream.range(0, probabilities.size()).boxed().map(String::valueOf).collect(Collectors.toList()) : + classificationLabels; + + int count = numToInclude < 0 ? probabilities.size() : Math.min(numToInclude, probabilities.size()); + List topClassEntries = new ArrayList<>(count); + for(int i = 0; i < count; i++) { + int idx = sortedIndices[i]; + topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx))); + } + + return topClassEntries; + } + + + public static String classificationLabel(double inferenceValue, @Nullable List classificationLabels) { + assert inferenceValue == Math.rint(inferenceValue); + if (classificationLabels == null) { + return String.valueOf(inferenceValue); + } + int label = Double.valueOf(inferenceValue).intValue(); + if (label < 0 || label >= classificationLabels.size()) { + throw ExceptionsHelper.serverError( + "model returned classification value of [{}] which is not a valid index in classification labels [{}]", + null, + label, + classificationLabels); + } + return classificationLabels.get(label); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java new file mode 100644 index 0000000000000..150bf3d483f26 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class InferenceParams implements ToXContentObject, Writeable { + + public static ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); + + public static InferenceParams EMPTY_PARAMS = new InferenceParams(0); + + private final int numTopClasses; + + public InferenceParams(Integer numTopClasses) { + this.numTopClasses = numTopClasses == null ? 0 : numTopClasses; + } + + public InferenceParams(StreamInput in) throws IOException { + this.numTopClasses = in.readInt(); + } + + public int getNumTopClasses() { + return numTopClasses; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(numTopClasses); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferenceParams that = (InferenceParams) o; + return Objects.equals(numTopClasses, that.numTopClasses); + } + + @Override + public int hashCode() { + return Objects.hash(numTopClasses); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (numTopClasses != 0) { + builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); + } + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java index cad5a6c0a8c74..a6c6f1eff011d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java @@ -7,6 +7,7 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; import java.util.List; @@ -26,13 +27,7 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable { * @return The predicted value. For classification this will be discrete values (e.g. 0.0, or 1.0). * For regression this is continuous. */ - double infer(Map fields); - - /** - * @param fields similar to {@link TrainedModel#infer(Map)}, but fields are already in order and doubles - * @return The predicted value. - */ - double infer(List fields); + InferenceResults infer(Map fields, InferenceParams params); /** * @return {@link TargetType} for the model. @@ -40,26 +35,7 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable { TargetType targetType(); /** - * This gathers the probabilities for each potential classification value. - * - * The probabilities are indexed by classification ordinal label encoding. - * The length of this list is equal to the number of classification labels. - * - * This only should return if the implementation model is inferring classification values and not regression - * @param fields The fields and their values to infer against - * @return The probabilities of each classification value - */ - List classificationProbability(Map fields); - - /** - * @param fields similar to {@link TrainedModel#classificationProbability(Map)} but the fields are already in order and doubles - * @return The probabilities of each classification value - */ - List classificationProbability(List fields); - - /** - * The ordinal encoded list of the classification labels. - * @return Oridinal encoded list of classification labels. + * @return Ordinal encoded list of classification labels. */ @Nullable List classificationLabels(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index 5e5199c24053f..09e418cec916d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -12,6 +12,12 @@ import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; @@ -26,6 +32,8 @@ import java.util.Objects; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel; + public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel { // TODO should we have regression/classification sub-classes that accept the builder? @@ -106,14 +114,21 @@ public List getFeatureNames() { } @Override - public double infer(Map fields) { - List processedInferences = inferAndProcess(fields); - return outputAggregator.aggregate(processedInferences); - } - - @Override - public double infer(List fields) { - throw new UnsupportedOperationException("Ensemble requires map containing field names and values"); + public InferenceResults infer(Map fields, InferenceParams params) { + if (params.getNumTopClasses() != 0 && + (targetType != TargetType.CLASSIFICATION || outputAggregator.providesProbabilities() == false)) { + throw ExceptionsHelper.badRequestException( + "Cannot return top classes for target_type [{}] and aggregate_output [{}]", + targetType, + outputAggregator.getName()); + } + List inferenceResults = this.models.stream().map(model -> { + InferenceResults results = model.infer(fields, InferenceParams.EMPTY_PARAMS); + assert results instanceof SingleValueInferenceResults; + return ((SingleValueInferenceResults)results).value(); + }).collect(Collectors.toList()); + List processed = outputAggregator.processValues(inferenceResults); + return buildResults(processed, params); } @Override @@ -121,18 +136,20 @@ public TargetType targetType() { return targetType; } - @Override - public List classificationProbability(Map fields) { - if ((targetType == TargetType.CLASSIFICATION) == false) { - throw new UnsupportedOperationException( - "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); + private InferenceResults buildResults(List processedInferences, InferenceParams params) { + switch(targetType) { + case REGRESSION: + return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences)); + case CLASSIFICATION: + List topClasses = + InferenceHelpers.topClasses(processedInferences, classificationLabels, params.getNumTopClasses()); + double value = outputAggregator.aggregate(processedInferences); + return new ClassificationInferenceResults(outputAggregator.aggregate(processedInferences), + classificationLabel(value, classificationLabels), + topClasses); + default: + throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model"); } - return inferAndProcess(fields); - } - - @Override - public List classificationProbability(List fields) { - throw new UnsupportedOperationException("Ensemble requires map containing field names and values"); } @Override @@ -140,11 +157,6 @@ public List classificationLabels() { return classificationLabels; } - private List inferAndProcess(Map fields) { - List modelInferences = models.stream().map(m -> m.infer(fields)).collect(Collectors.toList()); - return outputAggregator.processValues(modelInferences); - } - @Override public String getWriteableName() { return NAME.getPreferredName(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java index 1f882b724ee94..012f474ab0618 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java @@ -44,4 +44,6 @@ public interface OutputAggregator extends NamedXContentObject, NamedWriteable { * @return The name of the output aggregator */ String getName(); + + boolean providesProbabilities(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java index 739a4e13d8659..0689d748b0ccb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -158,4 +158,9 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(weights); } + + @Override + public boolean providesProbabilities() { + return true; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java index f5812dabf88f2..9c5c2bf582e54 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java @@ -135,4 +135,9 @@ public int hashCode() { public Integer expectedValueSize() { return weights == null ? null : this.weights.size(); } + + @Override + public boolean providesProbabilities() { + return false; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index 3a91ec0cd86c1..bce6b08b6ed4b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -13,6 +13,11 @@ import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; @@ -31,6 +36,8 @@ import java.util.Set; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel; + public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel { // TODO should we have regression/classification sub-classes that accept the builder? @@ -105,20 +112,36 @@ public List getNodes() { } @Override - public double infer(Map fields) { + public InferenceResults infer(Map fields, InferenceParams params) { + if (targetType != TargetType.CLASSIFICATION && params.getNumTopClasses() != 0) { + throw ExceptionsHelper.badRequestException( + "Cannot return top classes for target_type [{}]", targetType.toString()); + } List features = featureNames.stream().map(f -> - fields.get(f) instanceof Number ? ((Number)fields.get(f)).doubleValue() : null + fields.get(f) instanceof Number ? ((Number) fields.get(f)).doubleValue() : null ).collect(Collectors.toList()); - return infer(features); + return infer(features, params); } - @Override - public double infer(List features) { + private InferenceResults infer(List features, InferenceParams params) { TreeNode node = nodes.get(0); while(node.isLeaf() == false) { node = nodes.get(node.compare(features)); } - return node.getLeafValue(); + return buildResult(node.getLeafValue(), params); + } + + private InferenceResults buildResult(Double value, InferenceParams params) { + switch (targetType) { + case CLASSIFICATION: + List topClasses = + InferenceHelpers.topClasses(classificationProbability(value), classificationLabels, params.getNumTopClasses()); + return new ClassificationInferenceResults(value, classificationLabel(value, classificationLabels), topClasses); + case REGRESSION: + return new RegressionInferenceResults(value); + default: + throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model"); + } } /** @@ -142,34 +165,15 @@ public TargetType targetType() { return targetType; } - @Override - public List classificationProbability(Map fields) { - if ((targetType == TargetType.CLASSIFICATION) == false) { - throw new UnsupportedOperationException( - "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); - } - List features = featureNames.stream().map(f -> - fields.get(f) instanceof Number ? ((Number)fields.get(f)).doubleValue() : null) - .collect(Collectors.toList()); - - return classificationProbability(features); - } - - @Override - public List classificationProbability(List fields) { - if ((targetType == TargetType.CLASSIFICATION) == false) { - throw new UnsupportedOperationException( - "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); - } - double label = infer(fields); + private List classificationProbability(double inferenceValue) { // If we are classification, we should assume that the inference return value is whole. - assert label == Math.rint(label); + assert inferenceValue == Math.rint(inferenceValue); double maxCategory = this.highestOrderCategory.get(); // If we are classification, we should assume that the largest leaf value is whole. assert maxCategory == Math.rint(maxCategory); List list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)); // TODO, eventually have TreeNodes contain confidence levels - list.set(Double.valueOf(label).intValue(), 1.0); + list.set(Double.valueOf(inferenceValue).intValue(), 1.0); return list; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java index cb44d03e22bb2..44cca308ea794 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java @@ -5,6 +5,8 @@ */ package org.elasticsearch.xpack.core.ml.inference.utils; +import org.elasticsearch.common.Numbers; + import java.util.List; import java.util.stream.Collectors; @@ -22,31 +24,31 @@ private Statistics(){} */ public static List softMax(List values) { Double expSum = 0.0; - Double max = values.stream().filter(v -> isInvalid(v) == false).max(Double::compareTo).orElse(null); + Double max = values.stream().filter(Statistics::isValid).max(Double::compareTo).orElse(null); if (max == null) { throw new IllegalArgumentException("no valid values present"); } - List exps = values.stream().map(v -> isInvalid(v) ? Double.NEGATIVE_INFINITY : v - max) + List exps = values.stream().map(v -> isValid(v) ? v - max : Double.NEGATIVE_INFINITY) .collect(Collectors.toList()); for (int i = 0; i < exps.size(); i++) { - if (isInvalid(exps.get(i)) == false) { + if (isValid(exps.get(i))) { Double exp = Math.exp(exps.get(i)); expSum += exp; exps.set(i, exp); } } for (int i = 0; i < exps.size(); i++) { - if (isInvalid(exps.get(i))) { - exps.set(i, 0.0); - } else { + if (isValid(exps.get(i))) { exps.set(i, exps.get(i)/expSum); + } else { + exps.set(i, 0.0); } } return exps; } - public static boolean isInvalid(Double v) { - return v == null || Double.isInfinite(v) || Double.isNaN(v); + private static boolean isValid(Double v) { + return v != null && Numbers.isValidDouble(v); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java index 320eace983590..8dfcc5fc59977 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java @@ -51,6 +51,10 @@ public static ElasticsearchException serverError(String msg, Throwable cause) { return new ElasticsearchException(msg, cause); } + public static ElasticsearchException serverError(String msg, Throwable cause, Object... args) { + return new ElasticsearchException(msg, cause, args); + } + public static ElasticsearchStatusException conflictStatusException(String msg, Throwable cause, Object... args) { return new ElasticsearchStatusException(msg, RestStatus.CONFLICT, cause, args); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java new file mode 100644 index 0000000000000..a49643d081957 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.InferModelAction.Request; + +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParamsTests.randomInferenceParams; + +public class InferModelActionRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + return randomBoolean() ? + new Request( + randomAlphaOfLength(10), + randomLongBetween(1, 100), + Stream.generate(InferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()), + randomBoolean() ? null : randomInferenceParams()) : + new Request( + randomAlphaOfLength(10), + randomLongBetween(1, 100), + randomMap(), + randomBoolean() ? null : randomInferenceParams()); + } + + private static Map randomMap() { + return Stream.generate(()-> randomAlphaOfLength(10)) + .limit(randomInt(10)) + .collect(Collectors.toMap(Function.identity(), (v) -> randomAlphaOfLength(10))); + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java new file mode 100644 index 0000000000000..9e72d1c4e682a --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.InferModelAction.Response; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class InferModelActionResponseTests extends AbstractWireSerializingTestCase { + + @Override + protected Response createTestInstance() { + String resultType = randomFrom(ClassificationInferenceResults.NAME, RegressionInferenceResults.NAME); + return new Response( + Stream.generate(() -> randomInferenceResult(resultType)) + .limit(randomIntBetween(0, 10)) + .collect(Collectors.toList())); + } + + private static InferenceResults randomInferenceResult(String resultType) { + if (resultType.equals(ClassificationInferenceResults.NAME)) { + return ClassificationInferenceResultsTests.createRandomResults(); + } else if (resultType.equals(RegressionInferenceResults.NAME)) { + return RegressionInferenceResultsTests.createRandomResults(); + } else { + fail("unexpected result type [" + resultType + "]"); + return null; + } + } + + @Override + protected Writeable.Reader instanceReader() { + return Response::new; + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java new file mode 100644 index 0000000000000..ba90fece02f2a --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.equalTo; + +public class ClassificationInferenceResultsTests extends AbstractWireSerializingTestCase { + + public static ClassificationInferenceResults createRandomResults() { + return new ClassificationInferenceResults(randomDouble(), + randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : + Stream.generate(ClassificationInferenceResultsTests::createRandomClassEntry) + .limit(randomIntBetween(0, 10)) + .collect(Collectors.toList())); + } + + private static ClassificationInferenceResults.TopClassEntry createRandomClassEntry() { + return new ClassificationInferenceResults.TopClassEntry(randomAlphaOfLength(10), randomDouble()); + } + + public void testWriteResultsWithClassificationLabel() { + ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, "foo", Collections.emptyList()); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + + assertThat(document.getFieldValue("result_field", String.class), equalTo("foo")); + } + + public void testWriteResultsWithoutClassificationLabel() { + ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, null, Collections.emptyList()); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + + assertThat(document.getFieldValue("result_field", String.class), equalTo("1.0")); + } + + @SuppressWarnings("unchecked") + public void testWriteResultsWithTopClasses() { + List entries = Arrays.asList( + new ClassificationInferenceResults.TopClassEntry("foo", 0.7), + new ClassificationInferenceResults.TopClassEntry("bar", 0.2), + new ClassificationInferenceResults.TopClassEntry("baz", 0.1)); + ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, + "foo", + entries); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + + List list = document.getFieldValue("result_field", List.class); + assertThat(list.size(), equalTo(3)); + + for(int i = 0; i < 3; i++) { + Map map = (Map)list.get(i); + assertThat(map, equalTo(entries.get(i).asValueMap())); + } + } + + @Override + protected ClassificationInferenceResults createTestInstance() { + return createRandomResults(); + } + + @Override + protected Writeable.Reader instanceReader() { + return ClassificationInferenceResults::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java new file mode 100644 index 0000000000000..4f2d5926c84dc --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; + +import java.util.HashMap; + +import static org.hamcrest.Matchers.equalTo; + + +public class RegressionInferenceResultsTests extends AbstractWireSerializingTestCase { + + public static RegressionInferenceResults createRandomResults() { + return new RegressionInferenceResults(randomDouble()); + } + + public void testWriteResults() { + RegressionInferenceResults result = new RegressionInferenceResults(0.3); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + + assertThat(document.getFieldValue("result_field", Double.class), equalTo(0.3)); + } + + @Override + protected RegressionInferenceResults createTestInstance() { + return createRandomResults(); + } + + @Override + protected Writeable.Reader instanceReader() { + return RegressionInferenceResults::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParamsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParamsTests.java new file mode 100644 index 0000000000000..2586cdd75de47 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParamsTests.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +public class InferenceParamsTests extends AbstractWireSerializingTestCase { + + public static InferenceParams randomInferenceParams() { + return randomBoolean() ? InferenceParams.EMPTY_PARAMS : new InferenceParams(randomIntBetween(-1, 100)); + } + + @Override + protected InferenceParams createTestInstance() { + return randomInferenceParams(); + } + + @Override + protected Writeable.Reader instanceReader() { + return InferenceParams::new; + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index eb537e247e994..8da0c15718f24 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -15,6 +15,9 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; @@ -239,27 +242,30 @@ public void testClassificationProbability() { List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); - List expected = Arrays.asList(0.231475216, 0.768524783); + List expected = Arrays.asList(0.768524783, 0.231475216); double eps = 0.000001; - List probabilities = ensemble.classificationProbability(featureMap); + List probabilities = + ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { - assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); - expected = Arrays.asList(0.3100255188, 0.689974481); - probabilities = ensemble.classificationProbability(featureMap); + expected = Arrays.asList(0.689974481, 0.3100255188); + probabilities = + ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { - assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } featureVector = Arrays.asList(0.0, 1.0); featureMap = zipObjMap(featureNames, featureVector); - expected = Arrays.asList(0.231475216, 0.768524783); - probabilities = ensemble.classificationProbability(featureMap); + expected = Arrays.asList(0.768524783, 0.231475216); + probabilities = + ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { - assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } // This should handle missing values and take the default_left path @@ -268,9 +274,10 @@ public void testClassificationProbability() { put("bar", null); }}; expected = Arrays.asList(0.6899744811, 0.3100255188); - probabilities = ensemble.classificationProbability(featureMap); + probabilities = + ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { - assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } } @@ -320,21 +327,25 @@ public void testClassificationInference() { List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + assertThat(1.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + assertThat(1.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); featureVector = Arrays.asList(0.0, 1.0); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + assertThat(1.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); featureMap = new HashMap<>(2) {{ put("foo", 0.3); put("bar", null); }}; - assertEquals(0.0, ensemble.infer(featureMap), 0.00001); + assertThat(0.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); } public void testRegressionInference() { @@ -373,11 +384,13 @@ public void testRegressionInference() { List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); - assertEquals(0.9, ensemble.infer(featureMap), 0.00001); + assertThat(0.9, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(0.5, ensemble.infer(featureMap), 0.00001); + assertThat(0.5, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); // Test with NO aggregator supplied, verifies default behavior of non-weighted sum ensemble = Ensemble.builder() @@ -388,17 +401,20 @@ public void testRegressionInference() { featureVector = Arrays.asList(0.4, 0.0); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.8, ensemble.infer(featureMap), 0.00001); + assertThat(1.8, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + assertThat(1.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); featureMap = new HashMap<>(2) {{ put("foo", 0.3); put("bar", null); }}; - assertEquals(1.8, ensemble.infer(featureMap), 0.00001); + assertThat(1.8, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); } private static Map zipObjMap(List keys, List values) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index 81030585f1889..c362c17fd579d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -10,6 +10,9 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.junit.Before; @@ -120,26 +123,30 @@ public void testInfer() { // This feature vector should hit the right child of the root node List featureVector = Arrays.asList(0.6, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); - assertThat(0.3, closeTo(tree.infer(featureMap), 0.00001)); + assertThat(0.3, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); // This should hit the left child of the left child of the root node // i.e. it takes the path left, left featureVector = Arrays.asList(0.3, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertThat(0.1, closeTo(tree.infer(featureMap), 0.00001)); + assertThat(0.1, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); // This should hit the right child of the left child of the root node // i.e. it takes the path left, right featureVector = Arrays.asList(0.3, 0.9); featureMap = zipObjMap(featureNames, featureVector); - assertThat(0.2, closeTo(tree.infer(featureMap), 0.00001)); + assertThat(0.2, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); // This should handle missing values and take the default_left path featureMap = new HashMap<>(2) {{ put("foo", 0.3); put("bar", null); }}; - assertThat(0.1, closeTo(tree.infer(featureMap), 0.00001)); + assertThat(0.1, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); } public void testTreeClassificationProbability() { @@ -153,31 +160,43 @@ public void testTreeClassificationProbability() { builder.addLeaf(leftChildNode.getRightChild(), 0.0); List featureNames = Arrays.asList("foo", "bar"); - Tree tree = builder.setFeatureNames(featureNames).build(); + Tree tree = builder.setFeatureNames(featureNames).setClassificationLabels(Arrays.asList("cat", "dog")).build(); + double eps = 0.000001; // This feature vector should hit the right child of the root node List featureVector = Arrays.asList(0.6, 0.0); + List expectedProbs = Arrays.asList(1.0, 0.0); + List expectedFields = Arrays.asList("dog", "cat"); Map featureMap = zipObjMap(featureNames, featureVector); - assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap)); + List probabilities = + ((ClassificationInferenceResults)tree.infer(featureMap, new InferenceParams(2))).getTopClasses(); + for(int i = 0; i < expectedProbs.size(); i++) { + assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); + assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); + } // This should hit the left child of the left child of the root node // i.e. it takes the path left, left featureVector = Arrays.asList(0.3, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap)); - - // This should hit the right child of the left child of the root node - // i.e. it takes the path left, right - featureVector = Arrays.asList(0.3, 0.9); - featureMap = zipObjMap(featureNames, featureVector); - assertEquals(Arrays.asList(1.0, 0.0), tree.classificationProbability(featureMap)); + probabilities = + ((ClassificationInferenceResults)tree.infer(featureMap, new InferenceParams(2))).getTopClasses(); + for(int i = 0; i < expectedProbs.size(); i++) { + assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); + assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); + } // This should handle missing values and take the default_left path featureMap = new HashMap<>(2) {{ put("foo", 0.3); put("bar", null); }}; - assertEquals(1.0, tree.infer(featureMap), 0.00001); + probabilities = + ((ClassificationInferenceResults)tree.infer(featureMap, new InferenceParams(2))).getTopClasses(); + for(int i = 0; i < expectedProbs.size(); i++) { + assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); + assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); + } } public void testTreeWithNullRoot() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 277d7018e8d4e..0b09b0736bfbb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -95,6 +95,7 @@ import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction; import org.elasticsearch.xpack.core.ml.action.GetRecordsAction; import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.KillProcessAction; import org.elasticsearch.xpack.core.ml.action.MlInfoAction; import org.elasticsearch.xpack.core.ml.action.OpenJobAction; @@ -161,6 +162,7 @@ import org.elasticsearch.xpack.ml.action.TransportGetModelSnapshotsAction; import org.elasticsearch.xpack.ml.action.TransportGetOverallBucketsAction; import org.elasticsearch.xpack.ml.action.TransportGetRecordsAction; +import org.elasticsearch.xpack.ml.action.TransportInferModelAction; import org.elasticsearch.xpack.ml.action.TransportIsolateDatafeedAction; import org.elasticsearch.xpack.ml.action.TransportKillProcessAction; import org.elasticsearch.xpack.ml.action.TransportMlInfoAction; @@ -200,7 +202,9 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier; @@ -501,6 +505,8 @@ public Collection createComponents(Client client, ClusterService cluster notifier, xContentRegistry); + final TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry); + final ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); // special holder for @link(MachineLearningFeatureSetUsage) which needs access to job manager if ML is enabled JobManagerHolder jobManagerHolder = new JobManagerHolder(jobManager); @@ -613,7 +619,9 @@ public Collection createComponents(Client client, ClusterService cluster analyticsProcessManager, memoryEstimationProcessManager, dataFrameAnalyticsConfigProvider, - nativeStorageProvider + nativeStorageProvider, + modelLoadingService, + trainedModelProvider ); } @@ -768,6 +776,7 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(StopDataFrameAnalyticsAction.INSTANCE, TransportStopDataFrameAnalyticsAction.class), new ActionHandler<>(EvaluateDataFrameAction.INSTANCE, TransportEvaluateDataFrameAction.class), new ActionHandler<>(EstimateMemoryUsageAction.INSTANCE, TransportEstimateMemoryUsageAction.class), + new ActionHandler<>(InferModelAction.INSTANCE, TransportInferModelAction.class), usageAction, infoAction); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java new file mode 100644 index 0000000000000..a2f79fc9d0437 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.ml.inference.loadingservice.Model; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; +import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor; + + +public class TransportInferModelAction extends HandledTransportAction { + + private final ModelLoadingService modelLoadingService; + private final Client client; + + @Inject + public TransportInferModelAction(TransportService transportService, + ActionFilters actionFilters, + ModelLoadingService modelLoadingService, + Client client) { + super(InferModelAction.NAME, transportService, actionFilters, InferModelAction.Request::new); + this.modelLoadingService = modelLoadingService; + this.client = client; + } + + @Override + protected void doExecute(Task task, InferModelAction.Request request, ActionListener listener) { + + ActionListener getModelListener = ActionListener.wrap( + model -> { + TypedChainTaskExecutor typedChainTaskExecutor = + new TypedChainTaskExecutor<>(client.threadPool().executor(ThreadPool.Names.SAME), + // run through all tasks + r -> true, + // Always fail immediately and return an error + ex -> true); + request.getObjectsToInfer().forEach(stringObjectMap -> + typedChainTaskExecutor.add(chainedTask -> + model.infer(stringObjectMap, request.getParams(), chainedTask))); + + typedChainTaskExecutor.execute(ActionListener.wrap( + inferenceResultsInterfaces -> + listener.onResponse(new InferModelAction.Response(inferenceResultsInterfaces)), + listener::onFailure + )); + }, + listener::onFailure + ); + + this.modelLoadingService.getModel(request.getModelId(), request.getModelVersion(), getModelListener); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java new file mode 100644 index 0000000000000..59f6c62a7f55e --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.ingest; + +import org.elasticsearch.client.Client; +import org.elasticsearch.ingest.AbstractProcessor; +import org.elasticsearch.ingest.IngestDocument; + +import java.util.function.BiConsumer; + +public class InferenceProcessor extends AbstractProcessor { + + public static final String TYPE = "inference"; + public static final String MODEL_ID = "model_id"; + + private final Client client; + + public InferenceProcessor(Client client, String tag) { + super(tag); + this.client = client; + } + + @Override + public void execute(IngestDocument ingestDocument, BiConsumer handler) { + //TODO actually work + handler.accept(ingestDocument, null); + } + + @Override + public IngestDocument execute(IngestDocument ingestDocument) { + throw new UnsupportedOperationException("should never be called"); + } + + @Override + public String getType() { + return TYPE; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java new file mode 100644 index 0000000000000..e5253b3d5b173 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; + +import java.util.Map; + +public class LocalModel implements Model { + + private final TrainedModelDefinition trainedModelDefinition; + private final String modelId; + + public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition) { + this.trainedModelDefinition = trainedModelDefinition; + this.modelId = modelId; + } + + @Override + public String getResultsType() { + switch (trainedModelDefinition.getTrainedModel().targetType()) { + case CLASSIFICATION: + return ClassificationInferenceResults.NAME; + case REGRESSION: + return RegressionInferenceResults.NAME; + default: + throw ExceptionsHelper.badRequestException("Model [{}] has unsupported target type [{}]", + modelId, + trainedModelDefinition.getTrainedModel().targetType()); + } + } + + @Override + public void infer(Map fields, InferenceParams params, ActionListener listener) { + try { + listener.onResponse(trainedModelDefinition.infer(fields, params)); + } catch (Exception e) { + listener.onFailure(e); + } + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java new file mode 100644 index 0000000000000..27924a47aa153 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; + +import java.util.Map; + +public interface Model { + + String getResultsType(); + + void infer(Map fields, InferenceParams inferenceParams, ActionListener listener); + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java new file mode 100644 index 0000000000000..b4fc552ba5f93 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -0,0 +1,287 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; + +public class ModelLoadingService implements ClusterStateListener { + + private static final Logger logger = LogManager.getLogger(ModelLoadingService.class); + private final Map loadedModels = new HashMap<>(); + private final Map>> loadingListeners = new HashMap<>(); + private final TrainedModelProvider provider; + private final ThreadPool threadPool; + + public ModelLoadingService(TrainedModelProvider trainedModelProvider, + ThreadPool threadPool, + ClusterService clusterService) { + this.provider = trainedModelProvider; + this.threadPool = threadPool; + clusterService.addListener(this); + } + + public void getModel(String modelId, long modelVersion, ActionListener modelActionListener) { + String key = modelKey(modelId, modelVersion); + MaybeModel cachedModel = loadedModels.get(key); + if (cachedModel != null) { + if (cachedModel.isSuccess()) { + modelActionListener.onResponse(cachedModel.getModel()); + return; + } + } + if (loadModelIfNecessary(key, modelId, modelVersion, modelActionListener) == false) { + // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called + // by a simulated pipeline + logger.debug("[{}] version [{}] not actively loading, eager loading without cache", modelId, modelVersion); + provider.getTrainedModel(modelId, modelVersion, ActionListener.wrap( + trainedModelConfig -> + modelActionListener.onResponse(new LocalModel(trainedModelConfig.getModelId(), trainedModelConfig.getDefinition())), + modelActionListener::onFailure + )); + } else { + logger.debug("[{}] version [{}] is currently loading, added new listener to queue", modelId, modelVersion); + } + } + + /** + * Returns true if the model is loaded and the listener has been given the cached model + * Returns true if the model is CURRENTLY being loaded and the listener was added to be notified when it is loaded + * Returns false if the model is not loaded or actively being loaded + */ + private boolean loadModelIfNecessary(String key, String modelId, long modelVersion, ActionListener modelActionListener) { + synchronized (loadingListeners) { + MaybeModel cachedModel = loadedModels.get(key); + if (cachedModel != null) { + if (cachedModel.isSuccess()) { + modelActionListener.onResponse(cachedModel.getModel()); + return true; + } + // If the loaded model entry is there but is not present, that means the previous load attempt ran into an issue + // Attempt to load and cache the model if necessary + if (loadingListeners.computeIfPresent( + key, + (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) == null) { + logger.debug("[{}] version [{}] attempting to load and cache", modelId, modelVersion); + loadingListeners.put(key, addFluently(new ArrayDeque<>(), modelActionListener)); + loadModel(key, modelId, modelVersion); + } + return true; + } + // if the cachedModel entry is null, but there are listeners present, that means it is being loaded + return loadingListeners.computeIfPresent(key, + (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) != null; + } + } + + private void loadModel(String modelKey, String modelId, long modelVersion) { + provider.getTrainedModel(modelId, modelVersion, ActionListener.wrap( + trainedModelConfig -> { + logger.debug("[{}] successfully loaded model", modelKey); + handleLoadSuccess(modelKey, trainedModelConfig); + }, + failure -> { + logger.warn(new ParameterizedMessage("[{}] failed to load model", modelKey), failure); + handleLoadFailure(modelKey, failure); + } + )); + } + + private void handleLoadSuccess(String modelKey, TrainedModelConfig trainedModelConfig) { + Queue> listeners; + Model loadedModel = new LocalModel(trainedModelConfig.getModelId(), trainedModelConfig.getDefinition()); + synchronized (loadingListeners) { + listeners = loadingListeners.remove(modelKey); + // If there is no loadingListener that means the loading was canceled and the listener was already notified as such + // Consequently, we should not store the retrieved model + if (listeners != null) { + loadedModels.put(modelKey, MaybeModel.of(loadedModel)); + } + } + if (listeners != null) { + for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + listener.onResponse(loadedModel); + } + } + } + + private void handleLoadFailure(String modelKey, Exception failure) { + Queue> listeners; + synchronized (loadingListeners) { + listeners = loadingListeners.remove(modelKey); + if (listeners != null) { + // If we failed to load and there were listeners present, that means that this model is referenced by a processor + // Add an empty entry here so that we can attempt to load and cache the model again when it is accessed again. + loadedModels.computeIfAbsent(modelKey, (key) -> MaybeModel.of(failure)); + } + } + if (listeners != null) { + for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + listener.onFailure(failure); + } + } + } + + @Override + public void clusterChanged(ClusterChangedEvent event) { + if (event.changedCustomMetaDataSet().contains(IngestMetadata.TYPE)) { + ClusterState state = event.state(); + IngestMetadata currentIngestMetadata = state.metaData().custom(IngestMetadata.TYPE); + Set allReferencedModelKeys = getReferencedModelKeys(currentIngestMetadata); + // The listeners still waiting for a model and we are canceling the load? + List>>> drainWithFailure = new ArrayList<>(); + synchronized (loadingListeners) { + // If we had models still loading here but are no longer referenced + // we should remove them from loadingListeners and alert the listeners + for (String modelKey : loadingListeners.keySet()) { + if (allReferencedModelKeys.contains(modelKey) == false) { + drainWithFailure.add(Tuple.tuple(splitModelKey(modelKey).v1(), new ArrayList<>(loadingListeners.remove(modelKey)))); + } + } + + // Remove all cached models that are not referenced by any processors + loadedModels.keySet().retainAll(allReferencedModelKeys); + + // Remove all that are currently being loaded + allReferencedModelKeys.removeAll(loadingListeners.keySet()); + + // Remove all that are fully loaded, will attempt empty model loading again + loadedModels.forEach((id, optionalModel) -> { + if (optionalModel.isSuccess()) { + allReferencedModelKeys.remove(id); + } + }); + // Populate loadingListeners key so we know that we are currently loading the model + for (String modelId : allReferencedModelKeys) { + loadingListeners.put(modelId, new ArrayDeque<>()); + } + } + for (Tuple>> modelAndListeners : drainWithFailure) { + final String msg = new ParameterizedMessage( + "Cancelling load of model [{}] as it is no longer referenced by a pipeline", + modelAndListeners.v1()).getFormat(); + for (ActionListener listener : modelAndListeners.v2()) { + listener.onFailure(new ElasticsearchException(msg)); + } + } + loadModels(allReferencedModelKeys); + } + } + + private void loadModels(Set modelKeys) { + if (modelKeys.isEmpty()) { + return; + } + // Execute this on a utility thread as when the callbacks occur we don't want them tying up the cluster listener thread pool + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> { + for (String modelKey : modelKeys) { + Tuple modelIdAndVersion = splitModelKey(modelKey); + this.loadModel(modelKey, modelIdAndVersion.v1(), modelIdAndVersion.v2()); + } + }); + } + + private static Queue addFluently(Queue queue, T object) { + queue.add(object); + return queue; + } + + private static String modelKey(String modelId, long modelVersion) { + return modelId + "_" + modelVersion; + } + + private static Tuple splitModelKey(String modelKey) { + int delim = modelKey.lastIndexOf('_'); + String modelId = modelKey.substring(0, delim); + Long version = Long.valueOf(modelKey.substring(delim + 1)); + return Tuple.tuple(modelId, version); + } + + private static Set getReferencedModelKeys(IngestMetadata ingestMetadata) { + Set allReferencedModelKeys = new HashSet<>(); + if (ingestMetadata != null) { + ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> { + Object processors = pipelineConfiguration.getConfigAsMap().get("processors"); + if (processors instanceof List) { + for(Object processor : (List)processors) { + if (processor instanceof Map) { + Object processorConfig = ((Map)processor).get(InferenceProcessor.TYPE); + if (processorConfig instanceof Map) { + Object modelId = ((Map)processorConfig).get(InferenceProcessor.MODEL_ID); + if (modelId != null) { + assert modelId instanceof String; + // TODO also read model version + allReferencedModelKeys.add(modelKey(modelId.toString(), 0)); + } + } + } + } + } + }); + } + return allReferencedModelKeys; + } + + private static class MaybeModel { + + private final Model model; + private final Exception exception; + + static MaybeModel of(Model model) { + return new MaybeModel(model, null); + } + + static MaybeModel of(Exception exception) { + return new MaybeModel(null, exception); + } + + private MaybeModel(Model model, Exception exception) { + this.model = model; + this.exception = exception; + } + + Model getModel() { + return model; + } + + Exception getException() { + return exception; + } + + boolean isSuccess() { + return this.model != null; + } + + boolean isFailure() { + return this.exception != null; + } + + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java new file mode 100644 index 0000000000000..e66a5790d85e5 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -0,0 +1,213 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class LocalModelTests extends ESTestCase { + + public void testClassificationInfer() throws Exception { + String modelId = "classification_model"; + TrainedModelDefinition definition = new TrainedModelDefinition.Builder() + .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) + .setTrainedModel(buildClassification(false)) + .build(); + + Model model = new LocalModel(modelId, definition); + Map fields = new HashMap<>() {{ + put("foo", 1.0); + put("bar", 0.5); + put("categorical", "dog"); + }}; + + SingleValueInferenceResults result = getSingleValue(model, fields, InferenceParams.EMPTY_PARAMS); + assertThat(result.value(), equalTo(0.0)); + assertThat(result.valueAsString(), is("0.0")); + + ClassificationInferenceResults classificationResult = + (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(1)); + assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); + assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("0")); + + // Test with labels + definition = new TrainedModelDefinition.Builder() + .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) + .setTrainedModel(buildClassification(true)) + .build(); + model = new LocalModel(modelId, definition); + result = getSingleValue(model, fields, InferenceParams.EMPTY_PARAMS); + assertThat(result.value(), equalTo(0.0)); + assertThat(result.valueAsString(), equalTo("not_to_be")); + + classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(1)); + assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); + assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be")); + + classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(2)); + assertThat(classificationResult.getTopClasses(), hasSize(2)); + + classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(-1)); + assertThat(classificationResult.getTopClasses(), hasSize(2)); + } + + public void testRegression() throws Exception { + TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder() + .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) + .setTrainedModel(buildRegression()) + .build(); + Model model = new LocalModel("regression_model", trainedModelDefinition); + + Map fields = new HashMap<>() {{ + put("foo", 1.0); + put("bar", 0.5); + put("categorical", "dog"); + }}; + + SingleValueInferenceResults results = getSingleValue(model, fields, InferenceParams.EMPTY_PARAMS); + assertThat(results.value(), equalTo(1.3)); + + PlainActionFuture failedFuture = new PlainActionFuture<>(); + model.infer(fields, new InferenceParams(2), failedFuture); + ExecutionException ex = expectThrows(ExecutionException.class, failedFuture::get); + assertThat(ex.getCause().getMessage(), + equalTo("Cannot return top classes for target_type [regression] and aggregate_output [weighted_sum]")); + } + + private static SingleValueInferenceResults getSingleValue(Model model, + Map fields, + InferenceParams params) throws Exception { + PlainActionFuture future = new PlainActionFuture<>(); + model.infer(fields, params, future); + return (SingleValueInferenceResults)future.get(); + } + + private static Map oneHotMap() { + Map oneHotEncoding = new HashMap<>(); + oneHotEncoding.put("cat", "animal_cat"); + oneHotEncoding.put("dog", "animal_dog"); + return oneHotEncoding; + } + + public static TrainedModel buildClassification(boolean includeLabels) { + List featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(1.0)) + .addNode(TreeNode.builder(2) + .setThreshold(0.8) + .setSplitFeature(1) + .setLeftChild(3) + .setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(0.0)) + .addNode(TreeNode.builder(4).setLeafValue(1.0)).build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(3) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(0.0)) + .addNode(TreeNode.builder(2).setLeafValue(1.0)) + .build(); + Tree tree3 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(1.0)) + .addNode(TreeNode.builder(2).setLeafValue(0.0)) + .build(); + return Ensemble.builder() + .setClassificationLabels(includeLabels ? Arrays.asList("not_to_be", "to_be") : null) + .setTargetType(TargetType.CLASSIFICATION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0))) + .build(); + } + + public static TrainedModel buildRegression() { + List featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(0.3)) + .addNode(TreeNode.builder(2) + .setThreshold(0.0) + .setSplitFeature(3) + .setLeftChild(3) + .setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(0.1)) + .addNode(TreeNode.builder(4).setLeafValue(0.2)).build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(2) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(1.5)) + .addNode(TreeNode.builder(2).setLeafValue(0.9)) + .build(); + Tree tree3 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(1) + .setThreshold(0.2)) + .addNode(TreeNode.builder(1).setLeafValue(1.5)) + .addNode(TreeNode.builder(2).setLeafValue(0.9)) + .build(); + return Ensemble.builder() + .setTargetType(TargetType.REGRESSION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5, 0.5))) + .build(); + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java new file mode 100644 index 0000000000000..4fbf973aae7f0 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -0,0 +1,212 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.PipelineConfiguration; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ScalingExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.atMost; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class ModelLoadingServiceTests extends ESTestCase { + + private TrainedModelProvider trainedModelProvider; + private ThreadPool threadPool; + private ClusterService clusterService; + + @Before + public void setUpComponents() { + threadPool = new TestThreadPool("ModelLoadingServiceTests", new ScalingExecutorBuilder(UTILITY_THREAD_POOL_NAME, + 1, 4, TimeValue.timeValueMinutes(10), "xpack.ml.utility_thread_pool")); + trainedModelProvider = mock(TrainedModelProvider.class); + clusterService = mock(ClusterService.class); + doAnswer((invocationOnMock) -> null).when(clusterService).addListener(any(ClusterStateListener.class)); + when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("_name")).build()); + } + + @After + public void terminateThreadPool() { + terminate(threadPool); + } + + public void testGetCachedModels() throws Exception { + String model1 = "test-load-model-1"; + String model2 = "test-load-model-2"; + String model3 = "test-load-model-3"; + withTrainedModel(model1, 0); + withTrainedModel(model2, 0); + withTrainedModel(model3, 0); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); + + modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3)); + + String[] modelIds = new String[]{model1, model2, model3}; + for(int i = 0; i < 10; i++) { + String model = modelIds[i%3]; + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, 0, future); + assertThat(future.get(), is(not(nullValue()))); + } + + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(0L), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(0L), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(0L), any()); + } + + public void testGetCachedMissingModel() throws Exception { + String model = "test-load-cached-missing-model"; + withMissingModel(model, 0); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); + modelLoadingService.clusterChanged(ingestChangedEvent(model)); + + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, 0, future); + + try { + future.get(); + fail("Should not have succeeded in loaded model"); + } catch (Exception ex) { + assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model, 0))); + } + + verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model), eq(0L), any()); + } + + public void testGetMissingModel() { + String model = "test-load-missing-model"; + withMissingModel(model, 0); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); + + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, 0, future); + try { + future.get(); + fail("Should not have succeeded"); + } catch (Exception ex) { + assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model, 0))); + } + } + + public void testGetModelEagerly() throws Exception { + String model = "test-get-model-eagerly"; + withTrainedModel(model, 0); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); + + for(int i = 0; i < 3; i++) { + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, 0, future); + assertThat(future.get(), is(not(nullValue()))); + } + + verify(trainedModelProvider, times(3)).getTrainedModel(eq(model), eq(0L), any()); + } + + @SuppressWarnings("unchecked") + private void withTrainedModel(String modelId, long modelVersion) { + TrainedModelConfig trainedModelConfig = buildTrainedModelConfigBuilder(modelId, modelVersion).build(Version.CURRENT); + doAnswer(invocationOnMock -> { + @SuppressWarnings("rawtypes") + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(trainedModelConfig); + return null; + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(modelVersion), any()); + } + + private void withMissingModel(String modelId, long modelVersion) { + doAnswer(invocationOnMock -> { + @SuppressWarnings("rawtypes") + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId, modelVersion))); + return null; + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(modelVersion), any()); + } + + private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId, long modelVersion) { + return TrainedModelConfig.builder() + .setCreatedBy("ml_test") + .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) + .setDescription("trained model config for test") + .setModelId(modelId) + .setModelType("binary_decision_tree") + .setModelVersion(modelVersion); + } + + private static ClusterChangedEvent ingestChangedEvent(String... modelId) throws IOException { + ClusterChangedEvent event = mock(ClusterChangedEvent.class); + when(event.changedCustomMetaDataSet()).thenReturn(Collections.singleton(IngestMetadata.TYPE)); + when(event.state()).thenReturn(buildClusterStateWithModelReferences(modelId)); + return event; + } + + private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException { + Map configurations = new HashMap<>(modelId.length); + for (String id : modelId) { + configurations.put("pipeline_with_model_" + id, newConfigurationWithInferenceProcessor(id)); + } + IngestMetadata ingestMetadata = new IngestMetadata(configurations); + + return ClusterState.builder(new ClusterName("_name")) + .metaData(MetaData.builder().putCustom(IngestMetadata.TYPE, ingestMetadata)) + .build(); + } + + private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId) throws IOException { + try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors", + Collections.singletonList( + Collections.singletonMap(InferenceProcessor.TYPE, + Collections.singletonMap(InferenceProcessor.MODEL_ID, + modelId)))))) { + return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON); + } + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java new file mode 100644 index 0000000000000..032f84d16b52d --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -0,0 +1,185 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.Version; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildClassification; +import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildRegression; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.nullValue; + +public class ModelInferenceActionIT extends MlSingleNodeTestCase { + + private TrainedModelProvider trainedModelProvider; + + @Before + public void createComponents() throws Exception { + trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry()); + waitForMlTemplates(); + } + + public void testInferModels() throws Exception { + String modelId1 = "test-load-models-regression"; + String modelId2 = "test-load-models-classification"; + Map oneHotEncoding = new HashMap<>(); + oneHotEncoding.put("cat", "animal_cat"); + oneHotEncoding.put("dog", "animal_dog"); + TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2, 0) + .setDefinition(new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) + .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) + .setTrainedModel(buildClassification(true))) + .build(Version.CURRENT); + TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1, 0) + .setDefinition(new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) + .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) + .setTrainedModel(buildRegression())) + .build(Version.CURRENT); + AtomicReference putConfigHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config1, listener), putConfigHolder, exceptionHolder); + assertThat(putConfigHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config2, listener), putConfigHolder, exceptionHolder); + assertThat(putConfigHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + + + List> toInfer = new ArrayList<>(); + toInfer.add(new HashMap<>() {{ + put("foo", 1.0); + put("bar", 0.5); + put("categorical", "dog"); + }}); + toInfer.add(new HashMap<>() {{ + put("foo", 0.9); + put("bar", 1.5); + put("categorical", "cat"); + }}); + + List> toInfer2 = new ArrayList<>(); + toInfer2.add(new HashMap<>() {{ + put("foo", 0.0); + put("bar", 0.01); + put("categorical", "dog"); + }}); + toInfer2.add(new HashMap<>() {{ + put("foo", 1.0); + put("bar", 0.0); + put("categorical", "cat"); + }}); + + // Test regression + InferModelAction.Request request = new InferModelAction.Request(modelId1, 0, toInfer, null); + InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()), + contains(1.3, 1.25)); + + request = new InferModelAction.Request(modelId1, 0, toInfer2, null); + response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()), + contains(1.65, 1.55)); + + + // Test classification + request = new InferModelAction.Request(modelId2, 0, toInfer, null); + response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + assertThat(response.getInferenceResults() + .stream() + .map(i -> ((SingleValueInferenceResults)i).valueAsString()) + .collect(Collectors.toList()), + contains("not_to_be", "to_be")); + + // Get top classes + request = new InferModelAction.Request(modelId2, 0, toInfer, new InferenceParams(2)); + response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + + ClassificationInferenceResults classificationInferenceResults = + (ClassificationInferenceResults)response.getInferenceResults().get(0); + + assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("not_to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(), + greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); + + classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(1); + assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("not_to_be")); + // they should always be in order of Most probable to least + assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(), + greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); + + // Test that top classes restrict the number returned + request = new InferModelAction.Request(modelId2, 0, toInfer2, new InferenceParams(1)); + response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + + classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0); + assertThat(classificationInferenceResults.getTopClasses(), hasSize(1)); + assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be")); + } + + public void testInferMissingModel() { + String model = "test-infer-missing-model"; + InferModelAction.Request request = new InferModelAction.Request(model, 0, Collections.emptyList(), null); + try { + client().execute(InferModelAction.INSTANCE, request).actionGet(); + } catch (ElasticsearchException ex) { + assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model, 0))); + } + } + + private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId, long modelVersion) { + return TrainedModelConfig.builder() + .setCreatedBy("ml_test") + .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) + .setDescription("trained model config for test") + .setModelId(modelId) + .setModelType("binary_decision_tree") + .setModelVersion(modelVersion); + } + + @Override + public NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + +} From 33e1880a52b0fde8815bb249039812b109c12c34 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 10 Oct 2019 07:44:51 -0400 Subject: [PATCH 02/17] [ML][Inference] Adjust inference configuration option API (#47812) * [ML][Inference] Adjust inference configuration option API * fixing method reference * fixing tests --- .../xpack/core/XPackClientPlugin.java | 6 ++ .../core/ml/action/InferModelAction.java | 27 ++++---- .../MlInferenceNamedXContentProvider.java | 7 ++ .../ml/inference/TrainedModelDefinition.java | 6 +- ...eParams.java => ClassificationConfig.java} | 32 +++++++--- .../trainedmodel/InferenceConfig.java | 16 +++++ .../trainedmodel/RegressionConfig.java | 64 +++++++++++++++++++ .../inference/trainedmodel/TrainedModel.java | 3 +- .../trainedmodel/ensemble/Ensemble.java | 32 ++++++---- .../ensemble/NullInferenceConfig.java | 47 ++++++++++++++ .../ensemble/OutputAggregator.java | 3 +- .../trainedmodel/ensemble/WeightedMode.java | 10 +-- .../trainedmodel/ensemble/WeightedSum.java | 5 +- .../ml/inference/trainedmodel/tree/Tree.java | 25 +++++--- .../action/InferModelActionRequestTests.java | 22 ++++++- .../ClassificationConfigTests.java | 27 ++++++++ ...sTests.java => RegressionConfigTests.java} | 14 ++-- .../trainedmodel/ensemble/EnsembleTests.java | 51 +++++++++++---- .../ensemble/WeightedModeTests.java | 8 +++ .../ensemble/WeightedSumTests.java | 8 +++ .../trainedmodel/tree/TreeTests.java | 17 ++--- .../ml/action/TransportInferModelAction.java | 2 +- .../inference/loadingservice/LocalModel.java | 6 +- .../ml/inference/loadingservice/Model.java | 4 +- .../loadingservice/LocalModelTests.java | 26 ++++---- .../integration/ModelInferenceActionIT.java | 24 +++---- 26 files changed, 376 insertions(+), 116 deletions(-) rename x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/{InferenceParams.java => ClassificationConfig.java} (67%) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java rename x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/{InferenceParamsTests.java => RegressionConfigTests.java} (50%) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index 2c8553d68fe98..6b7da4b0ff57b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -143,6 +143,9 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator; @@ -471,6 +474,9 @@ public List getNamedWriteables() { new NamedWriteableRegistry.Entry(InferenceResults.class, RegressionInferenceResults.NAME, RegressionInferenceResults::new), + // ML - Inference Configuration + new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new), + new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new), // monitoring new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java index 248d6180d3256..67e3a75283d67 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java @@ -13,7 +13,8 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -37,24 +38,24 @@ public static class Request extends ActionRequest { private final String modelId; private final long modelVersion; private final List> objectsToInfer; - private final InferenceParams params; + private final InferenceConfig config; public Request(String modelId, long modelVersion) { - this(modelId, modelVersion, Collections.emptyList(), InferenceParams.EMPTY_PARAMS); + this(modelId, modelVersion, Collections.emptyList(), new RegressionConfig()); } - public Request(String modelId, long modelVersion, List> objectsToInfer, InferenceParams inferenceParams) { + public Request(String modelId, long modelVersion, List> objectsToInfer, InferenceConfig inferenceConfig) { this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); this.modelVersion = modelVersion; this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, "objects_to_infer")); - this.params = inferenceParams == null ? InferenceParams.EMPTY_PARAMS : inferenceParams; + this.config = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config"); } - public Request(String modelId, long modelVersion, Map objectToInfer, InferenceParams params) { + public Request(String modelId, long modelVersion, Map objectToInfer, InferenceConfig config) { this(modelId, modelVersion, Arrays.asList(ExceptionsHelper.requireNonNull(objectToInfer, "objects_to_infer")), - params); + config); } public Request(StreamInput in) throws IOException { @@ -62,7 +63,7 @@ public Request(StreamInput in) throws IOException { this.modelId = in.readString(); this.modelVersion = in.readVLong(); this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap)); - this.params = new InferenceParams(in); + this.config = in.readNamedWriteable(InferenceConfig.class); } public String getModelId() { @@ -77,8 +78,8 @@ public List> getObjectsToInfer() { return objectsToInfer; } - public InferenceParams getParams() { - return params; + public InferenceConfig getConfig() { + return config; } @Override @@ -92,7 +93,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(modelId); out.writeVLong(modelVersion); out.writeCollection(objectsToInfer, StreamOutput::writeMap); - params.writeTo(out); + out.writeNamedWriteable(config); } @Override @@ -102,13 +103,13 @@ public boolean equals(Object o) { InferModelAction.Request that = (InferModelAction.Request) o; return Objects.equals(modelId, that.modelId) && Objects.equals(modelVersion, that.modelVersion) - && Objects.equals(params, that.params) + && Objects.equals(config, that.config) && Objects.equals(objectsToInfer, that.objectsToInfer); } @Override public int hashCode() { - return Objects.hash(modelId, modelVersion, objectsToInfer, params); + return Objects.hash(modelId, modelVersion, objectsToInfer, config); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 7b56a4c3b4da3..352271b6f27ab 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -11,7 +11,10 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; @@ -111,6 +114,10 @@ public List getNamedWriteables() { RegressionInferenceResults.NAME, RegressionInferenceResults::new)); + // Inference Configs + namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new)); + return namedWriteables; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java index e936d60bf87b7..0798e721ed17f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java @@ -19,7 +19,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; @@ -125,9 +125,9 @@ private void preProcess(Map fields) { preProcessors.forEach(preProcessor -> preProcessor.process(fields)); } - public InferenceResults infer(Map fields, InferenceParams params) { + public InferenceResults infer(Map fields, InferenceConfig config) { preProcess(fields); - return trainedModel.infer(fields, params); + return trainedModel.infer(fields, config); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java similarity index 67% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java index 150bf3d483f26..5aa0403d94753 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java @@ -8,26 +8,26 @@ import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import java.io.IOException; import java.util.Objects; -public class InferenceParams implements ToXContentObject, Writeable { +public class ClassificationConfig implements InferenceConfig { - public static ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); + public static final String NAME = "classification"; - public static InferenceParams EMPTY_PARAMS = new InferenceParams(0); + public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); + + public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0); private final int numTopClasses; - public InferenceParams(Integer numTopClasses) { + public ClassificationConfig(Integer numTopClasses) { this.numTopClasses = numTopClasses == null ? 0 : numTopClasses; } - public InferenceParams(StreamInput in) throws IOException { + public ClassificationConfig(StreamInput in) throws IOException { this.numTopClasses = in.readInt(); } @@ -44,7 +44,7 @@ public void writeTo(StreamOutput out) throws IOException { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - InferenceParams that = (InferenceParams) o; + ClassificationConfig that = (ClassificationConfig) o; return Objects.equals(numTopClasses, that.numTopClasses); } @@ -62,4 +62,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean isTargetTypeSupported(TargetType targetType) { + return TargetType.CLASSIFICATION.equals(targetType); + } + } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java new file mode 100644 index 0000000000000..6129d71d5ff95 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; + + +public interface InferenceConfig extends NamedXContentObject, NamedWriteable { + + boolean isTargetTypeSupported(TargetType targetType); + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java new file mode 100644 index 0000000000000..bb7f772f86ba4 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class RegressionConfig implements InferenceConfig { + + public static final String NAME = "regression"; + + public RegressionConfig() { + } + + public RegressionConfig(StreamInput in) { + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + return true; + } + + @Override + public int hashCode() { + return Objects.hash(NAME); + } + + @Override + public boolean isTargetTypeSupported(TargetType targetType) { + return TargetType.REGRESSION.equals(targetType); + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java index a6c6f1eff011d..d1215943cbe12 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java @@ -24,10 +24,11 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable { * Infer against the provided fields * * @param fields The fields and their values to infer against + * @param config The configuration options for inference * @return The predicted value. For classification this will be discrete values (e.g. 0.0, or 1.0). * For regression this is continuous. */ - InferenceResults infer(Map fields, InferenceParams params); + InferenceResults infer(Map fields, InferenceConfig config); /** * @return {@link TargetType} for the model. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index 09e418cec916d..ff03a621d99fa 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -16,8 +16,9 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; @@ -114,21 +115,18 @@ public List getFeatureNames() { } @Override - public InferenceResults infer(Map fields, InferenceParams params) { - if (params.getNumTopClasses() != 0 && - (targetType != TargetType.CLASSIFICATION || outputAggregator.providesProbabilities() == false)) { + public InferenceResults infer(Map fields, InferenceConfig config) { + if (config.isTargetTypeSupported(targetType) == false) { throw ExceptionsHelper.badRequestException( - "Cannot return top classes for target_type [{}] and aggregate_output [{}]", - targetType, - outputAggregator.getName()); + "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString()); } List inferenceResults = this.models.stream().map(model -> { - InferenceResults results = model.infer(fields, InferenceParams.EMPTY_PARAMS); + InferenceResults results = model.infer(fields, NullInferenceConfig.INSTANCE); assert results instanceof SingleValueInferenceResults; return ((SingleValueInferenceResults)results).value(); }).collect(Collectors.toList()); List processed = outputAggregator.processValues(inferenceResults); - return buildResults(processed, params); + return buildResults(processed, config); } @Override @@ -136,13 +134,16 @@ public TargetType targetType() { return targetType; } - private InferenceResults buildResults(List processedInferences, InferenceParams params) { + private InferenceResults buildResults(List processedInferences, InferenceConfig config) { switch(targetType) { case REGRESSION: return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences)); case CLASSIFICATION: - List topClasses = - InferenceHelpers.topClasses(processedInferences, classificationLabels, params.getNumTopClasses()); + ClassificationConfig classificationConfig = (ClassificationConfig) config; + List topClasses = InferenceHelpers.topClasses( + processedInferences, + classificationLabels, + classificationConfig.getNumTopClasses()); double value = outputAggregator.aggregate(processedInferences); return new ClassificationInferenceResults(outputAggregator.aggregate(processedInferences), classificationLabel(value, classificationLabels), @@ -216,6 +217,13 @@ public int hashCode() { @Override public void validate() { + if (outputAggregator.compatibleWith(targetType) == false) { + throw ExceptionsHelper.badRequestException( + "aggregate_output [{}] is not compatible with target_type [{}]", + this.targetType, + outputAggregator.getName() + ); + } if (outputAggregator.expectedValueSize() != null && outputAggregator.expectedValueSize() != models.size()) { throw ExceptionsHelper.badRequestException( diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java new file mode 100644 index 0000000000000..7628d0beec25f --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; + +import java.io.IOException; + +/** + * Used by ensemble to pass into sub-models. + */ +class NullInferenceConfig implements InferenceConfig { + + public static final NullInferenceConfig INSTANCE = new NullInferenceConfig(); + + private NullInferenceConfig() { } + + @Override + public boolean isTargetTypeSupported(TargetType targetType) { + return true; + } + + @Override + public String getWriteableName() { + return "null"; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + } + + @Override + public String getName() { + return "null"; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java index 012f474ab0618..f19ae376f0e96 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; import java.util.List; @@ -45,5 +46,5 @@ public interface OutputAggregator extends NamedXContentObject, NamedWriteable { */ String getName(); - boolean providesProbabilities(); + boolean compatibleWith(TargetType targetType); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java index 0689d748b0ccb..a872565ad20b5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.io.IOException; import java.util.ArrayList; @@ -123,6 +124,11 @@ public String getName() { return NAME.getPreferredName(); } + @Override + public boolean compatibleWith(TargetType targetType) { + return true; + } + @Override public String getWriteableName() { return NAME.getPreferredName(); @@ -159,8 +165,4 @@ public int hashCode() { return Objects.hash(weights); } - @Override - public boolean providesProbabilities() { - return true; - } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java index 9c5c2bf582e54..db70346c0849e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.io.IOException; import java.util.Collections; @@ -137,7 +138,7 @@ public Integer expectedValueSize() { } @Override - public boolean providesProbabilities() { - return false; + public boolean compatibleWith(TargetType targetType) { + return TargetType.REGRESSION.equals(targetType); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index bce6b08b6ed4b..a48cca3873117 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -16,8 +16,9 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; @@ -112,30 +113,34 @@ public List getNodes() { } @Override - public InferenceResults infer(Map fields, InferenceParams params) { - if (targetType != TargetType.CLASSIFICATION && params.getNumTopClasses() != 0) { + public InferenceResults infer(Map fields, InferenceConfig config) { + if (config.isTargetTypeSupported(targetType) == false) { throw ExceptionsHelper.badRequestException( - "Cannot return top classes for target_type [{}]", targetType.toString()); + "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString()); } + List features = featureNames.stream().map(f -> fields.get(f) instanceof Number ? ((Number) fields.get(f)).doubleValue() : null ).collect(Collectors.toList()); - return infer(features, params); + return infer(features, config); } - private InferenceResults infer(List features, InferenceParams params) { + private InferenceResults infer(List features, InferenceConfig config) { TreeNode node = nodes.get(0); while(node.isLeaf() == false) { node = nodes.get(node.compare(features)); } - return buildResult(node.getLeafValue(), params); + return buildResult(node.getLeafValue(), config); } - private InferenceResults buildResult(Double value, InferenceParams params) { + private InferenceResults buildResult(Double value, InferenceConfig config) { switch (targetType) { case CLASSIFICATION: - List topClasses = - InferenceHelpers.topClasses(classificationProbability(value), classificationLabels, params.getNumTopClasses()); + ClassificationConfig classificationConfig = (ClassificationConfig) config; + List topClasses = InferenceHelpers.topClasses( + classificationProbability(value), + classificationLabels, + classificationConfig.getNumTopClasses()); return new ClassificationInferenceResults(value, classificationLabel(value, classificationLabels), topClasses); case REGRESSION: return new RegressionInferenceResults(value); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java index a49643d081957..ed782a04e0c86 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java @@ -5,16 +5,22 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.ml.action.InferModelAction.Request; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParamsTests.randomInferenceParams; public class InferModelActionRequestTests extends AbstractWireSerializingTestCase { @@ -25,12 +31,16 @@ protected Request createTestInstance() { randomAlphaOfLength(10), randomLongBetween(1, 100), Stream.generate(InferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()), - randomBoolean() ? null : randomInferenceParams()) : + randomInferenceConfig()) : new Request( randomAlphaOfLength(10), randomLongBetween(1, 100), randomMap(), - randomBoolean() ? null : randomInferenceParams()); + randomInferenceConfig()); + } + + private static InferenceConfig randomInferenceConfig() { + return randomFrom(RegressionConfigTests.randomRegressionConfig(), ClassificationConfigTests.randomClassificationConfig()); } private static Map randomMap() { @@ -44,4 +54,10 @@ protected Writeable.Reader instanceReader() { return Request::new; } + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java new file mode 100644 index 0000000000000..4df3263215f63 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +public class ClassificationConfigTests extends AbstractWireSerializingTestCase { + + public static ClassificationConfig randomClassificationConfig() { + return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10)); + } + + @Override + protected ClassificationConfig createTestInstance() { + return randomClassificationConfig(); + } + + @Override + protected Writeable.Reader instanceReader() { + return ClassificationConfig::new; + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParamsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java similarity index 50% rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParamsTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java index 2586cdd75de47..57efcdd15009a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParamsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java @@ -8,20 +8,20 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -public class InferenceParamsTests extends AbstractWireSerializingTestCase { +public class RegressionConfigTests extends AbstractWireSerializingTestCase { - public static InferenceParams randomInferenceParams() { - return randomBoolean() ? InferenceParams.EMPTY_PARAMS : new InferenceParams(randomIntBetween(-1, 100)); + public static RegressionConfig randomRegressionConfig() { + return new RegressionConfig(); } @Override - protected InferenceParams createTestInstance() { - return randomInferenceParams(); + protected RegressionConfig createTestInstance() { + return randomRegressionConfig(); } @Override - protected Writeable.Reader instanceReader() { - return InferenceParams::new; + protected Writeable.Reader instanceReader() { + return RegressionConfig::new; } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 8da0c15718f24..816fabf7b0b67 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -17,7 +17,8 @@ import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; @@ -159,6 +160,27 @@ public void testEnsembleWithInvalidModel() { }); } + public void testEnsembleWithAggregatorOutputNotSupportingTargetType() { + List featureNames = Arrays.asList("foo", "bar"); + ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> { + Ensemble.builder() + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList( + Tree.builder() + .setNodes(TreeNode.builder(0) + .setLeftChild(1) + .setSplitFeature(1) + .setThreshold(randomDouble())) + .setFeatureNames(featureNames) + .build())) + .setClassificationLabels(Arrays.asList("label1", "label2")) + .setTargetType(TargetType.CLASSIFICATION) + .setOutputAggregator(new WeightedSum()) + .build() + .validate(); + }); + } + public void testEnsembleWithTargetTypeAndLabelsMismatch() { List featureNames = Arrays.asList("foo", "bar"); String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa"; @@ -190,6 +212,7 @@ public void testEnsembleWithTargetTypeAndLabelsMismatch() { .setFeatureNames(featureNames) .build())) .setTargetType(TargetType.CLASSIFICATION) + .setOutputAggregator(new WeightedMode()) .build() .validate(); }); @@ -245,7 +268,7 @@ public void testClassificationProbability() { List expected = Arrays.asList(0.768524783, 0.231475216); double eps = 0.000001; List probabilities = - ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); + ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } @@ -254,7 +277,7 @@ public void testClassificationProbability() { featureMap = zipObjMap(featureNames, featureVector); expected = Arrays.asList(0.689974481, 0.3100255188); probabilities = - ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); + ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } @@ -263,7 +286,7 @@ public void testClassificationProbability() { featureMap = zipObjMap(featureNames, featureVector); expected = Arrays.asList(0.768524783, 0.231475216); probabilities = - ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); + ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } @@ -275,7 +298,7 @@ public void testClassificationProbability() { }}; expected = Arrays.asList(0.6899744811, 0.3100255188); probabilities = - ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); + ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } @@ -328,24 +351,24 @@ public void testClassificationInference() { List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); assertThat(1.0, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); assertThat(1.0, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); featureVector = Arrays.asList(0.0, 1.0); featureMap = zipObjMap(featureNames, featureVector); assertThat(1.0, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); featureMap = new HashMap<>(2) {{ put("foo", 0.3); put("bar", null); }}; assertThat(0.0, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); } public void testRegressionInference() { @@ -385,12 +408,12 @@ public void testRegressionInference() { List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); assertThat(0.9, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); assertThat(0.5, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001)); // Test with NO aggregator supplied, verifies default behavior of non-weighted sum ensemble = Ensemble.builder() @@ -402,19 +425,19 @@ public void testRegressionInference() { featureVector = Arrays.asList(0.4, 0.0); featureMap = zipObjMap(featureNames, featureVector); assertThat(1.8, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); assertThat(1.0, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001)); featureMap = new HashMap<>(2) {{ put("foo", 0.3); put("bar", null); }}; assertThat(1.8, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001)); } private static Map zipObjMap(List keys, List values) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java index 7849d6d071ef1..683115e63879e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java @@ -8,6 +8,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.io.IOException; import java.util.Arrays; @@ -15,6 +16,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; public class WeightedModeTests extends WeightedAggregatorTests { @@ -55,4 +57,10 @@ public void testAggregate() { weightedMode = new WeightedMode(); assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0)); } + + public void testCompatibleWith() { + WeightedMode weightedMode = createTestInstance(); + assertThat(weightedMode.compatibleWith(TargetType.CLASSIFICATION), is(true)); + assertThat(weightedMode.compatibleWith(TargetType.REGRESSION), is(true)); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java index 89222365c83d8..fa372f043a410 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java @@ -8,6 +8,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.io.IOException; import java.util.Arrays; @@ -15,6 +16,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; public class WeightedSumTests extends WeightedAggregatorTests { @@ -55,4 +57,10 @@ public void testAggregate() { weightedSum = new WeightedSum(); assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0)); } + + public void testCompatibleWith() { + WeightedSum weightedSum = createTestInstance(); + assertThat(weightedSum.compatibleWith(TargetType.CLASSIFICATION), is(false)); + assertThat(weightedSum.compatibleWith(TargetType.REGRESSION), is(true)); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index c362c17fd579d..075bfbe912270 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -12,7 +12,8 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.junit.Before; @@ -124,21 +125,21 @@ public void testInfer() { List featureVector = Arrays.asList(0.6, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); assertThat(0.3, - closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001)); // This should hit the left child of the left child of the root node // i.e. it takes the path left, left featureVector = Arrays.asList(0.3, 0.7); featureMap = zipObjMap(featureNames, featureVector); assertThat(0.1, - closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001)); // This should hit the right child of the left child of the root node // i.e. it takes the path left, right featureVector = Arrays.asList(0.3, 0.9); featureMap = zipObjMap(featureNames, featureVector); assertThat(0.2, - closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001)); // This should handle missing values and take the default_left path featureMap = new HashMap<>(2) {{ @@ -146,7 +147,7 @@ public void testInfer() { put("bar", null); }}; assertThat(0.1, - closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001)); } public void testTreeClassificationProbability() { @@ -169,7 +170,7 @@ public void testTreeClassificationProbability() { List expectedFields = Arrays.asList("dog", "cat"); Map featureMap = zipObjMap(featureNames, featureVector); List probabilities = - ((ClassificationInferenceResults)tree.infer(featureMap, new InferenceParams(2))).getTopClasses(); + ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expectedProbs.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); @@ -180,7 +181,7 @@ public void testTreeClassificationProbability() { featureVector = Arrays.asList(0.3, 0.7); featureMap = zipObjMap(featureNames, featureVector); probabilities = - ((ClassificationInferenceResults)tree.infer(featureMap, new InferenceParams(2))).getTopClasses(); + ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expectedProbs.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); @@ -192,7 +193,7 @@ public void testTreeClassificationProbability() { put("bar", null); }}; probabilities = - ((ClassificationInferenceResults)tree.infer(featureMap, new InferenceParams(2))).getTopClasses(); + ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expectedProbs.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java index a2f79fc9d0437..b5d0b7c4e330a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java @@ -48,7 +48,7 @@ protected void doExecute(Task task, InferModelAction.Request request, ActionList ex -> true); request.getObjectsToInfer().forEach(stringObjectMap -> typedChainTaskExecutor.add(chainedTask -> - model.infer(stringObjectMap, request.getParams(), chainedTask))); + model.infer(stringObjectMap, request.getConfig(), chainedTask))); typedChainTaskExecutor.execute(ActionListener.wrap( inferenceResultsInterfaces -> diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index e5253b3d5b173..9bbf42915410d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -7,7 +7,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; @@ -40,9 +40,9 @@ public String getResultsType() { } @Override - public void infer(Map fields, InferenceParams params, ActionListener listener) { + public void infer(Map fields, InferenceConfig config, ActionListener listener) { try { - listener.onResponse(trainedModelDefinition.infer(fields, params)); + listener.onResponse(trainedModelDefinition.infer(fields, config)); } catch (Exception e) { listener.onFailure(e); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java index 27924a47aa153..c66a23d78f98e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java @@ -7,7 +7,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import java.util.Map; @@ -15,6 +15,6 @@ public interface Model { String getResultsType(); - void infer(Map fields, InferenceParams inferenceParams, ActionListener listener); + void infer(Map fields, InferenceConfig inferenceConfig, ActionListener listener); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index e66a5790d85e5..48aa70dec74f2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -10,7 +10,9 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; @@ -49,12 +51,12 @@ public void testClassificationInfer() throws Exception { put("categorical", "dog"); }}; - SingleValueInferenceResults result = getSingleValue(model, fields, InferenceParams.EMPTY_PARAMS); + SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfig(0)); assertThat(result.value(), equalTo(0.0)); assertThat(result.valueAsString(), is("0.0")); ClassificationInferenceResults classificationResult = - (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(1)); + (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1)); assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("0")); @@ -65,18 +67,18 @@ public void testClassificationInfer() throws Exception { .setTrainedModel(buildClassification(true)) .build(); model = new LocalModel(modelId, definition); - result = getSingleValue(model, fields, InferenceParams.EMPTY_PARAMS); + result = getSingleValue(model, fields, new ClassificationConfig(0)); assertThat(result.value(), equalTo(0.0)); assertThat(result.valueAsString(), equalTo("not_to_be")); - classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(1)); + classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1)); assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be")); - classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(2)); + classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(2)); assertThat(classificationResult.getTopClasses(), hasSize(2)); - classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(-1)); + classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(-1)); assertThat(classificationResult.getTopClasses(), hasSize(2)); } @@ -94,21 +96,21 @@ public void testRegression() throws Exception { put("categorical", "dog"); }}; - SingleValueInferenceResults results = getSingleValue(model, fields, InferenceParams.EMPTY_PARAMS); + SingleValueInferenceResults results = getSingleValue(model, fields, new RegressionConfig()); assertThat(results.value(), equalTo(1.3)); PlainActionFuture failedFuture = new PlainActionFuture<>(); - model.infer(fields, new InferenceParams(2), failedFuture); + model.infer(fields, new ClassificationConfig(2), failedFuture); ExecutionException ex = expectThrows(ExecutionException.class, failedFuture::get); assertThat(ex.getCause().getMessage(), - equalTo("Cannot return top classes for target_type [regression] and aggregate_output [weighted_sum]")); + equalTo("Cannot infer using configuration for [classification] when model target_type is [regression]")); } private static SingleValueInferenceResults getSingleValue(Model model, Map fields, - InferenceParams params) throws Exception { + InferenceConfig config) throws Exception { PlainActionFuture future = new PlainActionFuture<>(); - model.infer(fields, params, future); + model.infer(fields, config, future); return (SingleValueInferenceResults)future.get(); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 032f84d16b52d..36f5817e40f3a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -16,7 +16,8 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; @@ -58,13 +59,13 @@ public void testInferModels() throws Exception { Map oneHotEncoding = new HashMap<>(); oneHotEncoding.put("cat", "animal_cat"); oneHotEncoding.put("dog", "animal_dog"); - TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2, 0) + TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2) .setDefinition(new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) .setTrainedModel(buildClassification(true))) .build(Version.CURRENT); - TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1, 0) + TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1) .setDefinition(new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) @@ -106,19 +107,19 @@ public void testInferModels() throws Exception { }}); // Test regression - InferModelAction.Request request = new InferModelAction.Request(modelId1, 0, toInfer, null); + InferModelAction.Request request = new InferModelAction.Request(modelId1, 0, toInfer, new RegressionConfig()); InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet(); assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()), contains(1.3, 1.25)); - request = new InferModelAction.Request(modelId1, 0, toInfer2, null); + request = new InferModelAction.Request(modelId1, 0, toInfer2, new RegressionConfig()); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()), contains(1.65, 1.55)); // Test classification - request = new InferModelAction.Request(modelId2, 0, toInfer, null); + request = new InferModelAction.Request(modelId2, 0, toInfer, new ClassificationConfig(0)); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); assertThat(response.getInferenceResults() .stream() @@ -127,7 +128,7 @@ public void testInferModels() throws Exception { contains("not_to_be", "to_be")); // Get top classes - request = new InferModelAction.Request(modelId2, 0, toInfer, new InferenceParams(2)); + request = new InferModelAction.Request(modelId2, 0, toInfer, new ClassificationConfig(2)); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); ClassificationInferenceResults classificationInferenceResults = @@ -146,7 +147,7 @@ public void testInferModels() throws Exception { greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); // Test that top classes restrict the number returned - request = new InferModelAction.Request(modelId2, 0, toInfer2, new InferenceParams(1)); + request = new InferModelAction.Request(modelId2, 0, toInfer2, new ClassificationConfig(1)); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0); @@ -156,7 +157,7 @@ public void testInferModels() throws Exception { public void testInferMissingModel() { String model = "test-infer-missing-model"; - InferModelAction.Request request = new InferModelAction.Request(model, 0, Collections.emptyList(), null); + InferModelAction.Request request = new InferModelAction.Request(model, 0, Collections.emptyList(), new RegressionConfig()); try { client().execute(InferModelAction.INSTANCE, request).actionGet(); } catch (ElasticsearchException ex) { @@ -164,14 +165,13 @@ public void testInferMissingModel() { } } - private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId, long modelVersion) { + private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) { return TrainedModelConfig.builder() .setCreatedBy("ml_test") .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setDescription("trained model config for test") .setModelId(modelId) - .setModelType("binary_decision_tree") - .setModelVersion(modelVersion); + .setModelType("binary_decision_tree"); } @Override From 861ea4704da30ea713bfc01f92560c2e3043a21c Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 15 Oct 2019 12:21:51 -0400 Subject: [PATCH 03/17] removing model version param --- .../core/ml/action/InferModelAction.java | 20 ++--- .../xpack/core/ml/job/messages/Messages.java | 2 +- .../action/InferModelActionRequestTests.java | 2 - .../ml/action/TransportInferModelAction.java | 2 +- .../loadingservice/ModelLoadingService.java | 79 ++++++++----------- .../ModelLoadingServiceTests.java | 60 +++++++------- .../integration/ModelInferenceActionIT.java | 24 +++--- 7 files changed, 85 insertions(+), 104 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java index 67e3a75283d67..24b5e345430df 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java @@ -36,24 +36,21 @@ private InferModelAction() { public static class Request extends ActionRequest { private final String modelId; - private final long modelVersion; private final List> objectsToInfer; private final InferenceConfig config; - public Request(String modelId, long modelVersion) { - this(modelId, modelVersion, Collections.emptyList(), new RegressionConfig()); + public Request(String modelId) { + this(modelId, Collections.emptyList(), new RegressionConfig()); } - public Request(String modelId, long modelVersion, List> objectsToInfer, InferenceConfig inferenceConfig) { + public Request(String modelId, List> objectsToInfer, InferenceConfig inferenceConfig) { this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); - this.modelVersion = modelVersion; this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, "objects_to_infer")); this.config = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config"); } - public Request(String modelId, long modelVersion, Map objectToInfer, InferenceConfig config) { + public Request(String modelId, Map objectToInfer, InferenceConfig config) { this(modelId, - modelVersion, Arrays.asList(ExceptionsHelper.requireNonNull(objectToInfer, "objects_to_infer")), config); } @@ -61,7 +58,6 @@ public Request(String modelId, long modelVersion, Map objectToIn public Request(StreamInput in) throws IOException { super(in); this.modelId = in.readString(); - this.modelVersion = in.readVLong(); this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap)); this.config = in.readNamedWriteable(InferenceConfig.class); } @@ -70,10 +66,6 @@ public String getModelId() { return modelId; } - public long getModelVersion() { - return modelVersion; - } - public List> getObjectsToInfer() { return objectsToInfer; } @@ -91,7 +83,6 @@ public ActionRequestValidationException validate() { public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(modelId); - out.writeVLong(modelVersion); out.writeCollection(objectsToInfer, StreamOutput::writeMap); out.writeNamedWriteable(config); } @@ -102,14 +93,13 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; InferModelAction.Request that = (InferModelAction.Request) o; return Objects.equals(modelId, that.modelId) - && Objects.equals(modelVersion, that.modelVersion) && Objects.equals(config, that.config) && Objects.equals(objectsToInfer, that.objectsToInfer); } @Override public int hashCode() { - return Objects.hash(modelId, modelVersion, objectsToInfer, config); + return Objects.hash(modelId, objectsToInfer, config); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 75cc468160d17..c6558d781bbbc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -83,7 +83,7 @@ public final class Messages { public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists"; public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]"; public static final String INFERENCE_FAILED_TO_SERIALIZE_MODEL = - "Failed to serialize the trained model [{0}] with version [{1}] for storage"; + "Failed to serialize the trained model [{0}] for storage"; public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]"; public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again"; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java index ed782a04e0c86..051da354c2e59 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java @@ -29,12 +29,10 @@ protected Request createTestInstance() { return randomBoolean() ? new Request( randomAlphaOfLength(10), - randomLongBetween(1, 100), Stream.generate(InferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()), randomInferenceConfig()) : new Request( randomAlphaOfLength(10), - randomLongBetween(1, 100), randomMap(), randomInferenceConfig()); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java index b5d0b7c4e330a..b01063cac48dd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java @@ -59,6 +59,6 @@ protected void doExecute(Task task, InferModelAction.Request request, ActionList listener::onFailure ); - this.modelLoadingService.getModel(request.getModelId(), request.getModelVersion(), getModelListener); + this.modelLoadingService.getModel(request.getModelId(), getModelListener); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index b4fc552ba5f93..e6a5baf42ba12 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -47,26 +47,25 @@ public ModelLoadingService(TrainedModelProvider trainedModelProvider, clusterService.addListener(this); } - public void getModel(String modelId, long modelVersion, ActionListener modelActionListener) { - String key = modelKey(modelId, modelVersion); - MaybeModel cachedModel = loadedModels.get(key); + public void getModel(String modelId, ActionListener modelActionListener) { + MaybeModel cachedModel = loadedModels.get(modelId); if (cachedModel != null) { if (cachedModel.isSuccess()) { modelActionListener.onResponse(cachedModel.getModel()); return; } } - if (loadModelIfNecessary(key, modelId, modelVersion, modelActionListener) == false) { + if (loadModelIfNecessary(modelId, modelActionListener) == false) { // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called // by a simulated pipeline - logger.debug("[{}] version [{}] not actively loading, eager loading without cache", modelId, modelVersion); - provider.getTrainedModel(modelId, modelVersion, ActionListener.wrap( + logger.debug("[{}] not actively loading, eager loading without cache", modelId); + provider.getTrainedModel(modelId, ActionListener.wrap( trainedModelConfig -> modelActionListener.onResponse(new LocalModel(trainedModelConfig.getModelId(), trainedModelConfig.getDefinition())), modelActionListener::onFailure )); } else { - logger.debug("[{}] version [{}] is currently loading, added new listener to queue", modelId, modelVersion); + logger.debug("[{}] is currently loading, added new listener to queue", modelId); } } @@ -75,9 +74,9 @@ public void getModel(String modelId, long modelVersion, ActionListener mo * Returns true if the model is CURRENTLY being loaded and the listener was added to be notified when it is loaded * Returns false if the model is not loaded or actively being loaded */ - private boolean loadModelIfNecessary(String key, String modelId, long modelVersion, ActionListener modelActionListener) { + private boolean loadModelIfNecessary(String modelId, ActionListener modelActionListener) { synchronized (loadingListeners) { - MaybeModel cachedModel = loadedModels.get(key); + MaybeModel cachedModel = loadedModels.get(modelId); if (cachedModel != null) { if (cachedModel.isSuccess()) { modelActionListener.onResponse(cachedModel.getModel()); @@ -86,42 +85,42 @@ private boolean loadModelIfNecessary(String key, String modelId, long modelVersi // If the loaded model entry is there but is not present, that means the previous load attempt ran into an issue // Attempt to load and cache the model if necessary if (loadingListeners.computeIfPresent( - key, + modelId, (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) == null) { - logger.debug("[{}] version [{}] attempting to load and cache", modelId, modelVersion); - loadingListeners.put(key, addFluently(new ArrayDeque<>(), modelActionListener)); - loadModel(key, modelId, modelVersion); + logger.debug("[{}] attempting to load and cache", modelId); + loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener)); + loadModel(modelId); } return true; } // if the cachedModel entry is null, but there are listeners present, that means it is being loaded - return loadingListeners.computeIfPresent(key, + return loadingListeners.computeIfPresent(modelId, (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) != null; } } - private void loadModel(String modelKey, String modelId, long modelVersion) { - provider.getTrainedModel(modelId, modelVersion, ActionListener.wrap( + private void loadModel(String modelId) { + provider.getTrainedModel(modelId, ActionListener.wrap( trainedModelConfig -> { - logger.debug("[{}] successfully loaded model", modelKey); - handleLoadSuccess(modelKey, trainedModelConfig); + logger.debug("[{}] successfully loaded model", modelId); + handleLoadSuccess(modelId, trainedModelConfig); }, failure -> { - logger.warn(new ParameterizedMessage("[{}] failed to load model", modelKey), failure); - handleLoadFailure(modelKey, failure); + logger.warn(new ParameterizedMessage("[{}] failed to load model", modelId), failure); + handleLoadFailure(modelId, failure); } )); } - private void handleLoadSuccess(String modelKey, TrainedModelConfig trainedModelConfig) { + private void handleLoadSuccess(String modelId, TrainedModelConfig trainedModelConfig) { Queue> listeners; Model loadedModel = new LocalModel(trainedModelConfig.getModelId(), trainedModelConfig.getDefinition()); synchronized (loadingListeners) { - listeners = loadingListeners.remove(modelKey); + listeners = loadingListeners.remove(modelId); // If there is no loadingListener that means the loading was canceled and the listener was already notified as such // Consequently, we should not store the retrieved model if (listeners != null) { - loadedModels.put(modelKey, MaybeModel.of(loadedModel)); + loadedModels.put(modelId, MaybeModel.of(loadedModel)); } } if (listeners != null) { @@ -131,14 +130,14 @@ private void handleLoadSuccess(String modelKey, TrainedModelConfig trainedModelC } } - private void handleLoadFailure(String modelKey, Exception failure) { + private void handleLoadFailure(String modelId, Exception failure) { Queue> listeners; synchronized (loadingListeners) { - listeners = loadingListeners.remove(modelKey); + listeners = loadingListeners.remove(modelId); if (listeners != null) { // If we failed to load and there were listeners present, that means that this model is referenced by a processor // Add an empty entry here so that we can attempt to load and cache the model again when it is accessed again. - loadedModels.computeIfAbsent(modelKey, (key) -> MaybeModel.of(failure)); + loadedModels.computeIfAbsent(modelId, (key) -> MaybeModel.of(failure)); } } if (listeners != null) { @@ -159,9 +158,9 @@ public void clusterChanged(ClusterChangedEvent event) { synchronized (loadingListeners) { // If we had models still loading here but are no longer referenced // we should remove them from loadingListeners and alert the listeners - for (String modelKey : loadingListeners.keySet()) { - if (allReferencedModelKeys.contains(modelKey) == false) { - drainWithFailure.add(Tuple.tuple(splitModelKey(modelKey).v1(), new ArrayList<>(loadingListeners.remove(modelKey)))); + for (String modelId : loadingListeners.keySet()) { + if (allReferencedModelKeys.contains(modelId) == false) { + drainWithFailure.add(Tuple.tuple(modelId, new ArrayList<>(loadingListeners.remove(modelId)))); } } @@ -194,15 +193,14 @@ public void clusterChanged(ClusterChangedEvent event) { } } - private void loadModels(Set modelKeys) { - if (modelKeys.isEmpty()) { + private void loadModels(Set modelIds) { + if (modelIds.isEmpty()) { return; } // Execute this on a utility thread as when the callbacks occur we don't want them tying up the cluster listener thread pool threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> { - for (String modelKey : modelKeys) { - Tuple modelIdAndVersion = splitModelKey(modelKey); - this.loadModel(modelKey, modelIdAndVersion.v1(), modelIdAndVersion.v2()); + for (String modelId : modelIds) { + this.loadModel(modelId); } }); } @@ -212,17 +210,6 @@ private static Queue addFluently(Queue queue, T object) { return queue; } - private static String modelKey(String modelId, long modelVersion) { - return modelId + "_" + modelVersion; - } - - private static Tuple splitModelKey(String modelKey) { - int delim = modelKey.lastIndexOf('_'); - String modelId = modelKey.substring(0, delim); - Long version = Long.valueOf(modelKey.substring(delim + 1)); - return Tuple.tuple(modelId, version); - } - private static Set getReferencedModelKeys(IngestMetadata ingestMetadata) { Set allReferencedModelKeys = new HashSet<>(); if (ingestMetadata != null) { @@ -237,7 +224,7 @@ private static Set getReferencedModelKeys(IngestMetadata ingestMetadata) if (modelId != null) { assert modelId instanceof String; // TODO also read model version - allReferencedModelKeys.add(modelKey(modelId.toString(), 0)); + allReferencedModelKeys.add(modelId.toString()); } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 4fbf973aae7f0..36cfe23d9398a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -35,6 +35,7 @@ import org.junit.Before; import java.io.IOException; +import java.time.Instant; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -78,9 +79,9 @@ public void testGetCachedModels() throws Exception { String model1 = "test-load-model-1"; String model2 = "test-load-model-2"; String model3 = "test-load-model-3"; - withTrainedModel(model1, 0); - withTrainedModel(model2, 0); - withTrainedModel(model3, 0); + withTrainedModel(model1); + withTrainedModel(model2); + withTrainedModel(model3); ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); @@ -90,95 +91,96 @@ public void testGetCachedModels() throws Exception { for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model, 0, future); + modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } - verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(0L), any()); - verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(0L), any()); - verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(0L), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), any()); } public void testGetCachedMissingModel() throws Exception { String model = "test-load-cached-missing-model"; - withMissingModel(model, 0); + withMissingModel(model); ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); modelLoadingService.clusterChanged(ingestChangedEvent(model)); PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model, 0, future); + modelLoadingService.getModel(model, future); try { future.get(); fail("Should not have succeeded in loaded model"); } catch (Exception ex) { - assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model, 0))); + assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model))); } - verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model), eq(0L), any()); + verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model), any()); } public void testGetMissingModel() { String model = "test-load-missing-model"; - withMissingModel(model, 0); + withMissingModel(model); ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model, 0, future); + modelLoadingService.getModel(model, future); try { future.get(); fail("Should not have succeeded"); } catch (Exception ex) { - assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model, 0))); + assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model))); } } public void testGetModelEagerly() throws Exception { String model = "test-get-model-eagerly"; - withTrainedModel(model, 0); + withTrainedModel(model); ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); for(int i = 0; i < 3; i++) { PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model, 0, future); + modelLoadingService.getModel(model, future); assertThat(future.get(), is(not(nullValue()))); } - verify(trainedModelProvider, times(3)).getTrainedModel(eq(model), eq(0L), any()); + verify(trainedModelProvider, times(3)).getTrainedModel(eq(model), any()); } @SuppressWarnings("unchecked") - private void withTrainedModel(String modelId, long modelVersion) { - TrainedModelConfig trainedModelConfig = buildTrainedModelConfigBuilder(modelId, modelVersion).build(Version.CURRENT); + private void withTrainedModel(String modelId) { + TrainedModelConfig trainedModelConfig = buildTrainedModelConfigBuilder(modelId) + .setVersion(Version.CURRENT) + .setCreateTime(Instant.now()) + .build(); doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; listener.onResponse(trainedModelConfig); return null; - }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(modelVersion), any()); + }).when(trainedModelProvider).getTrainedModel(eq(modelId), any()); } - private void withMissingModel(String modelId, long modelVersion) { + private void withMissingModel(String modelId) { doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; listener.onFailure(new ResourceNotFoundException( - Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId, modelVersion))); + Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); return null; - }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(modelVersion), any()); + }).when(trainedModelProvider).getTrainedModel(eq(modelId), any()); } - private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId, long modelVersion) { + private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) { return TrainedModelConfig.builder() .setCreatedBy("ml_test") .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setDescription("trained model config for test") - .setModelId(modelId) - .setModelType("binary_decision_tree") - .setModelVersion(modelVersion); + .setModelId(modelId); } private static ClusterChangedEvent ingestChangedEvent(String... modelId) throws IOException { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 36f5817e40f3a..c2cfbe4f15498 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -25,6 +25,7 @@ import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.junit.Before; +import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -64,13 +65,17 @@ public void testInferModels() throws Exception { .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) .setTrainedModel(buildClassification(true))) - .build(Version.CURRENT); + .setVersion(Version.CURRENT) + .setCreateTime(Instant.now()) + .build(); TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1) .setDefinition(new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) .setTrainedModel(buildRegression())) - .build(Version.CURRENT); + .setVersion(Version.CURRENT) + .setCreateTime(Instant.now()) + .build(); AtomicReference putConfigHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); @@ -107,19 +112,19 @@ public void testInferModels() throws Exception { }}); // Test regression - InferModelAction.Request request = new InferModelAction.Request(modelId1, 0, toInfer, new RegressionConfig()); + InferModelAction.Request request = new InferModelAction.Request(modelId1, toInfer, new RegressionConfig()); InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet(); assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()), contains(1.3, 1.25)); - request = new InferModelAction.Request(modelId1, 0, toInfer2, new RegressionConfig()); + request = new InferModelAction.Request(modelId1, toInfer2, new RegressionConfig()); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()), contains(1.65, 1.55)); // Test classification - request = new InferModelAction.Request(modelId2, 0, toInfer, new ClassificationConfig(0)); + request = new InferModelAction.Request(modelId2, toInfer, new ClassificationConfig(0)); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); assertThat(response.getInferenceResults() .stream() @@ -128,7 +133,7 @@ public void testInferModels() throws Exception { contains("not_to_be", "to_be")); // Get top classes - request = new InferModelAction.Request(modelId2, 0, toInfer, new ClassificationConfig(2)); + request = new InferModelAction.Request(modelId2, toInfer, new ClassificationConfig(2)); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); ClassificationInferenceResults classificationInferenceResults = @@ -147,7 +152,7 @@ public void testInferModels() throws Exception { greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); // Test that top classes restrict the number returned - request = new InferModelAction.Request(modelId2, 0, toInfer2, new ClassificationConfig(1)); + request = new InferModelAction.Request(modelId2, toInfer2, new ClassificationConfig(1)); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0); @@ -157,7 +162,7 @@ public void testInferModels() throws Exception { public void testInferMissingModel() { String model = "test-infer-missing-model"; - InferModelAction.Request request = new InferModelAction.Request(model, 0, Collections.emptyList(), new RegressionConfig()); + InferModelAction.Request request = new InferModelAction.Request(model, Collections.emptyList(), new RegressionConfig()); try { client().execute(InferModelAction.INSTANCE, request).actionGet(); } catch (ElasticsearchException ex) { @@ -170,8 +175,7 @@ private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String .setCreatedBy("ml_test") .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setDescription("trained model config for test") - .setModelId(modelId) - .setModelType("binary_decision_tree"); + .setModelId(modelId); } @Override From 4dc58e5038dc96d8c1cc95bc9d74f2459d217c77 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Fri, 18 Oct 2019 07:00:36 -0400 Subject: [PATCH 04/17] [ML][Inference] adds logistic_regression output aggregator (#48075) * [ML][Inference] adds logistic_regression output aggregator * Addressing PR comments --- .../MlInferenceNamedXContentProvider.java | 4 + .../ensemble/LogisticRegression.java | 83 ++++++++++ .../client/RestHighLevelClientTests.java | 7 +- .../trainedmodel/ensemble/EnsembleTests.java | 2 +- .../ensemble/LogisticRegressionTests.java | 51 ++++++ .../xpack/core/XPackClientPlugin.java | 4 + .../MlInferenceNamedXContentProvider.java | 10 ++ .../ensemble/LogisticRegression.java | 153 ++++++++++++++++++ .../trainedmodel/ensemble/WeightedMode.java | 2 +- .../trainedmodel/ensemble/WeightedSum.java | 2 +- .../core/ml/inference/utils/Statistics.java | 4 + .../trainedmodel/ensemble/EnsembleTests.java | 4 +- .../ensemble/LogisticRegressionTests.java | 66 ++++++++ .../ml/inference/utils/StatisticsTests.java | 17 ++ 14 files changed, 402 insertions(+), 7 deletions(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/LogisticRegression.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java index 2325bbf27baa0..6b7fb6d3cd17c 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java @@ -20,6 +20,7 @@ import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum; @@ -60,6 +61,9 @@ public List getNamedXContentParsers() { namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class, new ParseField(WeightedSum.NAME), WeightedSum::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class, + new ParseField(LogisticRegression.NAME), + LogisticRegression::fromXContent)); return namedXContent; } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/LogisticRegression.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/LogisticRegression.java new file mode 100644 index 0000000000000..e2e2f4f2b0d50 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/LogisticRegression.java @@ -0,0 +1,83 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.ensemble; + + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + + +public class LogisticRegression implements OutputAggregator { + + public static final String NAME = "logistic_regression"; + public static final ParseField WEIGHTS = new ParseField("weights"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + true, + a -> new LogisticRegression((List)a[0])); + static { + PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS); + } + + public static LogisticRegression fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final List weights; + + public LogisticRegression(List weights) { + this.weights = weights; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (weights != null) { + builder.field(WEIGHTS.getPreferredName(), weights); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LogisticRegression that = (LogisticRegression) o; + return Objects.equals(weights, that.weights); + } + + @Override + public int hashCode() { + return Objects.hash(weights); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index cd8b518ac8311..e2ccc5a1fd573 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -68,6 +68,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode; import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum; import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree; @@ -686,7 +687,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(48, namedXContents.size()); + assertEquals(49, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -750,9 +751,9 @@ public void testProvidedNamedXContents() { assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME)); assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class)); assertThat(names, hasItems(Tree.NAME, Ensemble.NAME)); - assertEquals(Integer.valueOf(2), + assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator.class)); - assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME)); + assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME, LogisticRegression.NAME)); } public void testApiNamingConventions() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 774ab26bc17c7..d3431fe6b8961 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -68,7 +68,7 @@ public static Ensemble createRandom() { OutputAggregator outputAggregator = null; if (randomBoolean()) { List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList()); - outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights)); + outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights), new LogisticRegression(weights)); } List categoryLabels = null; if (randomBoolean()) { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java new file mode 100644 index 0000000000000..a345db3da1748 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java @@ -0,0 +1,51 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +public class LogisticRegressionTests extends AbstractXContentTestCase { + + LogisticRegression createTestInstance(int numberOfWeights) { + return new LogisticRegression(Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList())); + } + + @Override + protected LogisticRegression doParseInstance(XContentParser parser) throws IOException { + return LogisticRegression.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected LogisticRegression createTestInstance() { + return randomBoolean() ? new LogisticRegression(null) : createTestInstance(randomIntBetween(1, 100)); + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index b5f973878b35a..c87604357df38 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -148,6 +148,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LogisticRegression; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum; @@ -470,6 +471,9 @@ public List getNamedWriteables() { // ML - Inference aggregators new NamedWriteableRegistry.Entry(OutputAggregator.class, WeightedSum.NAME.getPreferredName(), WeightedSum::new), new NamedWriteableRegistry.Entry(OutputAggregator.class, WeightedMode.NAME.getPreferredName(), WeightedMode::new), + new NamedWriteableRegistry.Entry(OutputAggregator.class, + LogisticRegression.NAME.getPreferredName(), + LogisticRegression::new), // ML - Inference Results new NamedWriteableRegistry.Entry(InferenceResults.class, ClassificationInferenceResults.NAME, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 352271b6f27ab..ca380dac2bf00 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LogisticRegression; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.StrictlyParsedOutputAggregator; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode; @@ -67,6 +68,9 @@ public List getNamedXContentParsers() { namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class, WeightedSum.NAME, WeightedSum::fromXContentLenient)); + namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class, + LogisticRegression.NAME, + LogisticRegression::fromXContentLenient)); // Model Strict namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentStrict)); @@ -79,6 +83,9 @@ public List getNamedXContentParsers() { namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class, WeightedSum.NAME, WeightedSum::fromXContentStrict)); + namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class, + LogisticRegression.NAME, + LogisticRegression::fromXContentStrict)); return namedXContent; } @@ -105,6 +112,9 @@ public List getNamedWriteables() { namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class, WeightedMode.NAME.getPreferredName(), WeightedMode::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class, + LogisticRegression.NAME.getPreferredName(), + LogisticRegression::new)); // Inference Results namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java new file mode 100644 index 0000000000000..c8d06c2c1eb79 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java @@ -0,0 +1,153 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.stream.IntStream; + +import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.sigmoid; + +public class LogisticRegression implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator { + + public static final ParseField NAME = new ParseField("logistic_regression"); + public static final ParseField WEIGHTS = new ParseField("weights"); + + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + @SuppressWarnings("unchecked") + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( + NAME.getPreferredName(), + lenient, + a -> new LogisticRegression((List)a[0])); + parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS); + return parser; + } + + public static LogisticRegression fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null); + } + + public static LogisticRegression fromXContentLenient(XContentParser parser) { + return LENIENT_PARSER.apply(parser, null); + } + + private final List weights; + + LogisticRegression() { + this((List) null); + } + + public LogisticRegression(List weights) { + this.weights = weights == null ? null : Collections.unmodifiableList(weights); + } + + public LogisticRegression(StreamInput in) throws IOException { + if (in.readBoolean()) { + this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble)); + } else { + this.weights = null; + } + } + + @Override + public Integer expectedValueSize() { + return this.weights == null ? null : this.weights.size(); + } + + @Override + public List processValues(List values) { + Objects.requireNonNull(values, "values must not be null"); + if (weights != null && values.size() != weights.size()) { + throw new IllegalArgumentException("values must be the same length as weights."); + } + double summation = weights == null ? + values.stream().mapToDouble(Double::valueOf).sum() : + IntStream.range(0, weights.size()).mapToDouble(i -> values.get(i) * weights.get(i)).sum(); + double probOfClassOne = sigmoid(summation); + assert 0.0 <= probOfClassOne && probOfClassOne <= 1.0; + return Arrays.asList(1.0 - probOfClassOne, probOfClassOne); + } + + @Override + public double aggregate(List values) { + Objects.requireNonNull(values, "values must not be null"); + assert values.size() == 2; + int bestValue = 0; + double bestProb = Double.NEGATIVE_INFINITY; + for (int i = 0; i < values.size(); i++) { + if (values.get(i) == null) { + throw new IllegalArgumentException("values must not contain null values"); + } + if (values.get(i) > bestProb) { + bestProb = values.get(i); + bestValue = i; + } + } + return bestValue; + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public boolean compatibleWith(TargetType targetType) { + return true; + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(weights != null); + if (weights != null) { + out.writeCollection(weights, StreamOutput::writeDouble); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (weights != null) { + builder.field(WEIGHTS.getPreferredName(), weights); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LogisticRegression that = (LogisticRegression) o; + return Objects.equals(weights, that.weights); + } + + @Override + public int hashCode() { + return Objects.hash(weights); + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java index a872565ad20b5..525425db66d08 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -51,7 +51,7 @@ public static WeightedMode fromXContentLenient(XContentParser parser) { private final List weights; WeightedMode() { - this.weights = null; + this((List) null); } public WeightedMode(List weights) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java index db70346c0849e..f2ba7514b0c1c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java @@ -51,7 +51,7 @@ public static WeightedSum fromXContentLenient(XContentParser parser) { private final List weights; WeightedSum() { - this.weights = null; + this((List) null); } public WeightedSum(List weights) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java index 44cca308ea794..1cdddcd7af26b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java @@ -47,6 +47,10 @@ public static List softMax(List values) { return exps; } + public static double sigmoid(double value) { + return 1/(1 + Math.exp(-value)); + } + private static boolean isValid(Double v) { return v != null && Numbers.isValidDouble(v); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 816fabf7b0b67..52f317c2595c3 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -75,7 +75,9 @@ public static Ensemble createRandom() { List weights = randomBoolean() ? null : Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList()); - OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights)); + OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), + new WeightedSum(weights), + new LogisticRegression(weights)); List categoryLabels = null; if (randomBoolean()) { categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java new file mode 100644 index 0000000000000..142046d34ef75 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; + +public class LogisticRegressionTests extends WeightedAggregatorTests { + + @Override + LogisticRegression createTestInstance(int numberOfWeights) { + List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()); + return new LogisticRegression(weights); + } + + @Override + protected LogisticRegression doParseInstance(XContentParser parser) throws IOException { + return lenient ? LogisticRegression.fromXContentLenient(parser) : LogisticRegression.fromXContentStrict(parser); + } + + @Override + protected LogisticRegression createTestInstance() { + return randomBoolean() ? new LogisticRegression() : createTestInstance(randomIntBetween(1, 100)); + } + + @Override + protected Writeable.Reader instanceReader() { + return LogisticRegression::new; + } + + public void testAggregate() { + List ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0); + List values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0); + + LogisticRegression logisticRegression = new LogisticRegression(ones); + assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(1.0)); + + List variedWeights = Arrays.asList(.01, -1.0, .1, 0.0, 0.0); + + logisticRegression = new LogisticRegression(variedWeights); + assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(0.0)); + + logisticRegression = new LogisticRegression(); + assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(1.0)); + } + + public void testCompatibleWith() { + LogisticRegression logisticRegression = createTestInstance(); + assertThat(logisticRegression.compatibleWith(TargetType.CLASSIFICATION), is(true)); + assertThat(logisticRegression.compatibleWith(TargetType.REGRESSION), is(true)); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java index 5fb69238b1579..cc99a19b38a73 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.utils; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.test.ESTestCase; import java.util.Arrays; @@ -30,4 +31,20 @@ public void testSoftMaxWithNoValidValues() { expectThrows(IllegalArgumentException.class, () -> Statistics.softMax(values)); } + public void testSigmoid() { + double eps = 0.000001; + List> paramsAndExpectedReturns = Arrays.asList( + Tuple.tuple(0.0, 0.5), + Tuple.tuple(0.5, 0.62245933), + Tuple.tuple(1.0, 0.73105857), + Tuple.tuple(10000.0, 1.0), + Tuple.tuple(-0.5, 0.3775406), + Tuple.tuple(-1.0, 0.2689414), + Tuple.tuple(-10000.0, 0.0) + ); + for (Tuple expectation : paramsAndExpectedReturns) { + assertThat(Statistics.sigmoid(expectation.v1()), closeTo(expectation.v2(), eps)); + } + } + } From 27d14a0b6b78ced91ec5544b4acea79992885787 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Fri, 18 Oct 2019 07:02:32 -0400 Subject: [PATCH 05/17] [ML][Inference] Adding read/del trained models (#47882) * [ML][Inference] Adding read/del trained models * addressing PR comments and fixing tests * adding error tests to ml_security blacklist * fixing tests --- .../ml/action/DeleteTrainedModelAction.java | 81 +++++++++ .../ml/action/GetTrainedModelsAction.java | 77 ++++++++ .../notifications/InferenceAuditMessage.java | 37 ++++ .../xpack/core/ml/utils/ExceptionsHelper.java | 4 + .../DeleteTrainedModelsRequestTests.java | 23 +++ .../action/GetTrainedModelsRequestTests.java | 26 +++ .../AnomalyDetectionAuditMessageTests.java | 16 +- .../ml/notifications/AuditMessageTests.java | 27 +++ .../DataFrameAnalyticsAuditMessageTests.java | 18 +- .../InferenceAuditMessageTests.java | 35 ++++ .../ml/qa/ml-with-security/build.gradle | 5 + .../xpack/ml/integration/TrainedModelIT.java | 168 ++++++++++++++++++ .../xpack/ml/MachineLearning.java | 29 ++- .../TransportDeleteTrainedModelAction.java | 134 ++++++++++++++ .../TransportGetTrainedModelsAction.java | 90 ++++++++++ .../inference/ingest/InferenceProcessor.java | 39 +++- .../persistence/TrainedModelProvider.java | 27 +++ .../ml/notifications/InferenceAuditor.java | 20 +++ .../RestDeleteTrainedModelAction.java | 39 ++++ .../inference/RestGetTrainedModelsAction.java | 52 ++++++ .../integration/ModelInferenceActionIT.java | 2 +- .../api/ml.delete_trained_model.json | 24 +++ .../api/ml.get_trained_models.json | 48 +++++ .../rest-api-spec/test/ml/inference_crud.yml | 110 ++++++++++++ 24 files changed, 1102 insertions(+), 29 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessage.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelsRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AuditMessageTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessageTests.java create mode 100644 x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/InferenceAuditor.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestDeleteTrainedModelAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_trained_model.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAction.java new file mode 100644 index 0000000000000..521070a959db6 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAction.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ToXContentFragment; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class DeleteTrainedModelAction extends ActionType { + + public static final DeleteTrainedModelAction INSTANCE = new DeleteTrainedModelAction(); + public static final String NAME = "cluster:admin/xpack/ml/inference/delete"; + + private DeleteTrainedModelAction() { + super(NAME, AcknowledgedResponse::new); + } + + public static class Request extends AcknowledgedRequest implements ToXContentFragment { + + private String id; + + public Request(StreamInput in) throws IOException { + super(in); + id = in.readString(); + } + + public Request() {} + + public Request(String id) { + this.id = ExceptionsHelper.requireNonNull(id, TrainedModelConfig.MODEL_ID); + } + + public String getId() { + return id; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), id); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DeleteTrainedModelAction.Request request = (DeleteTrainedModelAction.Request) o; + return Objects.equals(id, request.id); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(id); + } + + @Override + public int hashCode() { + return Objects.hash(id); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java new file mode 100644 index 0000000000000..005f0d180cdc1 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest; +import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse; +import org.elasticsearch.xpack.core.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; + +import java.io.IOException; + +public class GetTrainedModelsAction extends ActionType { + + public static final GetTrainedModelsAction INSTANCE = new GetTrainedModelsAction(); + public static final String NAME = "cluster:monitor/xpack/ml/inference/get"; + + private GetTrainedModelsAction() { + super(NAME, Response::new); + } + + public static class Request extends AbstractGetResourcesRequest { + + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + public Request() { + setAllowNoResources(true); + } + + public Request(String id) { + setResourceId(id); + setAllowNoResources(true); + } + + public Request(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getResourceIdField() { + return TrainedModelConfig.MODEL_ID.getPreferredName(); + } + + } + + public static class Response extends AbstractGetResourcesResponse { + + public static final ParseField RESULTS_FIELD = new ParseField("trained_model_configs"); + + public Response(StreamInput in) throws IOException { + super(in); + } + + public Response(QueryPage trainedModels) { + super(trainedModels); + } + + @Override + protected Reader getReader() { + return TrainedModelConfig::new; + } + } + + public static class RequestBuilder extends ActionRequestBuilder { + + public RequestBuilder(ElasticsearchClient client) { + super(client, INSTANCE, new Request()); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessage.java new file mode 100644 index 0000000000000..7c1b93786bc91 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessage.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.notifications; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.xpack.core.common.notifications.AbstractAuditMessage; +import org.elasticsearch.xpack.core.common.notifications.Level; +import org.elasticsearch.xpack.core.ml.job.config.Job; + +import java.util.Date; + + +public class InferenceAuditMessage extends AbstractAuditMessage { + + //TODO this should be MODEL_ID... + private static final ParseField JOB_ID = Job.ID; + public static final ConstructingObjectParser PARSER = + createParser("ml_inference_audit_message", InferenceAuditMessage::new, JOB_ID); + + public InferenceAuditMessage(String resourceId, String message, Level level, Date timestamp, String nodeName) { + super(resourceId, message, level, timestamp, nodeName); + } + + @Override + public final String getJobType() { + return "inference"; + } + + @Override + protected String getResourceField() { + return JOB_ID.getPreferredName(); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java index d2e5a207355f3..9cc4a5cdfbf53 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java @@ -43,6 +43,10 @@ public static ResourceAlreadyExistsException dataFrameAnalyticsAlreadyExists(Str return new ResourceAlreadyExistsException("A data frame analytics with id [{}] already exists", id); } + public static ResourceNotFoundException missingTrainedModel(String modelId) { + return new ResourceNotFoundException("No known trained model with model_id [{}]", modelId); + } + public static ElasticsearchException serverError(String msg) { return new ElasticsearchException(msg); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelsRequestTests.java new file mode 100644 index 0000000000000..0797b20d438be --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelsRequestTests.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction.Request; + +public class DeleteTrainedModelsRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + return new Request(randomAlphaOfLengthBetween(1, 20)); + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java new file mode 100644 index 0000000000000..0abc0318e215e --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request; + +public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + Request request = new Request(randomAlphaOfLength(20)); + request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100))); + return request; + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AnomalyDetectionAuditMessageTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AnomalyDetectionAuditMessageTests.java index c6a904228b6a7..f6a319dab7ab3 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AnomalyDetectionAuditMessageTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AnomalyDetectionAuditMessageTests.java @@ -6,19 +6,16 @@ package org.elasticsearch.xpack.core.ml.notifications; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.xpack.core.common.notifications.Level; import org.elasticsearch.xpack.core.ml.job.config.Job; import java.util.Date; -import static org.hamcrest.Matchers.equalTo; +public class AnomalyDetectionAuditMessageTests extends AuditMessageTests { -public class AnomalyDetectionAuditMessageTests extends AbstractXContentTestCase { - - public void testGetJobType() { - AnomalyDetectionAuditMessage message = createTestInstance(); - assertThat(message.getJobType(), equalTo(Job.ANOMALY_DETECTOR_JOB_TYPE)); + @Override + public String getJobType() { + return Job.ANOMALY_DETECTOR_JOB_TYPE; } @Override @@ -26,11 +23,6 @@ protected AnomalyDetectionAuditMessage doParseInstance(XContentParser parser) { return AnomalyDetectionAuditMessage.PARSER.apply(parser, null); } - @Override - protected boolean supportsUnknownFields() { - return true; - } - @Override protected AnomalyDetectionAuditMessage createTestInstance() { return new AnomalyDetectionAuditMessage( diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AuditMessageTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AuditMessageTests.java new file mode 100644 index 0000000000000..2ccb1fbcbf4b3 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AuditMessageTests.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.notifications; + +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xpack.core.common.notifications.AbstractAuditMessage; + + +import static org.hamcrest.Matchers.equalTo; + +public abstract class AuditMessageTests extends AbstractXContentTestCase { + + public abstract String getJobType(); + + public void testGetJobType() { + AbstractAuditMessage message = createTestInstance(); + assertThat(message.getJobType(), equalTo(getJobType())); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/DataFrameAnalyticsAuditMessageTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/DataFrameAnalyticsAuditMessageTests.java index 139e76160d4a6..9637af79a947c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/DataFrameAnalyticsAuditMessageTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/DataFrameAnalyticsAuditMessageTests.java @@ -6,28 +6,20 @@ package org.elasticsearch.xpack.core.ml.notifications; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.xpack.core.common.notifications.Level; import java.util.Date; -import static org.hamcrest.Matchers.equalTo; +public class DataFrameAnalyticsAuditMessageTests extends AuditMessageTests { -public class DataFrameAnalyticsAuditMessageTests extends AbstractXContentTestCase { - - public void testGetJobType() { - DataFrameAnalyticsAuditMessage message = createTestInstance(); - assertThat(message.getJobType(), equalTo("data_frame_analytics")); - } - @Override - protected DataFrameAnalyticsAuditMessage doParseInstance(XContentParser parser) { - return DataFrameAnalyticsAuditMessage.PARSER.apply(parser, null); + public String getJobType() { + return "data_frame_analytics"; } @Override - protected boolean supportsUnknownFields() { - return true; + protected DataFrameAnalyticsAuditMessage doParseInstance(XContentParser parser) { + return DataFrameAnalyticsAuditMessage.PARSER.apply(parser, null); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessageTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessageTests.java new file mode 100644 index 0000000000000..5a9b86578ef59 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessageTests.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.notifications; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.common.notifications.Level; + +import java.util.Date; + +public class InferenceAuditMessageTests extends AuditMessageTests { + + @Override + public String getJobType() { + return "inference"; + } + + @Override + protected InferenceAuditMessage doParseInstance(XContentParser parser) { + return InferenceAuditMessage.PARSER.apply(parser, null); + } + + @Override + protected InferenceAuditMessage createTestInstance() { + return new InferenceAuditMessage( + randomBoolean() ? null : randomAlphaOfLength(10), + randomAlphaOfLengthBetween(1, 20), + randomFrom(Level.values()), + new Date(), + randomBoolean() ? null : randomAlphaOfLengthBetween(1, 20) + ); + } +} diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index e330d032c0a0d..2dd63883b523a 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -125,6 +125,11 @@ integTest.runner { 'ml/filter_crud/Test get all filter given index exists but no mapping for filter_id', 'ml/get_datafeed_stats/Test get datafeed stats given missing datafeed_id', 'ml/get_datafeeds/Test get datafeed given missing datafeed_id', + 'ml/inference_crud/Test delete given used trained model', + 'ml/inference_crud/Test delete given unused trained model', + 'ml/inference_crud/Test delete with missing model', + 'ml/inference_crud/Test get given missing trained model', + 'ml/inference_crud/Test get given expression without matches and allow_no_match is false', 'ml/jobs_crud/Test cannot create job with existing categorizer state document', 'ml/jobs_crud/Test cannot create job with existing quantiles document', 'ml/jobs_crud/Test cannot create job with existing result document', diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java new file mode 100644 index 0000000000000..6d3fe32332a72 --- /dev/null +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -0,0 +1,168 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.apache.http.util.EntityUtils; +import org.elasticsearch.Version; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.client.ResponseException; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.test.SecuritySettingsSourceField; +import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests; +import org.junit.After; + +import java.io.IOException; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; + +import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; + +public class TrainedModelIT extends ESRestTestCase { + + private static final String BASIC_AUTH_VALUE = basicAuthHeaderValue("x_pack_rest_user", + SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING); + + @Override + protected Settings restClientSettings() { + return Settings.builder().put(super.restClientSettings()).put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE).build(); + } + + @Override + protected boolean preserveTemplatesUponCompletion() { + return true; + } + + public void testGetTrainedModels() throws IOException { + String modelId = "test_regression_model"; + String modelId2 = "test_regression_model-2"; + Request model1 = new Request("PUT", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId); + model1.setJsonEntity(buildRegressionModel(modelId)); + assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201)); + + Request model2 = new Request("PUT", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId2); + model2.setJsonEntity(buildRegressionModel(modelId2)); + assertThat(client().performRequest(model2).getStatusLine().getStatusCode(), equalTo(201)); + + adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh")); + Response getModel = client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/" + modelId)); + + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + String response = EntityUtils.toString(getModel.getEntity()); + + assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); + assertThat(response, containsString("\"count\":1")); + + getModel = client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/test_regression*")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); + assertThat(response, containsString("\"model_id\":\"test_regression_model-2\"")); + assertThat(response, containsString("\"count\":2")); + + getModel = client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/test_regression_model,test_regression_model-2")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); + assertThat(response, containsString("\"model_id\":\"test_regression_model-2\"")); + assertThat(response, containsString("\"count\":2")); + + getModel = client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/classification*?allow_no_match=true")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"count\":0")); + + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/classification*?allow_no_match=false"))); + assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(404)); + + getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference?from=0&size=1")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"count\":2")); + assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); + assertThat(response, not(containsString("\"model_id\":\"test_regression_model-2\""))); + + getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference?from=1&size=1")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"count\":2")); + assertThat(response, not(containsString("\"model_id\":\"test_regression_model\""))); + assertThat(response, containsString("\"model_id\":\"test_regression_model-2\"")); + } + + public void testDeleteTrainedModels() throws IOException { + String modelId = "test_delete_regression_model"; + Request model1 = new Request("PUT", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId); + model1.setJsonEntity(buildRegressionModel(modelId)); + assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201)); + + adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh")); + + Response delModel = client().performRequest(new Request("DELETE", + MachineLearning.BASE_PATH + "inference/" + modelId)); + String response = EntityUtils.toString(delModel.getEntity()); + assertThat(response, containsString("\"acknowledged\":true")); + + ResponseException responseException = expectThrows(ResponseException.class, + () -> client().performRequest(new Request("DELETE", MachineLearning.BASE_PATH + "inference/" + modelId))); + assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404)); + } + + private static String buildRegressionModel(String modelId) throws IOException { + try(XContentBuilder builder = XContentFactory.jsonBuilder()) { + TrainedModelConfig.builder() + .setModelId(modelId) + .setCreatedBy("ml_test") + .setDefinition(new TrainedModelDefinition.Builder() + .setInput(new TrainedModelDefinition.Input(Arrays.asList("col1", "col2", "col3"))) + .setPreProcessors(Collections.emptyList()) + .setTrainedModel(LocalModelTests.buildRegression())) + .setVersion(Version.CURRENT) + .setCreateTime(Instant.now()) + .build() + .toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"))); + return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON); + } + } + + + @After + public void clearMlState() throws Exception { + new MlRestTestStateCleaner(logger, adminClient()).clearMlMetadata(); + ESRestTestCase.waitForPendingTasks(adminClient()); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 46084836b6b89..500a71b3a9416 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -42,12 +42,14 @@ import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.analysis.TokenizerFactory; import org.elasticsearch.indices.analysis.AnalysisModule.AnalysisProvider; +import org.elasticsearch.ingest.Processor; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.monitor.os.OsProbe; import org.elasticsearch.monitor.os.OsStats; import org.elasticsearch.persistent.PersistentTasksExecutor; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.AnalysisPlugin; +import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.PersistentTaskPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.RestController; @@ -73,6 +75,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction; import org.elasticsearch.xpack.core.ml.action.DeleteJobAction; import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction; @@ -94,6 +97,7 @@ import org.elasticsearch.xpack.core.ml.action.GetModelSnapshotsAction; import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction; import org.elasticsearch.xpack.core.ml.action.GetRecordsAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.KillProcessAction; @@ -141,6 +145,7 @@ import org.elasticsearch.xpack.ml.action.TransportDeleteForecastAction; import org.elasticsearch.xpack.ml.action.TransportDeleteJobAction; import org.elasticsearch.xpack.ml.action.TransportDeleteModelSnapshotAction; +import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction; import org.elasticsearch.xpack.ml.action.TransportEstimateMemoryUsageAction; import org.elasticsearch.xpack.ml.action.TransportEvaluateDataFrameAction; import org.elasticsearch.xpack.ml.action.TransportFinalizeJobExecutionAction; @@ -163,6 +168,7 @@ import org.elasticsearch.xpack.ml.action.TransportGetOverallBucketsAction; import org.elasticsearch.xpack.ml.action.TransportGetRecordsAction; import org.elasticsearch.xpack.ml.action.TransportInferModelAction; +import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsAction; import org.elasticsearch.xpack.ml.action.TransportIsolateDatafeedAction; import org.elasticsearch.xpack.ml.action.TransportKillProcessAction; import org.elasticsearch.xpack.ml.action.TransportMlInfoAction; @@ -203,6 +209,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.job.JobManager; @@ -225,6 +232,7 @@ import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerProcessFactory; import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import org.elasticsearch.xpack.ml.process.DummyController; import org.elasticsearch.xpack.ml.process.MlController; import org.elasticsearch.xpack.ml.process.MlControllerHolder; @@ -263,6 +271,8 @@ import org.elasticsearch.xpack.ml.rest.filter.RestGetFiltersAction; import org.elasticsearch.xpack.ml.rest.filter.RestPutFilterAction; import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction; +import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction; +import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction; import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction; @@ -303,7 +313,7 @@ import static java.util.Collections.emptyList; import static org.elasticsearch.index.mapper.MapperService.SINGLE_MAPPING_NAME; -public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlugin, PersistentTaskPlugin { +public class MachineLearning extends Plugin implements ActionPlugin, IngestPlugin, AnalysisPlugin, PersistentTaskPlugin { public static final String NAME = "ml"; public static final String BASE_PATH = "/_ml/"; public static final String PRE_V7_BASE_PATH = "/_xpack/ml/"; @@ -327,6 +337,15 @@ protected Setting roleSetting() { }; + @Override + public Map getProcessors(Processor.Parameters parameters) { + InferenceProcessor.Factory inferenceFactory = new InferenceProcessor.Factory(parameters.client, + parameters.ingestService.getClusterService(), + this.settings); + parameters.ingestService.addIngestClusterStateListener(inferenceFactory); + return Collections.singletonMap(InferenceProcessor.TYPE, inferenceFactory); + } + @Override public Set getRoles() { return Collections.singleton(ML_ROLE); @@ -487,6 +506,7 @@ public Collection createComponents(Client client, ClusterService cluster AnomalyDetectionAuditor anomalyDetectionAuditor = new AnomalyDetectionAuditor(client, clusterService.getNodeName()); DataFrameAnalyticsAuditor dataFrameAnalyticsAuditor = new DataFrameAnalyticsAuditor(client, clusterService.getNodeName()); + InferenceAuditor inferenceAuditor = new InferenceAuditor(client, clusterService.getNodeName()); this.dataFrameAnalyticsAuditor.set(dataFrameAnalyticsAuditor); JobResultsProvider jobResultsProvider = new JobResultsProvider(client, settings); JobResultsPersister jobResultsPersister = new JobResultsPersister(client); @@ -619,6 +639,7 @@ public Collection createComponents(Client client, ClusterService cluster datafeedManager, anomalyDetectionAuditor, dataFrameAnalyticsAuditor, + inferenceAuditor, mlAssignmentNotifier, memoryTracker, analyticsProcessManager, @@ -709,7 +730,9 @@ public List getRestHandlers(Settings settings, RestController restC new RestStartDataFrameAnalyticsAction(restController), new RestStopDataFrameAnalyticsAction(restController), new RestEvaluateDataFrameAction(restController), - new RestEstimateMemoryUsageAction(restController) + new RestEstimateMemoryUsageAction(restController), + new RestGetTrainedModelsAction(restController), + new RestDeleteTrainedModelAction(restController) ); } @@ -782,6 +805,8 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(EvaluateDataFrameAction.INSTANCE, TransportEvaluateDataFrameAction.class), new ActionHandler<>(EstimateMemoryUsageAction.INSTANCE, TransportEstimateMemoryUsageAction.class), new ActionHandler<>(InferModelAction.INSTANCE, TransportInferModelAction.class), + new ActionHandler<>(GetTrainedModelsAction.INSTANCE, TransportGetTrainedModelsAction.class), + new ActionHandler<>(DeleteTrainedModelAction.INSTANCE, TransportDeleteTrainedModelAction.class), usageAction, infoAction); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java new file mode 100644 index 0000000000000..aadcb9dd34708 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java @@ -0,0 +1,134 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.IngestService; +import org.elasticsearch.ingest.Pipeline; +import org.elasticsearch.ingest.PipelineConfiguration; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + + +/** + * The action is a master node action to ensure it reads an up-to-date cluster + * state in order to determine if there is a processor referencing the trained model + */ +public class TransportDeleteTrainedModelAction + extends TransportMasterNodeAction { + + private static final Logger LOGGER = LogManager.getLogger(TransportDeleteTrainedModelAction.class); + + private final TrainedModelProvider trainedModelProvider; + private final InferenceAuditor auditor; + private final IngestService ingestService; + + @Inject + public TransportDeleteTrainedModelAction(TransportService transportService, ClusterService clusterService, + ThreadPool threadPool, ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver, + TrainedModelProvider configProvider, InferenceAuditor auditor, + IngestService ingestService) { + super(DeleteTrainedModelAction.NAME, transportService, clusterService, threadPool, actionFilters, + DeleteTrainedModelAction.Request::new, indexNameExpressionResolver); + this.trainedModelProvider = configProvider; + this.ingestService = ingestService; + this.auditor = Objects.requireNonNull(auditor); + } + + @Override + protected String executor() { + return ThreadPool.Names.SAME; + } + + @Override + protected AcknowledgedResponse read(StreamInput in) throws IOException { + return new AcknowledgedResponse(in); + } + + @Override + protected void masterOperation(Task task, + DeleteTrainedModelAction.Request request, + ClusterState state, + ActionListener listener) { + String id = request.getId(); + IngestMetadata currentIngestMetadata = state.metaData().custom(IngestMetadata.TYPE); + Set referencedModels = getReferencedModelKeys(currentIngestMetadata); + + if (referencedModels.contains(id)) { + listener.onFailure(new ElasticsearchStatusException("Cannot delete model [{}] as it is still referenced by ingest processors", + RestStatus.CONFLICT, + id)); + return; + } + + trainedModelProvider.deleteTrainedModel(request.getId(), ActionListener.wrap( + r -> { + auditor.info(request.getId(), "trained model deleted"); + listener.onResponse(new AcknowledgedResponse(true)); + }, + listener::onFailure + )); + } + + private Set getReferencedModelKeys(IngestMetadata ingestMetadata) { + Set allReferencedModelKeys = new HashSet<>(); + if (ingestMetadata == null) { + return allReferencedModelKeys; + } + for(Map.Entry entry : ingestMetadata.getPipelines().entrySet()) { + String pipelineId = entry.getKey(); + Map config = entry.getValue().getConfigAsMap(); + try { + Pipeline pipeline = Pipeline.create(pipelineId, + config, + ingestService.getProcessorFactories(), + ingestService.getScriptService()); + pipeline.getProcessors().stream() + .filter(p -> p instanceof InferenceProcessor) + .map(p -> (InferenceProcessor) p) + .map(InferenceProcessor::getModelId) + .forEach(allReferencedModelKeys::add); + } catch (Exception ex) { + LOGGER.warn(new ParameterizedMessage("failed to load pipeline [{}]", pipelineId), ex); + } + } + return allReferencedModelKeys; + } + + + @Override + protected ClusterBlockException checkBlock(DeleteTrainedModelAction.Request request, ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java new file mode 100644 index 0000000000000..ee95ddbd9670d --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.action.AbstractTransportGetResourcesAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Response; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; + +public class TransportGetTrainedModelsAction extends AbstractTransportGetResourcesAction { + + @Inject + public TransportGetTrainedModelsAction(TransportService transportService, ActionFilters actionFilters, Client client, + NamedXContentRegistry xContentRegistry) { + super(GetTrainedModelsAction.NAME, transportService, actionFilters, GetTrainedModelsAction.Request::new, client, + xContentRegistry); + } + + @Override + protected ParseField getResultsField() { + return GetTrainedModelsAction.Response.RESULTS_FIELD; + } + + @Override + protected String[] getIndices() { + return new String[] { InferenceIndexConstants.INDEX_PATTERN }; + } + + @Override + protected TrainedModelConfig parse(XContentParser parser) { + return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build(); + } + + @Override + protected ResourceNotFoundException notFoundException(String resourceId) { + return ExceptionsHelper.missingTrainedModel(resourceId); + } + + @Override + protected void doExecute(Task task, GetTrainedModelsAction.Request request, + ActionListener listener) { + searchResources(request, ActionListener.wrap( + queryPage -> listener.onResponse(new GetTrainedModelsAction.Response(queryPage)), + listener::onFailure + )); + } + + @Override + protected String executionOrigin() { + return ML_ORIGIN; + } + + @Override + protected String extractIdFromResource(TrainedModelConfig config) { + return config.getModelId(); + } + + @Override + protected SearchSourceBuilder customSearchOptions(SearchSourceBuilder searchSourceBuilder) { + return searchSourceBuilder.sort("_index", SortOrder.DESC); + } + + @Nullable + protected QueryBuilder additionalQuery() { + return QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index 59f6c62a7f55e..b8cccc0d45e23 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -10,6 +10,14 @@ import org.elasticsearch.ingest.IngestDocument; import java.util.function.BiConsumer; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.ingest.ConfigurationUtils; +import org.elasticsearch.ingest.Processor; + +import java.util.Map; +import java.util.function.Consumer; public class InferenceProcessor extends AbstractProcessor { @@ -17,10 +25,16 @@ public class InferenceProcessor extends AbstractProcessor { public static final String MODEL_ID = "model_id"; private final Client client; + private final String modelId; - public InferenceProcessor(Client client, String tag) { + public InferenceProcessor(Client client, String tag, String modelId) { super(tag); this.client = client; + this.modelId = modelId; + } + + public String getModelId() { + return modelId; } @Override @@ -38,4 +52,27 @@ public IngestDocument execute(IngestDocument ingestDocument) { public String getType() { return TYPE; } + + public static class Factory implements Processor.Factory, Consumer { + + private final Client client; + private final ClusterService clusterService; + + public Factory(Client client, ClusterService clusterService, Settings settings) { + this.client = client; + this.clusterService = clusterService; + } + + @Override + public Processor create(Map processorFactories, String tag, Map config) + throws Exception { + String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID); + return new InferenceProcessor(client, tag, modelId); + } + + @Override + public void accept(ClusterState clusterState) { + + } + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 6f1e543896c9d..3ad5004b9a032 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -28,9 +28,12 @@ import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.reindex.DeleteByQueryAction; +import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; @@ -117,6 +120,30 @@ public void getTrainedModel(String modelId, ActionListener l listener::onFailure)); } + public void deleteTrainedModel(String modelId, ActionListener listener) { + DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false); + + request.indices(InferenceIndexConstants.INDEX_PATTERN); + QueryBuilder query = QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId); + request.setQuery(query); + request.setRefresh(true); + + executeAsyncWithOrigin(client, ML_ORIGIN, DeleteByQueryAction.INSTANCE, request, ActionListener.wrap(deleteResponse -> { + if (deleteResponse.getDeleted() == 0) { + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); + return; + } + listener.onResponse(true); + }, e -> { + if (e.getClass() == IndexNotFoundException.class) { + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); + } else { + listener.onFailure(e); + } + })); + } private void parseInferenceDocLenientlyFromSource(BytesReference source, String modelId, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/InferenceAuditor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/InferenceAuditor.java new file mode 100644 index 0000000000000..dfce44af7c9a4 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/InferenceAuditor.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.notifications; + +import org.elasticsearch.client.Client; +import org.elasticsearch.xpack.core.common.notifications.AbstractAuditor; +import org.elasticsearch.xpack.core.ml.notifications.AuditorField; +import org.elasticsearch.xpack.core.ml.notifications.InferenceAuditMessage; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; + +public class InferenceAuditor extends AbstractAuditor { + + public InferenceAuditor(Client client, String nodeName) { + super(client, nodeName, AuditorField.NOTIFICATIONS_INDEX, ML_ORIGIN, InferenceAuditMessage::new); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestDeleteTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestDeleteTrainedModelAction.java new file mode 100644 index 0000000000000..e9675be4d29fd --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestDeleteTrainedModelAction.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; + +import static org.elasticsearch.rest.RestRequest.Method.DELETE; + +public class RestDeleteTrainedModelAction extends BaseRestHandler { + + public RestDeleteTrainedModelAction(RestController controller) { + controller.registerHandler( + DELETE, MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}", this); + } + + @Override + public String getName() { + return "ml_delete_trained_models_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + DeleteTrainedModelAction.Request request = new DeleteTrainedModelAction.Request(modelId); + return channel -> client.execute(DeleteTrainedModelAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java new file mode 100644 index 0000000000000..40ddd05827043 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.common.Strings; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; + +import static org.elasticsearch.rest.RestRequest.Method.GET; +import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request.ALLOW_NO_MATCH; + +public class RestGetTrainedModelsAction extends BaseRestHandler { + + public RestGetTrainedModelsAction(RestController controller) { + controller.registerHandler( + GET, MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}", this); + controller.registerHandler(GET, MachineLearning.BASE_PATH + "inference", this); + } + + @Override + public String getName() { + return "ml_get_trained_models_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + if (Strings.isNullOrEmpty(modelId)) { + modelId = MetaData.ALL; + } + GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId); + if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) { + request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), + restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); + } + request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources())); + return channel -> client.execute(GetTrainedModelsAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index c2cfbe4f15498..04dff88417abb 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -166,7 +166,7 @@ public void testInferMissingModel() { try { client().execute(InferModelAction.INSTANCE, request).actionGet(); } catch (ElasticsearchException ex) { - assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model, 0))); + assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model))); } } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_trained_model.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_trained_model.json new file mode 100644 index 0000000000000..edfc157646f91 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_trained_model.json @@ -0,0 +1,24 @@ +{ + "ml.delete_trained_model":{ + "documentation":{ + "url":"TODO" + }, + "stability":"experimental", + "url":{ + "paths":[ + { + "path":"/_ml/inference/{model_id}", + "methods":[ + "DELETE" + ], + "parts":{ + "model_id":{ + "type":"string", + "description":"The ID of the trained model to delete" + } + } + } + ] + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json new file mode 100644 index 0000000000000..481f8b25975bb --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json @@ -0,0 +1,48 @@ +{ + "ml.get_trained_models":{ + "documentation":{ + "url":"TODO" + }, + "stability":"experimental", + "url":{ + "paths":[ + { + "path":"/_ml/inference/{model_id}", + "methods":[ + "GET" + ], + "parts":{ + "model_id":{ + "type":"string", + "description":"The ID of the trained models to fetch" + } + } + }, + { + "path":"/_ml/inference", + "methods":[ + "GET" + ] + } + ] + }, + "params":{ + "allow_no_match":{ + "type":"boolean", + "required":false, + "description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)", + "default":true + }, + "from":{ + "type":"int", + "description":"skips a number of trained models", + "default":0 + }, + "size":{ + "type":"int", + "description":"specifies a max number of trained models to get", + "default":100 + } + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml new file mode 100644 index 0000000000000..a18b29487eac5 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -0,0 +1,110 @@ +--- +"Test get-all given no trained models exist": + + - do: + ml.get_trained_models: + model_id: "_all" + - match: { count: 0 } + - match: { trained_model_configs: [] } + + - do: + ml.get_trained_models: + model_id: "*" + - match: { count: 0 } + - match: { trained_model_configs: [] } + +--- +"Test get given missing trained model": + + - do: + catch: missing + ml.get_trained_models: + model_id: "missing-trained-model" +--- +"Test get given expression without matches and allow_no_match is false": + + - do: + catch: missing + ml.get_trained_models: + model_id: "missing-trained-model*" + allow_no_match: false + +--- +"Test get given expression without matches and allow_no_match is true": + + - do: + ml.get_trained_models: + model_id: "missing-trained-model*" + allow_no_match: true + - match: { count: 0 } + - match: { trained_model_configs: [] } +--- +"Test delete given unused trained model": + + - do: + index: + id: trained_model_config-unused-regression-model-0 + index: .ml-inference-000001 + body: > + { + "model_id": "unused-regression-model", + "created_by": "ml_tests", + "version": "8.0.0", + "description": "empty model for tests", + "create_time": 0, + "model_version": 0, + "model_type": "local" + } + - do: + indices.refresh: {} + + - do: + ml.delete_trained_model: + model_id: "unused-regression-model" + - match: { acknowledged: true } + +--- +"Test delete with missing model": + - do: + catch: missing + ml.delete_trained_model: + model_id: "missing-trained-model" + +--- +"Test delete given used trained model": + - do: + index: + id: trained_model_config-used-regression-model-0 + index: .ml-inference-000001 + body: > + { + "model_id": "used-regression-model", + "created_by": "ml_tests", + "version": "8.0.0", + "description": "empty model for tests", + "create_time": 0, + "model_version": 0, + "model_type": "local" + } + - do: + indices.refresh: {} + + - do: + ingest.put_pipeline: + id: "regression-model-pipeline" + body: > + { + "processors": [ + { + "inference" : { + "model_id" : "used-regression-model" + } + } + ] + } + - match: { acknowledged: true } + + - do: + catch: conflict + ml.delete_trained_model: + model_id: "used-regression-model" From d839e6b369633e89739478bc8dd6f672b3a6e136 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 21 Oct 2019 08:55:27 -0400 Subject: [PATCH 06/17] [ML][Inference] Adding inference ingest processor (#47859) * [ML][Inference] Adding ingest processor * optionally including tag in model metadata injection in processor * fixing test * addressing PR comments * adding comment --- .../core/ml/action/InferModelAction.java | 2 +- .../trainedmodel/ClassificationConfig.java | 19 + .../trainedmodel/InferenceConfig.java | 5 + .../trainedmodel/RegressionConfig.java | 16 + .../ensemble/NullInferenceConfig.java | 6 + .../xpack/core/ml/job/messages/Messages.java | 2 + .../ClassificationConfigTests.java | 20 + .../trainedmodel/RegressionConfigTests.java | 16 + .../ml/integration/InferenceIngestIT.java | 530 ++++++++++++++++++ .../xpack/ml/MachineLearning.java | 10 +- .../inference/ingest/InferenceProcessor.java | 217 ++++++- .../loadingservice/ModelLoadingService.java | 21 +- .../InferenceProcessorFactoryTests.java | 266 +++++++++ .../ingest/InferenceProcessorTests.java | 191 +++++++ .../rest-api-spec/test/ml/inference_crud.yml | 5 +- 15 files changed, 1302 insertions(+), 24 deletions(-) create mode 100644 x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java index 24b5e345430df..29cab602dab06 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java @@ -27,7 +27,7 @@ public class InferModelAction extends ActionType { public static final InferModelAction INSTANCE = new InferModelAction(); - public static final String NAME = "cluster:admin/xpack/ml/infer"; + public static final String NAME = "cluster:admin/xpack/ml/inference/infer"; private InferModelAction() { super(NAME, Response::new); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java index 5aa0403d94753..4c9fc4d89e93b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java @@ -5,12 +5,16 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.Version; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import java.util.Objects; public class ClassificationConfig implements InferenceConfig { @@ -18,11 +22,21 @@ public class ClassificationConfig implements InferenceConfig { public static final String NAME = "classification"; public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); + private static final Version MIN_SUPPORTED_VERSION = Version.V_8_0_0; public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0); private final int numTopClasses; + public static ClassificationConfig fromMap(Map map) { + Map options = new HashMap<>(map); + Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName()); + if (options.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet()); + } + return new ClassificationConfig(numTopClasses); + } + public ClassificationConfig(Integer numTopClasses) { this.numTopClasses = numTopClasses == null ? 0 : numTopClasses; } @@ -78,4 +92,9 @@ public boolean isTargetTypeSupported(TargetType targetType) { return TargetType.CLASSIFICATION.equals(targetType); } + @Override + public Version getMinimalSupportedVersion() { + return MIN_SUPPORTED_VERSION; + } + } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java index 6129d71d5ff95..5d1dc7983ff3c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; @@ -13,4 +14,8 @@ public interface InferenceConfig extends NamedXContentObject, NamedWriteable { boolean isTargetTypeSupported(TargetType targetType); + /** + * All nodes in the cluster must be at least this version + */ + Version getMinimalSupportedVersion(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java index bb7f772f86ba4..58bd7bbd3d558 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java @@ -5,16 +5,27 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; +import java.util.Map; import java.util.Objects; public class RegressionConfig implements InferenceConfig { public static final String NAME = "regression"; + private static final Version MIN_SUPPORTED_VERSION = Version.V_8_0_0; + + public static RegressionConfig fromMap(Map map) { + if (map.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet()); + } + return new RegressionConfig(); + } public RegressionConfig() { } @@ -61,4 +72,9 @@ public boolean isTargetTypeSupported(TargetType targetType) { return TargetType.REGRESSION.equals(targetType); } + @Override + public Version getMinimalSupportedVersion() { + return MIN_SUPPORTED_VERSION; + } + } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java index 7628d0beec25f..42757d889818e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; @@ -26,6 +27,11 @@ public boolean isTargetTypeSupported(TargetType targetType) { return true; } + @Override + public Version getMinimalSupportedVersion() { + return Version.CURRENT; + } + @Override public String getWriteableName() { return "null"; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index c6558d781bbbc..05bc8250333e0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -85,6 +85,8 @@ public final class Messages { public static final String INFERENCE_FAILED_TO_SERIALIZE_MODEL = "Failed to serialize the trained model [{0}] for storage"; public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]"; + public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION = + "Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]"; public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again"; public static final String JOB_AUDIT_CREATED = "Job created"; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java index 4df3263215f63..808aaf960f4e1 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java @@ -5,15 +5,35 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import java.util.Collections; + +import static org.hamcrest.Matchers.equalTo; + public class ClassificationConfigTests extends AbstractWireSerializingTestCase { public static ClassificationConfig randomClassificationConfig() { return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10)); } + public void testFromMap() { + ClassificationConfig expected = new ClassificationConfig(0); + assertThat(ClassificationConfig.fromMap(Collections.emptyMap()), equalTo(expected)); + + expected = new ClassificationConfig(3); + assertThat(ClassificationConfig.fromMap(Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3)), + equalTo(expected)); + } + + public void testFromMapWithUnknownField() { + ElasticsearchException ex = expectThrows(ElasticsearchException.class, + () -> ClassificationConfig.fromMap(Collections.singletonMap("some_key", 1))); + assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); + } + @Override protected ClassificationConfig createTestInstance() { return randomClassificationConfig(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java index 57efcdd15009a..bdb0e6d03201f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java @@ -5,15 +5,31 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import java.util.Collections; + +import static org.hamcrest.Matchers.equalTo; + public class RegressionConfigTests extends AbstractWireSerializingTestCase { public static RegressionConfig randomRegressionConfig() { return new RegressionConfig(); } + public void testFromMap() { + RegressionConfig expected = new RegressionConfig(); + assertThat(RegressionConfig.fromMap(Collections.emptyMap()), equalTo(expected)); + } + + public void testFromMapWithUnknownField() { + ElasticsearchException ex = expectThrows(ElasticsearchException.class, + () -> RegressionConfig.fromMap(Collections.singletonMap("some_key", 1))); + assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); + } + @Override protected RegressionConfig createTestInstance() { return randomRegressionConfig(); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java new file mode 100644 index 0000000000000..852b3fcea0f0e --- /dev/null +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -0,0 +1,530 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; +import org.elasticsearch.action.ingest.SimulateDocumentBaseResult; +import org.elasticsearch.action.ingest.SimulatePipelineResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.junit.Before; + +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase { + + @Before + public void createBothModels() { + assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, + "_doc", + "test_classification") + .setSource(CLASSIFICATION_MODEL, XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get().status(), equalTo(RestStatus.CREATED)); + assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, + "_doc", + "test_regression") + .setSource(REGRESSION_MODEL, XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get().status(), equalTo(RestStatus.CREATED)); + } + + public void testPipelineCreationAndDeletion() throws Exception { + + for (int i = 0; i < 10; i++) { + assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline", + new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get().isAcknowledged(), is(true)); + + client().prepareIndex("index_for_inference_test", "_doc") + .setSource(new HashMap<>(){{ + put("col1", randomFrom("female", "male")); + put("col2", randomFrom("S", "M", "L", "XL")); + put("col3", randomFrom("true", "false", "none", "other")); + put("col4", randomIntBetween(0, 10)); + }}) + .setPipeline("simple_classification_pipeline") + .get(); + + assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(), + is(true)); + + assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline", + new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get().isAcknowledged(), is(true)); + + client().prepareIndex("index_for_inference_test", "_doc") + .setSource(new HashMap<>(){{ + put("col1", randomFrom("female", "male")); + put("col2", randomFrom("S", "M", "L", "XL")); + put("col3", randomFrom("true", "false", "none", "other")); + put("col4", randomIntBetween(0, 10)); + }}) + .setPipeline("simple_regression_pipeline") + .get(); + + assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(), + is(true)); + } + + assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline", + new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get().isAcknowledged(), is(true)); + + assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline", + new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get().isAcknowledged(), is(true)); + + for (int i = 0; i < 10; i++) { + client().prepareIndex("index_for_inference_test", "_doc") + .setSource(generateSourceDoc()) + .setPipeline("simple_classification_pipeline") + .get(); + + client().prepareIndex("index_for_inference_test", "_doc") + .setSource(generateSourceDoc()) + .setPipeline("simple_regression_pipeline") + .get(); + } + + assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(), + is(true)); + + assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(), + is(true)); + + client().admin().indices().refresh(new RefreshRequest("index_for_inference_test")).get(); + + assertThat(client().search(new SearchRequest().indices("index_for_inference_test") + .source(new SearchSourceBuilder() + .size(0) + .trackTotalHits(true) + .query(QueryBuilders.boolQuery() + .filter(QueryBuilders.existsQuery("regression_value"))))).get().getHits().getTotalHits().value, + equalTo(20L)); + + assertThat(client().search(new SearchRequest().indices("index_for_inference_test") + .source(new SearchSourceBuilder() + .size(0) + .trackTotalHits(true) + .query(QueryBuilders.boolQuery() + .filter(QueryBuilders.existsQuery("result_class"))))).get().getHits().getTotalHits().value, + equalTo(20L)); + + } + + public void testSimulate() { + String source = "{\n" + + " \"pipeline\": {\n" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"result_class\",\n" + + " \"inference_config\": {\"classification\":{}},\n" + + " \"model_id\": \"test_classification\",\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"result_class_prob\",\n" + + " \"inference_config\": {\"classification\": {\"num_top_classes\":2}},\n" + + " \"model_id\": \"test_classification\",\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"regression_value\",\n" + + " \"model_id\": \"test_regression\",\n" + + " \"inference_config\": {\"regression\":{}},\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " },\n" + + " \"docs\": [\n" + + " {\"_source\": {\n" + + " \"col1\": \"female\",\n" + + " \"col2\": \"M\",\n" + + " \"col3\": \"none\",\n" + + " \"col4\": 10\n" + + " }}]\n" + + "}"; + + SimulatePipelineResponse response = client().admin().cluster() + .prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get(); + SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0); + assertThat(baseResult.getIngestDocument().getFieldValue("regression_value", Double.class), equalTo(1.0)); + assertThat(baseResult.getIngestDocument().getFieldValue("result_class", String.class), equalTo("second")); + assertThat(baseResult.getIngestDocument().getFieldValue("result_class_prob", List.class).size(), equalTo(2)); + + String sourceWithMissingModel = "{\n" + + " \"pipeline\": {\n" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"result_class\",\n" + + " \"model_id\": \"test_classification_missing\",\n" + + " \"inference_config\": {\"classification\":{}},\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " },\n" + + " \"docs\": [\n" + + " {\"_source\": {\n" + + " \"col1\": \"female\",\n" + + " \"col2\": \"M\",\n" + + " \"col3\": \"none\",\n" + + " \"col4\": 10\n" + + " }}]\n" + + "}"; + + response = client().admin().cluster() + .prepareSimulatePipeline(new BytesArray(sourceWithMissingModel.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get(); + + assertThat(((SimulateDocumentBaseResult) response.getResults().get(0)).getFailure().getMessage(), + containsString("Could not find trained model [test_classification_missing]")); + } + + private Map generateSourceDoc() { + return new HashMap<>(){{ + put("col1", randomFrom("female", "male")); + put("col2", randomFrom("S", "M", "L", "XL")); + put("col3", randomFrom("true", "false", "none", "other")); + put("col4", randomIntBetween(0, 10)); + }}; + } + + private static final String REGRESSION_MODEL = "{" + + " \"model_id\": \"test_regression\",\n" + + " \"model_version\": 0,\n" + + " \"definition\": {\n" + + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + + " \"preprocessors\": [\n" + + " {\n" + + " \"one_hot_encoding\": {\n" + + " \"field\": \"col1\",\n" + + " \"hot_map\": {\n" + + " \"male\": \"col1_male\",\n" + + " \"female\": \"col1_female\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"target_mean_encoding\": {\n" + + " \"field\": \"col2\",\n" + + " \"feature_name\": \"col2_encoded\",\n" + + " \"target_map\": {\n" + + " \"S\": 5.0,\n" + + " \"M\": 10.0,\n" + + " \"L\": 20\n" + + " },\n" + + " \"default_value\": 5.0\n" + + " }\n" + + " },\n" + + " {\n" + + " \"frequency_encoding\": {\n" + + " \"field\": \"col3\",\n" + + " \"feature_name\": \"col3_encoded\",\n" + + " \"frequency_map\": {\n" + + " \"none\": 0.75,\n" + + " \"true\": 0.10,\n" + + " \"false\": 0.15\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"trained_model\": {\n" + + " \"ensemble\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"aggregate_output\": {\n" + + " \"weighted_sum\": {\n" + + " \"weights\": [\n" + + " 0.5,\n" + + " 0.5\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"target_type\": \"regression\",\n" + + " \"trained_models\": [\n" + + " {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " },\n" + + " \"description\": \"test model for regression\",\n" + + " \"version\": \"8.0.0\",\n" + + " \"created_by\": \"ml_test\",\n" + + " \"model_type\": \"local\",\n" + + " \"created_time\": 0" + + "}"; + + private static final String CLASSIFICATION_MODEL = "" + + "{\n" + + " \"model_id\": \"test_classification\",\n" + + " \"model_version\": 0,\n" + + " \"definition\":{\n" + + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + + " \"preprocessors\": [\n" + + " {\n" + + " \"one_hot_encoding\": {\n" + + " \"field\": \"col1\",\n" + + " \"hot_map\": {\n" + + " \"male\": \"col1_male\",\n" + + " \"female\": \"col1_female\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"target_mean_encoding\": {\n" + + " \"field\": \"col2\",\n" + + " \"feature_name\": \"col2_encoded\",\n" + + " \"target_map\": {\n" + + " \"S\": 5.0,\n" + + " \"M\": 10.0,\n" + + " \"L\": 20\n" + + " },\n" + + " \"default_value\": 5.0\n" + + " }\n" + + " },\n" + + " {\n" + + " \"frequency_encoding\": {\n" + + " \"field\": \"col3\",\n" + + " \"feature_name\": \"col3_encoded\",\n" + + " \"frequency_map\": {\n" + + " \"none\": 0.75,\n" + + " \"true\": 0.10,\n" + + " \"false\": 0.15\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"trained_model\": {\n" + + " \"ensemble\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"aggregate_output\": {\n" + + " \"weighted_mode\": {\n" + + " \"weights\": [\n" + + " 0.5,\n" + + " 0.5\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"target_type\": \"classification\",\n" + + " \"classification_labels\": [\"first\", \"second\"],\n" + + " \"trained_models\": [\n" + + " {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " },\n" + + " \"description\": \"test model for classification\",\n" + + " \"version\": \"8.0.0\",\n" + + " \"created_by\": \"benwtrent\",\n" + + " \"model_type\": \"local\",\n" + + " \"created_time\": 0\n" + + "}"; + + private static final String CLASSIFICATION_PIPELINE = "{" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"result_class\",\n" + + " \"model_id\": \"test_classification\",\n" + + " \"inference_config\": {\"classification\": {}},\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " }]}\n"; + + private static final String REGRESSION_PIPELINE = "{" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"regression_value\",\n" + + " \"model_id\": \"test_regression\",\n" + + " \"inference_config\": {\"regression\": {}},\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " }]}\n"; + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 500a71b3a9416..72e903d79baf5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -208,6 +208,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex; @@ -313,7 +314,7 @@ import static java.util.Collections.emptyList; import static org.elasticsearch.index.mapper.MapperService.SINGLE_MAPPING_NAME; -public class MachineLearning extends Plugin implements ActionPlugin, IngestPlugin, AnalysisPlugin, PersistentTaskPlugin { +public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin { public static final String NAME = "ml"; public static final String BASE_PATH = "/_ml/"; public static final String PRE_V7_BASE_PATH = "/_xpack/ml/"; @@ -341,7 +342,8 @@ protected Setting roleSetting() { public Map getProcessors(Processor.Parameters parameters) { InferenceProcessor.Factory inferenceFactory = new InferenceProcessor.Factory(parameters.client, parameters.ingestService.getClusterService(), - this.settings); + this.settings, + parameters.ingestService); parameters.ingestService.addIngestClusterStateListener(inferenceFactory); return Collections.singletonMap(InferenceProcessor.TYPE, inferenceFactory); } @@ -435,7 +437,9 @@ public List> getSettings() { AutodetectBuilder.MAX_ANOMALY_RECORDS_SETTING_DYNAMIC, MAX_OPEN_JOBS_PER_NODE, MIN_DISK_SPACE_OFF_HEAP, - MlConfigMigrationEligibilityCheck.ENABLE_CONFIG_MIGRATION); + MlConfigMigrationEligibilityCheck.ENABLE_CONFIG_MIGRATION, + InferenceProcessor.MAX_INFERENCE_PROCESSORS + ); } public Settings additionalSettings() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index b8cccc0d45e23..40ca9ba253594 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -5,32 +5,86 @@ */ package org.elasticsearch.xpack.ml.inference.ingest; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.Client; -import org.elasticsearch.ingest.AbstractProcessor; -import org.elasticsearch.ingest.IngestDocument; - -import java.util.function.BiConsumer; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.MetaData; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.ingest.AbstractProcessor; import org.elasticsearch.ingest.ConfigurationUtils; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.IngestService; +import org.elasticsearch.ingest.Pipeline; +import org.elasticsearch.ingest.PipelineConfiguration; import org.elasticsearch.ingest.Processor; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; import java.util.Map; +import java.util.function.BiConsumer; import java.util.function.Consumer; +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + public class InferenceProcessor extends AbstractProcessor { + // How many total inference processors are allowed to be used in the cluster. + public static final Setting MAX_INFERENCE_PROCESSORS = Setting.intSetting("xpack.ml.max_inference_processors", + 50, + 1, + Setting.Property.Dynamic, + Setting.Property.NodeScope); + public static final String TYPE = "inference"; public static final String MODEL_ID = "model_id"; + public static final String INFERENCE_CONFIG = "inference_config"; + public static final String TARGET_FIELD = "target_field"; + public static final String FIELD_MAPPINGS = "field_mappings"; + public static final String MODEL_INFO_FIELD = "model_info_field"; private final Client client; private final String modelId; - public InferenceProcessor(Client client, String tag, String modelId) { + private final String targetField; + private final String modelInfoField; + private final Map modelInfo; + private final InferenceConfig inferenceConfig; + private final Map fieldMapping; + + public InferenceProcessor(Client client, + String tag, + String targetField, + String modelId, + InferenceConfig inferenceConfig, + Map fieldMapping, + String modelInfoField) { super(tag); this.client = client; + this.targetField = targetField; + this.modelInfoField = modelInfoField; this.modelId = modelId; + this.inferenceConfig = inferenceConfig; + this.fieldMapping = fieldMapping; + this.modelInfo = new HashMap<>(); + this.modelInfo.put("model_id", modelId); } public String getModelId() { @@ -39,8 +93,44 @@ public String getModelId() { @Override public void execute(IngestDocument ingestDocument, BiConsumer handler) { - //TODO actually work - handler.accept(ingestDocument, null); + executeAsyncWithOrigin(client, + ML_ORIGIN, + InferModelAction.INSTANCE, + this.buildRequest(ingestDocument), + ActionListener.wrap( + r -> { + try { + mutateDocument(r, ingestDocument); + handler.accept(ingestDocument, null); + } catch(ElasticsearchException ex) { + handler.accept(ingestDocument, ex); + } + }, + e -> handler.accept(ingestDocument, e) + )); + } + + InferModelAction.Request buildRequest(IngestDocument ingestDocument) { + Map fields = new HashMap<>(ingestDocument.getSourceAndMetadata()); + if (fieldMapping != null) { + fieldMapping.forEach((src, dest) -> { + Object srcValue = fields.remove(src); + if (srcValue != null) { + fields.put(dest, srcValue); + } + }); + } + return new InferModelAction.Request(modelId, fields, inferenceConfig); + } + + void mutateDocument(InferModelAction.Response response, IngestDocument ingestDocument) { + if (response.getInferenceResults().isEmpty()) { + throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR); + } + response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField); + if (modelInfoField != null) { + ingestDocument.setFieldValue(modelInfoField, modelInfo); + } } @Override @@ -53,26 +143,123 @@ public String getType() { return TYPE; } - public static class Factory implements Processor.Factory, Consumer { + public static final class Factory implements Processor.Factory, Consumer { + + private static final Logger logger = LogManager.getLogger(Factory.class); private final Client client; - private final ClusterService clusterService; + private final IngestService ingestService; + private volatile int currentInferenceProcessors; + private volatile int maxIngestProcessors; + private volatile Version minNodeVersion = Version.CURRENT; - public Factory(Client client, ClusterService clusterService, Settings settings) { + public Factory(Client client, ClusterService clusterService, Settings settings, IngestService ingestService) { this.client = client; - this.clusterService = clusterService; + this.maxIngestProcessors = MAX_INFERENCE_PROCESSORS.get(settings); + this.ingestService = ingestService; + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_INFERENCE_PROCESSORS, this::setMaxIngestProcessors); + } + + @Override + public void accept(ClusterState state) { + minNodeVersion = state.nodes().getMinNodeVersion(); + MetaData metaData = state.getMetaData(); + if (metaData == null) { + currentInferenceProcessors = 0; + return; + } + IngestMetadata ingestMetadata = metaData.custom(IngestMetadata.TYPE); + if (ingestMetadata == null) { + currentInferenceProcessors = 0; + return; + } + + int count = 0; + for (PipelineConfiguration configuration : ingestMetadata.getPipelines().values()) { + try { + Pipeline pipeline = Pipeline.create(configuration.getId(), + configuration.getConfigAsMap(), + ingestService.getProcessorFactories(), + ingestService.getScriptService()); + count += pipeline.getProcessors().stream().filter(processor -> processor instanceof InferenceProcessor).count(); + } catch (Exception ex) { + logger.warn(new ParameterizedMessage("failure parsing pipeline config [{}]", configuration.getId()), ex); + } + } + currentInferenceProcessors = count; + } + + // Used for testing + int numInferenceProcessors() { + return currentInferenceProcessors; } @Override - public Processor create(Map processorFactories, String tag, Map config) + public InferenceProcessor create(Map processorFactories, String tag, Map config) throws Exception { + + if (this.maxIngestProcessors <= currentInferenceProcessors) { + throw new ElasticsearchStatusException("Max number of inference processors reached, total inference processors [{}]. " + + "Adjust the setting [{}]: [{}] if a greater number is desired.", + RestStatus.CONFLICT, + currentInferenceProcessors, + MAX_INFERENCE_PROCESSORS.getKey(), + maxIngestProcessors); + } + String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID); - return new InferenceProcessor(client, tag, modelId); + String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD); + Map fieldMapping = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS); + InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG)); + String modelInfoField = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_INFO_FIELD, "_model_info"); + if (modelInfoField != null && tag != null) { + modelInfoField += "." + tag; + } + return new InferenceProcessor(client, tag, targetField, modelId, inferenceConfig, fieldMapping, modelInfoField); } - @Override - public void accept(ClusterState clusterState) { + // Package private for testing + void setMaxIngestProcessors(int maxIngestProcessors) { + logger.trace("updating setting maxIngestProcessors from [{}] to [{}]", this.maxIngestProcessors, maxIngestProcessors); + this.maxIngestProcessors = maxIngestProcessors; + } + + InferenceConfig inferenceConfigFromMap(Map inferenceConfig) throws IOException { + ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG); + + if (inferenceConfig.size() != 1) { + throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.", + INFERENCE_CONFIG); + } + Object value = inferenceConfig.values().iterator().next(); + + if ((value instanceof Map) == false) { + throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.", + INFERENCE_CONFIG); + } + @SuppressWarnings("unchecked") + Map valueMap = (Map)value; + + if (inferenceConfig.containsKey(ClassificationConfig.NAME)) { + checkSupportedVersion(new ClassificationConfig(0)); + return ClassificationConfig.fromMap(valueMap); + } else if (inferenceConfig.containsKey(RegressionConfig.NAME)) { + checkSupportedVersion(new RegressionConfig()); + return RegressionConfig.fromMap(valueMap); + } else { + throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}", + inferenceConfig.keySet(), + Arrays.asList(ClassificationConfig.NAME, RegressionConfig.NAME)); + } + } + void checkSupportedVersion(InferenceConfig config) { + if (config.getMinimalSupportedVersion().after(minNodeVersion)) { + throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION, + config.getName(), + config.getMinimalSupportedVersion(), + minNodeVersion)); + } } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index e6a5baf42ba12..da294f1e1580b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -52,20 +52,21 @@ public void getModel(String modelId, ActionListener modelActionListener) if (cachedModel != null) { if (cachedModel.isSuccess()) { modelActionListener.onResponse(cachedModel.getModel()); + logger.trace("[{}] loaded from cache", modelId); return; } } if (loadModelIfNecessary(modelId, modelActionListener) == false) { // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called // by a simulated pipeline - logger.debug("[{}] not actively loading, eager loading without cache", modelId); + logger.trace("[{}] not actively loading, eager loading without cache", modelId); provider.getTrainedModel(modelId, ActionListener.wrap( trainedModelConfig -> modelActionListener.onResponse(new LocalModel(trainedModelConfig.getModelId(), trainedModelConfig.getDefinition())), modelActionListener::onFailure )); } else { - logger.debug("[{}] is currently loading, added new listener to queue", modelId); + logger.trace("[{}] is loading or loaded, added new listener to queue", modelId); } } @@ -87,7 +88,7 @@ private boolean loadModelIfNecessary(String modelId, ActionListener model if (loadingListeners.computeIfPresent( modelId, (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) == null) { - logger.debug("[{}] attempting to load and cache", modelId); + logger.trace("[{}] attempting to load and cache", modelId); loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener)); loadModel(modelId); } @@ -156,6 +157,8 @@ public void clusterChanged(ClusterChangedEvent event) { // The listeners still waiting for a model and we are canceling the load? List>>> drainWithFailure = new ArrayList<>(); synchronized (loadingListeners) { + HashSet loadedModelBeforeClusterState = logger.isTraceEnabled() ? new HashSet<>(loadedModels.keySet()) : null; + HashSet loadingModelBeforeClusterState = logger.isTraceEnabled() ? new HashSet<>(loadingListeners.keySet()) : null; // If we had models still loading here but are no longer referenced // we should remove them from loadingListeners and alert the listeners for (String modelId : loadingListeners.keySet()) { @@ -180,6 +183,17 @@ public void clusterChanged(ClusterChangedEvent event) { for (String modelId : allReferencedModelKeys) { loadingListeners.put(modelId, new ArrayDeque<>()); } + if (loadedModelBeforeClusterState != null && loadingModelBeforeClusterState != null) { + if (loadingListeners.keySet().equals(loadingModelBeforeClusterState) == false) { + logger.trace("cluster state event changed loading models: before {} after {}", loadingModelBeforeClusterState, + loadingListeners.keySet()); + } + if (loadedModels.keySet().equals(loadedModelBeforeClusterState) == false) { + logger.trace("cluster state event changed loaded models: before {} after {}", loadedModelBeforeClusterState, + loadedModels.keySet()); + } + } + } for (Tuple>> modelAndListeners : drainWithFailure) { final String msg = new ParameterizedMessage( @@ -223,7 +237,6 @@ private static Set getReferencedModelKeys(IngestMetadata ingestMetadata) Object modelId = ((Map)processorConfig).get(InferenceProcessor.MODEL_ID); if (modelId != null) { assert modelId instanceof String; - // TODO also read model version allReferencedModelKeys.add(modelId.toString()); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java new file mode 100644 index 0000000000000..322b5cfb4ec2c --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java @@ -0,0 +1,266 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.ingest; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.Version; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.IngestService; +import org.elasticsearch.ingest.PipelineConfiguration; +import org.elasticsearch.ingest.Processor; +import org.elasticsearch.plugins.IngestPlugin; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.junit.Before; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class InferenceProcessorFactoryTests extends ESTestCase { + + private static final IngestPlugin SKINNY_PLUGIN = new IngestPlugin() { + @Override + public Map getProcessors(Processor.Parameters parameters) { + return Collections.singletonMap(InferenceProcessor.TYPE, + new InferenceProcessor.Factory(parameters.client, + parameters.ingestService.getClusterService(), + Settings.EMPTY, + parameters.ingestService)); + } + }; + private Client client; + private ClusterService clusterService; + private IngestService ingestService; + + @Before + public void setUpVariables() { + ThreadPool tp = mock(ThreadPool.class); + client = mock(Client.class); + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, + Collections.singleton(InferenceProcessor.MAX_INFERENCE_PROCESSORS)); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + ingestService = new IngestService(clusterService, tp, null, null, + null, Collections.singletonList(SKINNY_PLUGIN), client); + } + + public void testNumInferenceProcessors() throws Exception { + MetaData metaData = null; + + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, clusterService, Settings.EMPTY, ingestService); + processorFactory.accept(buildClusterState(metaData)); + + assertThat(processorFactory.numInferenceProcessors(), equalTo(0)); + metaData = MetaData.builder().build(); + + processorFactory.accept(buildClusterState(metaData)); + assertThat(processorFactory.numInferenceProcessors(), equalTo(0)); + + processorFactory.accept(buildClusterStateWithModelReferences("model1", "model2", "model3")); + assertThat(processorFactory.numInferenceProcessors(), equalTo(3)); + } + + public void testCreateProcessorWithTooManyExisting() throws Exception { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.builder().put(InferenceProcessor.MAX_INFERENCE_PROCESSORS.getKey(), 1).build(), + ingestService); + + processorFactory.accept(buildClusterStateWithModelReferences("model1")); + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, + () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", Collections.emptyMap())); + + assertThat(ex.getMessage(), equalTo("Max number of inference processors reached, total inference processors [1]. " + + "Adjust the setting [xpack.ml.max_inference_processors]: [1] if a greater number is desired.")); + } + + public void testCreateProcessorWithInvalidInferenceConfig() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.EMPTY, + ingestService); + + Map config = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("unknown_type", Collections.emptyMap())); + }}; + + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, + () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", config)); + assertThat(ex.getMessage(), + equalTo("unrecognized inference configuration type [unknown_type]. Supported types [classification, regression]")); + + Map config2 = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("regression", "boom")); + }}; + ex = expectThrows(ElasticsearchStatusException.class, + () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", config2)); + assertThat(ex.getMessage(), + equalTo("inference_config must be an object with one inference type mapped to an object.")); + + Map config3 = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.emptyMap()); + }}; + ex = expectThrows(ElasticsearchStatusException.class, + () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", config3)); + assertThat(ex.getMessage(), + equalTo("inference_config must be an object with one inference type mapped to an object.")); + } + + public void testCreateProcessorWithTooOldMinNodeVersion() throws IOException { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.EMPTY, + ingestService); + processorFactory.accept(builderClusterStateWithModelReferences(Version.V_7_5_0, "model1")); + + Map regression = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap())); + }}; + + try { + processorFactory.create(Collections.emptyMap(), "my_inference_processor", regression); + fail("Should not have successfully created"); + } catch (ElasticsearchException ex) { + assertThat(ex.getMessage(), + equalTo("Configuration [regression] requires minimum node version [8.0.0] (current minimum node version [7.5.0]")); + } catch (Exception ex) { + fail(ex.getMessage()); + } + + Map classification = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME, + Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1))); + }}; + + try { + processorFactory.create(Collections.emptyMap(), "my_inference_processor", classification); + fail("Should not have successfully created"); + } catch (ElasticsearchException ex) { + assertThat(ex.getMessage(), + equalTo("Configuration [classification] requires minimum node version [8.0.0] (current minimum node version [7.5.0]")); + } catch (Exception ex) { + fail(ex.getMessage()); + } + } + + public void testCreateProcessor() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.EMPTY, + ingestService); + + Map regression = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap())); + }}; + + try { + processorFactory.create(Collections.emptyMap(), "my_inference_processor", regression); + } catch (Exception ex) { + fail(ex.getMessage()); + } + + Map classification = new HashMap<>() {{ + put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME, + Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1))); + }}; + + try { + processorFactory.create(Collections.emptyMap(), "my_inference_processor", classification); + } catch (Exception ex) { + fail(ex.getMessage()); + } + } + + private static ClusterState buildClusterState(MetaData metaData) { + return ClusterState.builder(new ClusterName("_name")).metaData(metaData).build(); + } + + private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException { + return builderClusterStateWithModelReferences(Version.CURRENT, modelId); + } + + private static ClusterState builderClusterStateWithModelReferences(Version minNodeVersion, String... modelId) throws IOException { + Map configurations = new HashMap<>(modelId.length); + for (String id : modelId) { + configurations.put("pipeline_with_model_" + id, newConfigurationWithInferenceProcessor(id)); + } + IngestMetadata ingestMetadata = new IngestMetadata(configurations); + + return ClusterState.builder(new ClusterName("_name")) + .metaData(MetaData.builder().putCustom(IngestMetadata.TYPE, ingestMetadata)) + .nodes(DiscoveryNodes.builder() + .add(new DiscoveryNode("min_node", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + minNodeVersion)) + .add(new DiscoveryNode("current_node", + new TransportAddress(InetAddress.getLoopbackAddress(), 9302), + Version.CURRENT)) + .localNodeId("_node_id") + .masterNodeId("_node_id")) + .build(); + } + + private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId) throws IOException { + try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors", + Collections.singletonList( + Collections.singletonMap(InferenceProcessor.TYPE, + new HashMap<>() {{ + put(InferenceProcessor.MODEL_ID, modelId); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap())); + put(InferenceProcessor.TARGET_FIELD, "new_field"); + put(InferenceProcessor.FIELD_MAPPINGS, Collections.singletonMap("source", "dest")); + }}))))) { + return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON); + } + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java new file mode 100644 index 0000000000000..4f55768407339 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -0,0 +1,191 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.ingest; + +import org.elasticsearch.client.Client; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.core.Is.is; +import static org.mockito.Mockito.mock; + +public class InferenceProcessorTests extends ESTestCase { + + private Client client; + + @Before + public void setUpVariables() { + client = mock(Client.class); + } + + public void testMutateDocumentWithClassification() { + String targetField = "classification_value"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + "my_processor", + targetField, + "classification_model", + new ClassificationConfig(0), + Collections.emptyMap(), + "_ml_model.my_processor"); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + InferModelAction.Response response = new InferModelAction.Response( + Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", null))); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue(targetField, String.class), equalTo("foo")); + assertThat(document.getFieldValue("_ml_model", Map.class), + equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model")))); + } + + @SuppressWarnings("unchecked") + public void testMutateDocumentClassificationTopNClasses() { + String targetField = "classification_value_probabilities"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + "my_processor", + targetField, + "classification_model", + new ClassificationConfig(2), + Collections.emptyMap(), + "_ml_model.my_processor"); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + List classes = new ArrayList<>(2); + classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6)); + classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4)); + + InferModelAction.Response response = new InferModelAction.Response( + Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes))); + inferenceProcessor.mutateDocument(response, document); + + assertThat((List>)document.getFieldValue(targetField, List.class), + contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new))); + assertThat(document.getFieldValue("_ml_model", Map.class), + equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model")))); + } + + public void testMutateDocumentRegression() { + String targetField = "regression_value"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + "my_processor", + targetField, + "regression_model", + new RegressionConfig(), + Collections.emptyMap(), + "_ml_model.my_processor"); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + InferModelAction.Response response = new InferModelAction.Response( + Collections.singletonList(new RegressionInferenceResults(0.7))); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7)); + assertThat(document.getFieldValue("_ml_model", Map.class), + equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "regression_model")))); + } + + public void testMutateDocumentNoModelMetaData() { + String targetField = "regression_value"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + "my_processor", + targetField, + "regression_model", + new RegressionConfig(), + Collections.emptyMap(), + null); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + InferModelAction.Response response = new InferModelAction.Response( + Collections.singletonList(new RegressionInferenceResults(0.7))); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7)); + assertThat(document.hasField("_ml_model"), is(false)); + } + + public void testGenerateRequestWithEmptyMapping() { + String modelId = "model"; + Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10); + + InferenceProcessor processor = new InferenceProcessor(client, + "my_processor", + "my_field", + modelId, + new ClassificationConfig(topNClasses), + Collections.emptyMap(), + null); + + Map source = new HashMap<>(){{ + put("value1", 1); + put("value2", 4); + put("categorical", "foo"); + }}; + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(source)); + } + + public void testGenerateWithMapping() { + String modelId = "model"; + Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10); + + Map fieldMapping = new HashMap<>(3) {{ + put("value1", "new_value1"); + put("value2", "new_value2"); + put("categorical", "new_categorical"); + }}; + + InferenceProcessor processor = new InferenceProcessor(client, + "my_processor", + "my_field", + modelId, + new ClassificationConfig(topNClasses), + fieldMapping, + null); + + Map source = new HashMap<>(3){{ + put("value1", 1); + put("categorical", "foo"); + put("un_touched", "bar"); + }}; + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + Map expectedMap = new HashMap<>(2) {{ + put("new_value1", 1); + put("new_categorical", "foo"); + put("un_touched", "bar"); + }}; + assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap)); + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index a18b29487eac5..a8b199a7a3b59 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -97,7 +97,10 @@ "processors": [ { "inference" : { - "model_id" : "used-regression-model" + "model_id" : "used-regression-model", + "inference_config": {"regression": {}}, + "target_field": "regression_field", + "field_mappings": {} } } ] From 1b8e288f37af283a8193f3ffec0c4115fd60b4c5 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 24 Oct 2019 09:38:11 -0400 Subject: [PATCH 07/17] fixing inference ingest tests --- .../xpack/ml/integration/InferenceIngestIT.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java index 852b3fcea0f0e..3533afe947535 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -52,7 +52,7 @@ public void testPipelineCreationAndDeletion() throws Exception { new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)), XContentType.JSON).get().isAcknowledged(), is(true)); - client().prepareIndex("index_for_inference_test", "_doc") + client().prepareIndex("index_for_inference_test") .setSource(new HashMap<>(){{ put("col1", randomFrom("female", "male")); put("col2", randomFrom("S", "M", "L", "XL")); @@ -69,7 +69,7 @@ public void testPipelineCreationAndDeletion() throws Exception { new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)), XContentType.JSON).get().isAcknowledged(), is(true)); - client().prepareIndex("index_for_inference_test", "_doc") + client().prepareIndex("index_for_inference_test") .setSource(new HashMap<>(){{ put("col1", randomFrom("female", "male")); put("col2", randomFrom("S", "M", "L", "XL")); @@ -92,12 +92,12 @@ public void testPipelineCreationAndDeletion() throws Exception { XContentType.JSON).get().isAcknowledged(), is(true)); for (int i = 0; i < 10; i++) { - client().prepareIndex("index_for_inference_test", "_doc") + client().prepareIndex("index_for_inference_test") .setSource(generateSourceDoc()) .setPipeline("simple_classification_pipeline") .get(); - client().prepareIndex("index_for_inference_test", "_doc") + client().prepareIndex("index_for_inference_test") .setSource(generateSourceDoc()) .setPipeline("simple_regression_pipeline") .get(); From 18ea1440b669a3f9b2fc96ee451c237dc0c1f3bf Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Fri, 25 Oct 2019 09:01:05 -0400 Subject: [PATCH 08/17] [ML][Inference] fixing classification inference for ensemble (#48463) --- .../results/RawInferenceResults.java | 65 +++++++++++++++++++ .../{ensemble => }/NullInferenceConfig.java | 6 +- .../trainedmodel/ensemble/Ensemble.java | 6 ++ .../ml/inference/trainedmodel/tree/Tree.java | 6 ++ .../results/RawInferenceResultsTests.java | 26 ++++++++ .../trainedmodel/ensemble/EnsembleTests.java | 6 +- 6 files changed, 110 insertions(+), 5 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java rename x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/{ensemble => }/NullInferenceConfig.java (80%) create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java new file mode 100644 index 0000000000000..884d66032b564 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.ingest.IngestDocument; + +import java.io.IOException; +import java.util.Objects; + +public class RawInferenceResults extends SingleValueInferenceResults { + + public static final String NAME = "raw"; + + public RawInferenceResults(double value) { + super(value); + } + + public RawInferenceResults(StreamInput in) throws IOException { + super(in.readDouble()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + XContentBuilder innerToXContent(XContentBuilder builder, Params params) { + return builder; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + RawInferenceResults that = (RawInferenceResults) object; + return Objects.equals(value(), that.value()); + } + + @Override + public int hashCode() { + return Objects.hash(value()); + } + + @Override + public void writeResult(IngestDocument document, String resultField) { + throw new UnsupportedOperationException("[raw] does not support writing inference results"); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NullInferenceConfig.java similarity index 80% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NullInferenceConfig.java index 42757d889818e..b7c4a71b3e79e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NullInferenceConfig.java @@ -3,20 +3,18 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.io.IOException; /** * Used by ensemble to pass into sub-models. */ -class NullInferenceConfig implements InferenceConfig { +public class NullInferenceConfig implements InferenceConfig { public static final NullInferenceConfig INSTANCE = new NullInferenceConfig(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index ff03a621d99fa..3bea5ad80ba0c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -14,12 +14,14 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; @@ -135,6 +137,10 @@ public TargetType targetType() { } private InferenceResults buildResults(List processedInferences, InferenceConfig config) { + // Indicates that the config is useless and the caller just wants the raw value + if (config instanceof NullInferenceConfig) { + return new RawInferenceResults(outputAggregator.aggregate(processedInferences)); + } switch(targetType) { case REGRESSION: return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences)); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index a48cca3873117..7427a7cc70037 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -15,11 +15,13 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -134,6 +136,10 @@ private InferenceResults infer(List features, InferenceConfig config) { } private InferenceResults buildResult(Double value, InferenceConfig config) { + // Indicates that the config is useless and the caller just wants the raw value + if (config instanceof NullInferenceConfig) { + return new RawInferenceResults(value); + } switch (targetType) { case CLASSIFICATION: ClassificationConfig classificationConfig = (ClassificationConfig) config; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java new file mode 100644 index 0000000000000..d9d4e9933b24d --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +public class RawInferenceResultsTests extends AbstractWireSerializingTestCase { + + public static RawInferenceResults createRandomResults() { + return new RawInferenceResults(randomDouble()); + } + + @Override + protected RawInferenceResults createTestInstance() { + return createRandomResults(); + } + + @Override + protected Writeable.Reader instanceReader() { + return RawInferenceResults::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 52f317c2595c3..a81c210e33067 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -322,7 +322,9 @@ public void testClassificationInference() { .setLeftChild(3) .setRightChild(4)) .addNode(TreeNode.builder(3).setLeafValue(0.0)) - .addNode(TreeNode.builder(4).setLeafValue(1.0)).build(); + .addNode(TreeNode.builder(4).setLeafValue(1.0)) + .setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION)) + .build(); Tree tree2 = Tree.builder() .setFeatureNames(featureNames) .setRoot(TreeNode.builder(0) @@ -332,6 +334,7 @@ public void testClassificationInference() { .setThreshold(0.5)) .addNode(TreeNode.builder(1).setLeafValue(0.0)) .addNode(TreeNode.builder(2).setLeafValue(1.0)) + .setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION)) .build(); Tree tree3 = Tree.builder() .setFeatureNames(featureNames) @@ -342,6 +345,7 @@ public void testClassificationInference() { .setThreshold(1.0)) .addNode(TreeNode.builder(1).setLeafValue(1.0)) .addNode(TreeNode.builder(2).setLeafValue(0.0)) + .setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION)) .build(); Ensemble ensemble = Ensemble.builder() .setTargetType(TargetType.CLASSIFICATION) From c6d977c352083803d9fa764d64abe252c364a70e Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Fri, 25 Oct 2019 11:46:14 -0400 Subject: [PATCH 09/17] [ML][Inference] Adding model memory estimations (#48323) * [ML][Inference] Adding model memory estimations * addressing PR comments --- .../core/ml/inference/TrainedModelConfig.java | 2 +- .../ml/inference/TrainedModelDefinition.java | 48 ++- .../preprocessing/FrequencyEncoding.java | 17 + .../preprocessing/OneHotEncoding.java | 15 + .../inference/preprocessing/PreProcessor.java | 3 +- .../preprocessing/TargetMeanEncoding.java | 16 + .../inference/trainedmodel/TrainedModel.java | 3 +- .../trainedmodel/ensemble/Ensemble.java | 26 ++ .../ensemble/LogisticRegression.java | 31 +- .../ensemble/OutputAggregator.java | 3 +- .../trainedmodel/ensemble/WeightedMode.java | 31 +- .../trainedmodel/ensemble/WeightedSum.java | 33 +- .../ml/inference/trainedmodel/tree/Tree.java | 25 +- .../inference/trainedmodel/tree/TreeNode.java | 97 ++--- .../ml/inference/TrainedModelConfigTests.java | 12 + .../TrainedModelDefinitionTests.java | 18 + .../trainedmodel/ensemble/EnsembleTests.java | 15 +- .../ensemble/LogisticRegressionTests.java | 7 +- .../ensemble/WeightedModeTests.java | 7 +- .../ensemble/WeightedSumTests.java | 7 +- .../trainedmodel/tree/TreeTests.java | 13 + .../xpack/ml/integration/TrainedModelIT.java | 4 +- .../xpack/ml/MachineLearning.java | 11 +- .../process/results/AnalyticsResult.java | 7 +- .../inference/loadingservice/LocalModel.java | 9 + .../ml/inference/loadingservice/Model.java | 1 + .../loadingservice/ModelLoadingService.java | 343 +++++++++++------- .../loadingservice/LocalModelTests.java | 4 +- .../ModelLoadingServiceTests.java | 189 ++++++++-- 29 files changed, 734 insertions(+), 263 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index e1c24eee02b88..c3ba88e42eb79 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -174,7 +174,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli()); if (definition != null) { - builder.field(DEFINITION.getPreferredName(), definition); + builder.field(DEFINITION.getPreferredName(), definition, params); } builder.field(TAGS.getPreferredName(), tags); if (metadata != null) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java index 0798e721ed17f..4d3875b85c81e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java @@ -5,11 +5,15 @@ */ package org.elasticsearch.xpack.core.ml.inference; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.Accountables; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; @@ -25,16 +29,21 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; -public class TrainedModelDefinition implements ToXContentObject, Writeable { +public class TrainedModelDefinition implements ToXContentObject, Writeable, Accountable { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TrainedModelDefinition.class); public static final String NAME = "trained_mode_definition"; + public static final String HEAP_MEMORY_ESTIMATION = "heap_memory_estimation"; public static final ParseField TRAINED_MODEL = new ParseField("trained_model"); public static final ParseField PREPROCESSORS = new ParseField("preprocessors"); @@ -105,6 +114,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws PREPROCESSORS.getPreferredName(), preProcessors); builder.field(INPUT.getPreferredName(), input); + if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) { + builder.humanReadableField(HEAP_MEMORY_ESTIMATION + "_bytes", + HEAP_MEMORY_ESTIMATION, + new ByteSizeValue(ramBytesUsed())); + } builder.endObject(); return builder; } @@ -150,6 +164,26 @@ public int hashCode() { return Objects.hash(trainedModel, input, preProcessors); } + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += RamUsageEstimator.sizeOf(trainedModel); + size += RamUsageEstimator.sizeOf(input); + size += RamUsageEstimator.sizeOfCollection(preProcessors); + return size; + } + + @Override + public Collection getChildResources() { + List accountables = new ArrayList<>(preProcessors.size() + 2); + accountables.add(Accountables.namedAccountable("input", input)); + accountables.add(Accountables.namedAccountable("trained_model", trainedModel)); + for(PreProcessor preProcessor : preProcessors) { + accountables.add(Accountables.namedAccountable("pre_processor_" + preProcessor.getName(), preProcessor)); + } + return accountables; + } + public static class Builder { private List preProcessors; @@ -204,8 +238,9 @@ public TrainedModelDefinition build() { } } - public static class Input implements ToXContentObject, Writeable { + public static class Input implements ToXContentObject, Writeable, Accountable { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Input.class); public static final String NAME = "trained_mode_definition_input"; public static final ParseField FIELD_NAMES = new ParseField("field_names"); @@ -265,6 +300,15 @@ public int hashCode() { return Objects.hash(fieldNames); } + @Override + public long ramBytesUsed() { + return SHALLOW_SIZE + RamUsageEstimator.sizeOfCollection(fieldNames); + } + + @Override + public String toString() { + return Strings.toString(this); + } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java index 351c0f05960f0..cea99d3edc8f6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java @@ -5,7 +5,9 @@ */ package org.elasticsearch.xpack.core.ml.inference.preprocessing; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; @@ -25,6 +27,8 @@ */ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(FrequencyEncoding.class); + public static final ParseField NAME = new ParseField("frequency_encoding"); public static final ParseField FIELD = new ParseField("field"); public static final ParseField FEATURE_NAME = new ParseField("feature_name"); @@ -143,4 +147,17 @@ public int hashCode() { return Objects.hash(field, featureName, frequencyMap); } + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += RamUsageEstimator.sizeOf(field); + size += RamUsageEstimator.sizeOf(featureName); + size += RamUsageEstimator.sizeOfMap(frequencyMap); + return size; + } + + @Override + public String toString() { + return Strings.toString(this); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java index 106cb1e26c1c8..9784ed8cbe7aa 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java @@ -5,7 +5,9 @@ */ package org.elasticsearch.xpack.core.ml.inference.preprocessing; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; @@ -23,6 +25,7 @@ */ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(OneHotEncoding.class); public static final ParseField NAME = new ParseField("one_hot_encoding"); public static final ParseField FIELD = new ParseField("field"); public static final ParseField HOT_MAP = new ParseField("hot_map"); @@ -127,4 +130,16 @@ public int hashCode() { return Objects.hash(field, hotMap); } + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += RamUsageEstimator.sizeOf(field); + size += RamUsageEstimator.sizeOfMap(hotMap); + return size; + } + + @Override + public String toString() { + return Strings.toString(this); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java index 79e1ce16ad80e..f5c2ff7398068 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.preprocessing; +import org.apache.lucene.util.Accountable; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; @@ -14,7 +15,7 @@ * Describes a pre-processor for a defined machine learning model * This processor should take a set of fields and return the modified set of fields. */ -public interface PreProcessor extends NamedXContentObject, NamedWriteable { +public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accountable { /** * Process the given fields and their values and return the modified map. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java index d8f413b3b1754..914b43f98e967 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java @@ -5,7 +5,9 @@ */ package org.elasticsearch.xpack.core.ml.inference.preprocessing; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; @@ -25,6 +27,7 @@ */ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TargetMeanEncoding.class); public static final ParseField NAME = new ParseField("target_mean_encoding"); public static final ParseField FIELD = new ParseField("field"); public static final ParseField FEATURE_NAME = new ParseField("feature_name"); @@ -158,4 +161,17 @@ public int hashCode() { return Objects.hash(field, featureName, meanMap, defaultValue); } + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += RamUsageEstimator.sizeOf(field); + size += RamUsageEstimator.sizeOf(featureName); + size += RamUsageEstimator.sizeOfMap(meanMap); + return size; + } + + @Override + public String toString() { + return Strings.toString(this); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java index d1215943cbe12..a9028efdffa94 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.apache.lucene.util.Accountable; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; @@ -13,7 +14,7 @@ import java.util.List; import java.util.Map; -public interface TrainedModel extends NamedXContentObject, NamedWriteable { +public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accountable { /** * @return List of featureNames expected by the model. In the order that they are expected diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index 3bea5ad80ba0c..79883f4db4b4e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -5,6 +5,9 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.Accountables; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; @@ -29,6 +32,8 @@ import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; @@ -39,6 +44,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Ensemble.class); // TODO should we have regression/classification sub-classes that accept the builder? public static final ParseField NAME = new ParseField("ensemble"); public static final ParseField FEATURE_NAMES = new ParseField("feature_names"); @@ -249,6 +255,26 @@ public static Builder builder() { return new Builder(); } + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += RamUsageEstimator.sizeOfCollection(featureNames); + size += RamUsageEstimator.sizeOfCollection(classificationLabels); + size += RamUsageEstimator.sizeOfCollection(models); + size += outputAggregator.ramBytesUsed(); + return size; + } + + @Override + public Collection getChildResources() { + List accountables = new ArrayList<>(models.size() + 1); + for (TrainedModel model : models) { + accountables.add(Accountables.namedAccountable(model.getName(), model)); + } + accountables.add(Accountables.namedAccountable(outputAggregator.getName(), outputAggregator)); + return Collections.unmodifiableCollection(accountables); + } + public static class Builder { private List featureNames; private List trainedModels; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java index c8d06c2c1eb79..14f2b1b64b523 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -16,7 +17,6 @@ import java.io.IOException; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.stream.IntStream; @@ -25,6 +25,7 @@ public class LogisticRegression implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LogisticRegression.class); public static final ParseField NAME = new ParseField("logistic_regression"); public static final ParseField WEIGHTS = new ParseField("weights"); @@ -49,19 +50,23 @@ public static LogisticRegression fromXContentLenient(XContentParser parser) { return LENIENT_PARSER.apply(parser, null); } - private final List weights; + private final double[] weights; LogisticRegression() { this((List) null); } - public LogisticRegression(List weights) { - this.weights = weights == null ? null : Collections.unmodifiableList(weights); + private LogisticRegression(List weights) { + this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray()); + } + + public LogisticRegression(double[] weights) { + this.weights = weights; } public LogisticRegression(StreamInput in) throws IOException { if (in.readBoolean()) { - this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble)); + this.weights = in.readDoubleArray(); } else { this.weights = null; } @@ -69,18 +74,18 @@ public LogisticRegression(StreamInput in) throws IOException { @Override public Integer expectedValueSize() { - return this.weights == null ? null : this.weights.size(); + return this.weights == null ? null : this.weights.length; } @Override public List processValues(List values) { Objects.requireNonNull(values, "values must not be null"); - if (weights != null && values.size() != weights.size()) { + if (weights != null && values.size() != weights.length) { throw new IllegalArgumentException("values must be the same length as weights."); } double summation = weights == null ? values.stream().mapToDouble(Double::valueOf).sum() : - IntStream.range(0, weights.size()).mapToDouble(i -> values.get(i) * weights.get(i)).sum(); + IntStream.range(0, weights.length).mapToDouble(i -> values.get(i) * weights[i]).sum(); double probOfClassOne = sigmoid(summation); assert 0.0 <= probOfClassOne && probOfClassOne <= 1.0; return Arrays.asList(1.0 - probOfClassOne, probOfClassOne); @@ -123,7 +128,7 @@ public String getWriteableName() { public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(weights != null); if (weights != null) { - out.writeCollection(weights, StreamOutput::writeDouble); + out.writeDoubleArray(weights); } } @@ -142,12 +147,16 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; LogisticRegression that = (LogisticRegression) o; - return Objects.equals(weights, that.weights); + return Arrays.equals(weights, that.weights); } @Override public int hashCode() { - return Objects.hash(weights); + return Arrays.hashCode(weights); } + @Override + public long ramBytesUsed() { + return SHALLOW_SIZE + RamUsageEstimator.sizeOf(weights); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java index f19ae376f0e96..16b1fd7c4051e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java @@ -5,13 +5,14 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; +import org.apache.lucene.util.Accountable; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; import java.util.List; -public interface OutputAggregator extends NamedXContentObject, NamedWriteable { +public interface OutputAggregator extends NamedXContentObject, NamedWriteable, Accountable { /** * @return The expected size of the values array when aggregating. `null` implies there is no expected size. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java index 525425db66d08..29b311794be06 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -16,6 +17,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Objects; @@ -24,6 +26,7 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(WeightedMode.class); public static final ParseField NAME = new ParseField("weighted_mode"); public static final ParseField WEIGHTS = new ParseField("weights"); @@ -48,19 +51,23 @@ public static WeightedMode fromXContentLenient(XContentParser parser) { return LENIENT_PARSER.apply(parser, null); } - private final List weights; + private final double[] weights; WeightedMode() { this((List) null); } - public WeightedMode(List weights) { - this.weights = weights == null ? null : Collections.unmodifiableList(weights); + private WeightedMode(List weights) { + this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray()); + } + + public WeightedMode(double[] weights) { + this.weights = weights; } public WeightedMode(StreamInput in) throws IOException { if (in.readBoolean()) { - this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble)); + this.weights = in.readDoubleArray(); } else { this.weights = null; } @@ -68,13 +75,13 @@ public WeightedMode(StreamInput in) throws IOException { @Override public Integer expectedValueSize() { - return this.weights == null ? null : this.weights.size(); + return this.weights == null ? null : this.weights.length; } @Override public List processValues(List values) { Objects.requireNonNull(values, "values must not be null"); - if (weights != null && values.size() != weights.size()) { + if (weights != null && values.size() != weights.length) { throw new IllegalArgumentException("values must be the same length as weights."); } List freqArray = new ArrayList<>(); @@ -94,7 +101,7 @@ public List processValues(List values) { } List frequencies = new ArrayList<>(Collections.nCopies(maxVal + 1, Double.NEGATIVE_INFINITY)); for (int i = 0; i < freqArray.size(); i++) { - Double weight = weights == null ? 1.0 : weights.get(i); + Double weight = weights == null ? 1.0 : weights[i]; Integer value = freqArray.get(i); Double frequency = frequencies.get(value) == Double.NEGATIVE_INFINITY ? weight : frequencies.get(value) + weight; frequencies.set(value, frequency); @@ -138,7 +145,7 @@ public String getWriteableName() { public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(weights != null); if (weights != null) { - out.writeCollection(weights, StreamOutput::writeDouble); + out.writeDoubleArray(weights); } } @@ -157,12 +164,16 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; WeightedMode that = (WeightedMode) o; - return Objects.equals(weights, that.weights); + return Arrays.equals(weights, that.weights); } @Override public int hashCode() { - return Objects.hash(weights); + return Arrays.hashCode(weights); } + @Override + public long ramBytesUsed() { + return SHALLOW_SIZE + RamUsageEstimator.sizeOf(weights); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java index f2ba7514b0c1c..b9b34508b88ba 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -15,7 +16,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.io.IOException; -import java.util.Collections; +import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -24,6 +25,7 @@ public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(WeightedSum.class); public static final ParseField NAME = new ParseField("weighted_sum"); public static final ParseField WEIGHTS = new ParseField("weights"); @@ -48,19 +50,23 @@ public static WeightedSum fromXContentLenient(XContentParser parser) { return LENIENT_PARSER.apply(parser, null); } - private final List weights; + private final double[] weights; WeightedSum() { this((List) null); } - public WeightedSum(List weights) { - this.weights = weights == null ? null : Collections.unmodifiableList(weights); + private WeightedSum(List weights) { + this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray()); + } + + public WeightedSum(double[] weights) { + this.weights = weights; } public WeightedSum(StreamInput in) throws IOException { if (in.readBoolean()) { - this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble)); + this.weights = in.readDoubleArray(); } else { this.weights = null; } @@ -72,10 +78,10 @@ public List processValues(List values) { if (weights == null) { return values; } - if (values.size() != weights.size()) { + if (values.size() != weights.length) { throw new IllegalArgumentException("values must be the same length as weights."); } - return IntStream.range(0, weights.size()).mapToDouble(i -> values.get(i) * weights.get(i)).boxed().collect(Collectors.toList()); + return IntStream.range(0, weights.length).mapToDouble(i -> values.get(i) * weights[i]).boxed().collect(Collectors.toList()); } @Override @@ -105,7 +111,7 @@ public String getWriteableName() { public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(weights != null); if (weights != null) { - out.writeCollection(weights, StreamOutput::writeDouble); + out.writeDoubleArray(weights); } } @@ -124,21 +130,26 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; WeightedSum that = (WeightedSum) o; - return Objects.equals(weights, that.weights); + return Arrays.equals(weights, that.weights); } @Override public int hashCode() { - return Objects.hash(weights); + return Arrays.hashCode(weights); } @Override public Integer expectedValueSize() { - return weights == null ? null : this.weights.size(); + return weights == null ? null : this.weights.length; } @Override public boolean compatibleWith(TargetType targetType) { return TargetType.REGRESSION.equals(targetType); } + + @Override + public long ramBytesUsed() { + return SHALLOW_SIZE + RamUsageEstimator.sizeOf(weights); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index 7427a7cc70037..7036447adef7d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -5,6 +5,9 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.Accountables; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; @@ -30,6 +33,7 @@ import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.List; @@ -41,8 +45,9 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel; -public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel { +public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel, Accountable { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Tree.class); // TODO should we have regression/classification sub-classes that accept the builder? public static final ParseField NAME = new ParseField("tree"); @@ -317,6 +322,24 @@ private Double maxLeafValue() { null; } + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += RamUsageEstimator.sizeOfCollection(classificationLabels); + size += RamUsageEstimator.sizeOfCollection(featureNames); + size += RamUsageEstimator.sizeOfCollection(nodes); + return size; + } + + @Override + public Collection getChildResources() { + List accountables = new ArrayList<>(nodes.size()); + for (TreeNode node : nodes) { + accountables.add(Accountables.namedAccountable("tree_node_" + node.getNodeIndex(), node)); + } + return Collections.unmodifiableCollection(accountables); + } + public static class Builder { private List featureNames; private ArrayList nodes; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java index 9beda88e2c50a..9d58c280905ad 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java @@ -5,6 +5,9 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.common.Numbers; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; @@ -15,14 +18,15 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.job.config.Operator; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.List; import java.util.Objects; -public class TreeNode implements ToXContentObject, Writeable { +public class TreeNode implements ToXContentObject, Writeable, Accountable { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TreeNode.class); public static final String NAME = "tree_node"; public static final ParseField DECISION_TYPE = new ParseField("decision_type"); @@ -63,31 +67,31 @@ public static TreeNode.Builder fromXContent(XContentParser parser, boolean lenie } private final Operator operator; - private final Double threshold; - private final Integer splitFeature; + private final double threshold; + private final int splitFeature; private final int nodeIndex; - private final Double splitGain; - private final Double leafValue; + private final double splitGain; + private final double leafValue; private final boolean defaultLeft; private final int leftChild; private final int rightChild; - TreeNode(Operator operator, - Double threshold, - Integer splitFeature, - Integer nodeIndex, - Double splitGain, - Double leafValue, - Boolean defaultLeft, - Integer leftChild, - Integer rightChild) { + private TreeNode(Operator operator, + Double threshold, + Integer splitFeature, + int nodeIndex, + Double splitGain, + Double leafValue, + Boolean defaultLeft, + Integer leftChild, + Integer rightChild) { this.operator = operator == null ? Operator.LTE : operator; - this.threshold = threshold; - this.splitFeature = splitFeature; - this.nodeIndex = ExceptionsHelper.requireNonNull(nodeIndex, NODE_INDEX.getPreferredName()); - this.splitGain = splitGain; - this.leafValue = leafValue; + this.threshold = threshold == null ? Double.NaN : threshold; + this.splitFeature = splitFeature == null ? -1 : splitFeature; + this.nodeIndex = nodeIndex; + this.splitGain = splitGain == null ? Double.NaN : splitGain; + this.leafValue = leafValue == null ? Double.NaN : leafValue; this.defaultLeft = defaultLeft == null ? false : defaultLeft; this.leftChild = leftChild == null ? -1 : leftChild; this.rightChild = rightChild == null ? -1 : rightChild; @@ -95,11 +99,11 @@ public static TreeNode.Builder fromXContent(XContentParser parser, boolean lenie public TreeNode(StreamInput in) throws IOException { operator = Operator.readFromStream(in); - threshold = in.readOptionalDouble(); - splitFeature = in.readOptionalInt(); - splitGain = in.readOptionalDouble(); - nodeIndex = in.readInt(); - leafValue = in.readOptionalDouble(); + threshold = in.readDouble(); + splitFeature = in.readInt(); + splitGain = in.readDouble(); + nodeIndex = in.readVInt(); + leafValue = in.readDouble(); defaultLeft = in.readBoolean(); leftChild = in.readInt(); rightChild = in.readInt(); @@ -110,23 +114,23 @@ public Operator getOperator() { return operator; } - public Double getThreshold() { + public double getThreshold() { return threshold; } - public Integer getSplitFeature() { + public int getSplitFeature() { return splitFeature; } - public Integer getNodeIndex() { + public int getNodeIndex() { return nodeIndex; } - public Double getSplitGain() { + public double getSplitGain() { return splitGain; } - public Double getLeafValue() { + public double getLeafValue() { return leafValue; } @@ -164,11 +168,11 @@ private boolean isMissing(Double feature) { @Override public void writeTo(StreamOutput out) throws IOException { operator.writeTo(out); - out.writeOptionalDouble(threshold); - out.writeOptionalInt(splitFeature); - out.writeOptionalDouble(splitGain); - out.writeInt(nodeIndex); - out.writeOptionalDouble(leafValue); + out.writeDouble(threshold); + out.writeInt(splitFeature); + out.writeDouble(splitGain); + out.writeVInt(nodeIndex); + out.writeDouble(leafValue); out.writeBoolean(defaultLeft); out.writeInt(leftChild); out.writeInt(rightChild); @@ -177,12 +181,14 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - addOptionalField(builder, DECISION_TYPE, operator); - addOptionalField(builder, THRESHOLD, threshold); - addOptionalField(builder, SPLIT_FEATURE, splitFeature); - addOptionalField(builder, SPLIT_GAIN, splitGain); + builder.field(DECISION_TYPE.getPreferredName(), operator); + addOptionalDouble(builder, THRESHOLD, threshold); + if (splitFeature > -1) { + builder.field(SPLIT_FEATURE.getPreferredName(), splitFeature); + } + addOptionalDouble(builder, SPLIT_GAIN, splitGain); builder.field(NODE_INDEX.getPreferredName(), nodeIndex); - addOptionalField(builder, LEAF_VALUE, leafValue); + addOptionalDouble(builder, LEAF_VALUE, leafValue); builder.field(DEFAULT_LEFT.getPreferredName(), defaultLeft); if (leftChild >= 0) { builder.field(LEFT_CHILD.getPreferredName(), leftChild); @@ -194,8 +200,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - private void addOptionalField(XContentBuilder builder, ParseField field, Object value) throws IOException { - if (value != null) { + private void addOptionalDouble(XContentBuilder builder, ParseField field, double value) throws IOException { + if (Numbers.isValidDouble(value)) { builder.field(field.getPreferredName(), value); } } @@ -237,7 +243,12 @@ public String toString() { public static Builder builder(int nodeIndex) { return new Builder(nodeIndex); } - + + @Override + public long ramBytesUsed() { + return SHALLOW_SIZE; + } + public static class Builder { private Operator operator; private Double threshold; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index 678bb8a2982b4..a4855274c075b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractSerializingTestCase; @@ -28,6 +29,7 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE; import static org.hamcrest.Matchers.equalTo; public class TrainedModelConfigTests extends AbstractSerializingTestCase { @@ -88,6 +90,16 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { return new NamedWriteableRegistry(entries); } + @Override + protected ToXContent.Params getToXContentParams() { + return lenient ? ToXContent.EMPTY_PARAMS : new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")); + } + + @Override + protected boolean assertToXContentEquivalence() { + return false; + } + public void testValidateWithNullDefinition() { IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> TrainedModelConfig.builder().validate()); assertThat(ex.getMessage(), equalTo("[definition] must not be null.")); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java index 5339d93bf9100..a6c7998d4c25a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.DeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; @@ -31,7 +32,9 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; public class TrainedModelDefinitionTests extends AbstractSerializingTestCase { @@ -58,6 +61,16 @@ protected Predicate getRandomFieldsExcludeFilter() { return field -> !field.isEmpty(); } + @Override + protected ToXContent.Params getToXContentParams() { + return lenient ? ToXContent.EMPTY_PARAMS : new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")); + } + + @Override + protected boolean assertToXContentEquivalence() { + return false; + } + public static TrainedModelDefinition.Builder createRandomBuilder() { int numberOfProcessors = randomIntBetween(1, 10); return new TrainedModelDefinition.Builder() @@ -316,4 +329,9 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { return new NamedWriteableRegistry(entries); } + public void testRamUsageEstimation() { + TrainedModelDefinition test = createTestInstance(); + assertThat(test.ramBytesUsed(), greaterThan(0L)); + } + } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index a81c210e33067..753a9d3dd3cad 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -25,7 +25,6 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; import org.junit.Before; - import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -72,9 +71,9 @@ public static Ensemble createRandom() { List models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6)) .limit(numberOfModels) .collect(Collectors.toList()); - List weights = randomBoolean() ? + double[] weights = randomBoolean() ? null : - Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList()); + Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).mapToDouble(Double::valueOf).toArray(); OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights), new LogisticRegression(weights)); @@ -118,9 +117,9 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { public void testEnsembleWithAggregatedOutputDifferingFromTrainedModels() { List featureNames = Arrays.asList("foo", "bar"); int numberOfModels = 5; - List weights = new ArrayList<>(numberOfModels + 2); + double[] weights = new double[numberOfModels + 2]; for (int i = 0; i < numberOfModels + 2; i++) { - weights.add(randomDouble()); + weights[i] = randomDouble(); } OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights)); @@ -262,7 +261,7 @@ public void testClassificationProbability() { .setTargetType(TargetType.CLASSIFICATION) .setFeatureNames(featureNames) .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) - .setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0))) + .setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0})) .build(); List featureVector = Arrays.asList(0.4, 0.0); @@ -351,7 +350,7 @@ public void testClassificationInference() { .setTargetType(TargetType.CLASSIFICATION) .setFeatureNames(featureNames) .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) - .setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0))) + .setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0})) .build(); List featureVector = Arrays.asList(0.4, 0.0); @@ -408,7 +407,7 @@ public void testRegressionInference() { .setTargetType(TargetType.REGRESSION) .setFeatureNames(featureNames) .setTrainedModels(Arrays.asList(tree1, tree2)) - .setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5))) + .setOutputAggregator(new WeightedSum(new double[]{0.5, 0.5})) .build(); List featureVector = Arrays.asList(0.4, 0.0); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java index 142046d34ef75..e630a5874fc79 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java @@ -13,7 +13,6 @@ import java.io.IOException; import java.util.Arrays; import java.util.List; -import java.util.stream.Collectors; import java.util.stream.Stream; import static org.hamcrest.CoreMatchers.is; @@ -23,7 +22,7 @@ public class LogisticRegressionTests extends WeightedAggregatorTests weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()); + double[] weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).mapToDouble(Double::valueOf).toArray(); return new LogisticRegression(weights); } @@ -43,13 +42,13 @@ protected Writeable.Reader instanceReader() { } public void testAggregate() { - List ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0); + double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0}; List values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0); LogisticRegression logisticRegression = new LogisticRegression(ones); assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(1.0)); - List variedWeights = Arrays.asList(.01, -1.0, .1, 0.0, 0.0); + double[] variedWeights = new double[]{.01, -1.0, .1, 0.0, 0.0}; logisticRegression = new LogisticRegression(variedWeights); assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(0.0)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java index 683115e63879e..4421a8fbb938b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java @@ -13,7 +13,6 @@ import java.io.IOException; import java.util.Arrays; import java.util.List; -import java.util.stream.Collectors; import java.util.stream.Stream; import static org.hamcrest.CoreMatchers.is; @@ -23,7 +22,7 @@ public class WeightedModeTests extends WeightedAggregatorTests { @Override WeightedMode createTestInstance(int numberOfWeights) { - List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()); + double[] weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).mapToDouble(Double::valueOf).toArray(); return new WeightedMode(weights); } @@ -43,13 +42,13 @@ protected Writeable.Reader instanceReader() { } public void testAggregate() { - List ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0); + double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0}; List values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0); WeightedMode weightedMode = new WeightedMode(ones); assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0)); - List variedWeights = Arrays.asList(1.0, -1.0, .5, 1.0, 5.0); + double[] variedWeights = new double[]{1.0, -1.0, .5, 1.0, 5.0}; weightedMode = new WeightedMode(variedWeights); assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(5.0)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java index fa372f043a410..8e4a6577dbb27 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java @@ -13,7 +13,6 @@ import java.io.IOException; import java.util.Arrays; import java.util.List; -import java.util.stream.Collectors; import java.util.stream.Stream; import static org.hamcrest.CoreMatchers.is; @@ -23,7 +22,7 @@ public class WeightedSumTests extends WeightedAggregatorTests { @Override WeightedSum createTestInstance(int numberOfWeights) { - List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()); + double[] weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).mapToDouble(Double::valueOf).toArray(); return new WeightedSum(weights); } @@ -43,13 +42,13 @@ protected Writeable.Reader instanceReader() { } public void testAggregate() { - List ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0); + double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0}; List values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0); WeightedSum weightedSum = new WeightedSum(ones); assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0)); - List variedWeights = Arrays.asList(1.0, -1.0, .5, 1.0, 5.0); + double[] variedWeights = new double[]{1.0, -1.0, .5, 1.0, 5.0}; weightedSum = new WeightedSum(variedWeights); assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(28.0)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index 075bfbe912270..86e257469d82a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -108,6 +108,19 @@ protected Writeable.Reader instanceReader() { return Tree::new; } + public void testInferWithStump() { + Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION); + builder.setRoot(TreeNode.builder(0).setLeafValue(42.0)); + builder.setFeatureNames(Collections.emptyList()); + + Tree tree = builder.build(); + List featureNames = Arrays.asList("foo", "bar"); + List featureVector = Arrays.asList(0.6, 0.0); + Map featureMap = zipObjMap(featureNames, featureVector); // does not really matter as this is a stump + assertThat(42.0, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001)); + } + public void testInfer() { // Build a tree with 2 nodes and 3 leaves using 2 features // The leaves have unique values 0.1, 0.2, 0.3 diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java index 6d3fe32332a72..856ab440ea3fd 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -78,12 +78,14 @@ public void testGetTrainedModels() throws IOException { assertThat(response, containsString("\"count\":1")); getModel = client().performRequest(new Request("GET", - MachineLearning.BASE_PATH + "inference/test_regression*")); + MachineLearning.BASE_PATH + "inference/test_regression*?human")); assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); response = EntityUtils.toString(getModel.getEntity()); assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); assertThat(response, containsString("\"model_id\":\"test_regression_model-2\"")); + assertThat(response, containsString("\"heap_memory_estimation_bytes\"")); + assertThat(response, containsString("\"heap_memory_estimation\"")); assertThat(response, containsString("\"count\":2")); getModel = client().performRequest(new Request("GET", diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 72e903d79baf5..74a1a01278c68 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -210,7 +210,6 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; -import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.job.JobManager; @@ -438,7 +437,9 @@ public List> getSettings() { MAX_OPEN_JOBS_PER_NODE, MIN_DISK_SPACE_OFF_HEAP, MlConfigMigrationEligibilityCheck.ENABLE_CONFIG_MIGRATION, - InferenceProcessor.MAX_INFERENCE_PROCESSORS + InferenceProcessor.MAX_INFERENCE_PROCESSORS, + ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE, + ModelLoadingService.INFERENCE_MODEL_CACHE_TTL ); } @@ -597,7 +598,11 @@ public Collection createComponents(Client client, ClusterService cluster // Inference components final TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry); - final ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); + final ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + inferenceAuditor, + threadPool, + clusterService, + settings); // Data frame analytics components AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java index c383fd195767a..d6a9c0a0c63fd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java @@ -7,14 +7,17 @@ import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import java.io.IOException; +import java.util.Collections; import java.util.Objects; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE; public class AnalyticsResult implements ToXContentObject { @@ -65,7 +68,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(PROGRESS_PERCENT.getPreferredName(), progressPercent); } if (inferenceModel != null) { - builder.field(INFERENCE_MODEL.getPreferredName(), inferenceModel); + builder.field(INFERENCE_MODEL.getPreferredName(), + inferenceModel, + new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true"))); } builder.endObject(); return builder; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index 9bbf42915410d..403f10dd7d83b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -25,6 +25,15 @@ public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition) this.modelId = modelId; } + long ramBytesUsed() { + return trainedModelDefinition.ramBytesUsed(); + } + + @Override + public String getModelId() { + return modelId; + } + @Override public String getResultsType() { switch (trainedModelDefinition.getTrainedModel().targetType()) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java index c66a23d78f98e..fb32ce7f646df 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java @@ -17,4 +17,5 @@ public interface Model { void infer(Map fields, InferenceConfig inferenceConfig, ActionListener listener); + String getModelId(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index da294f1e1580b..35c5cf0ab1eb6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -14,13 +14,23 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateListener; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.cache.Cache; +import org.elasticsearch.common.cache.CacheBuilder; +import org.elasticsearch.common.cache.RemovalNotification; import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import java.util.ArrayDeque; import java.util.ArrayList; @@ -30,31 +40,94 @@ import java.util.Map; import java.util.Queue; import java.util.Set; +import java.util.concurrent.TimeUnit; +/** + * This is a thread safe model loading service. + * + * It will cache local models that are referenced by processors in memory (as long as it is instantiated on an ingest node). + * + * If more than one processor references the same model, that model will only be cached once. + */ public class ModelLoadingService implements ClusterStateListener { + /** + * The maximum size of the local model cache here in the loading service + * + * Once the limit is reached, LRU models are evicted in favor of new models + */ + public static final Setting INFERENCE_MODEL_CACHE_SIZE = + Setting.byteSizeSetting("xpack.ml.inference_model.cache_size", + new ByteSizeValue(1, ByteSizeUnit.GB), + Setting.Property.NodeScope); + + /** + * How long should a model stay in the cache since its last access + * + * If nothing references a model via getModel for this configured timeValue, it will be evicted. + * + * Specifically, in the ingest scenario, a processor will call getModel whenever it needs to run inference. So, if a processor is not + * executed for an extended period of time, the model will be evicted and will have to be loaded again when getModel is called. + * + */ + public static final Setting INFERENCE_MODEL_CACHE_TTL = + Setting.timeSetting("xpack.ml.inference_model.time_to_live", + new TimeValue(5, TimeUnit.MINUTES), + new TimeValue(1, TimeUnit.MILLISECONDS), + Setting.Property.NodeScope); + private static final Logger logger = LogManager.getLogger(ModelLoadingService.class); - private final Map loadedModels = new HashMap<>(); + private final Cache localModelCache; + private final Set referencedModels = new HashSet<>(); private final Map>> loadingListeners = new HashMap<>(); private final TrainedModelProvider provider; + private final Set shouldNotAudit; private final ThreadPool threadPool; + private final InferenceAuditor auditor; + private final ByteSizeValue maxCacheSize; public ModelLoadingService(TrainedModelProvider trainedModelProvider, + InferenceAuditor auditor, ThreadPool threadPool, - ClusterService clusterService) { + ClusterService clusterService, + Settings settings) { this.provider = trainedModelProvider; this.threadPool = threadPool; + this.maxCacheSize = INFERENCE_MODEL_CACHE_SIZE.get(settings); + this.auditor = auditor; + this.shouldNotAudit = new HashSet<>(); + this.localModelCache = CacheBuilder.builder() + .setMaximumWeight(this.maxCacheSize.getBytes()) + .weigher((id, localModel) -> localModel.ramBytesUsed()) + .removalListener(this::cacheEvictionListener) + .setExpireAfterAccess(INFERENCE_MODEL_CACHE_TTL.get(settings)) + .build(); clusterService.addListener(this); } + /** + * Gets the model referenced by `modelId` and responds to the listener. + * + * This method first checks the local LRU cache for the model. If it is present, it is returned from cache. + * + * If it is not present, one of the following occurs: + * + * - If the model is referenced by a pipeline and is currently being loaded, the `modelActionListener` + * is added to the list of listeners to be alerted when the model is fully loaded. + * - If the model is referenced by a pipeline and is currently NOT being loaded, a new load attempt is made and the resulting + * model will attempt to be cached for future reference + * - If the models is NOT referenced by a pipeline, the model is simply loaded from the index and given to the listener. + * It is not cached. + * + * @param modelId the model to get + * @param modelActionListener the listener to alert when the model has been retrieved. + */ public void getModel(String modelId, ActionListener modelActionListener) { - MaybeModel cachedModel = loadedModels.get(modelId); + LocalModel cachedModel = localModelCache.get(modelId); if (cachedModel != null) { - if (cachedModel.isSuccess()) { - modelActionListener.onResponse(cachedModel.getModel()); - logger.trace("[{}] loaded from cache", modelId); - return; - } + modelActionListener.onResponse(cachedModel); + logger.trace("[{}] loaded from cache", modelId); + return; } if (loadModelIfNecessary(modelId, modelActionListener) == false) { // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called @@ -77,13 +150,15 @@ public void getModel(String modelId, ActionListener modelActionListener) */ private boolean loadModelIfNecessary(String modelId, ActionListener modelActionListener) { synchronized (loadingListeners) { - MaybeModel cachedModel = loadedModels.get(modelId); + Model cachedModel = localModelCache.get(modelId); if (cachedModel != null) { - if (cachedModel.isSuccess()) { - modelActionListener.onResponse(cachedModel.getModel()); - return true; - } - // If the loaded model entry is there but is not present, that means the previous load attempt ran into an issue + modelActionListener.onResponse(cachedModel); + return true; + } + // It is referenced by a pipeline, but the cache does not contain it + if (referencedModels.contains(modelId)) { + // If the loaded model is referenced there but is not present, + // that means the previous load attempt failed or the model has been evicted // Attempt to load and cache the model if necessary if (loadingListeners.computeIfPresent( modelId, @@ -97,7 +172,7 @@ private boolean loadModelIfNecessary(String modelId, ActionListener model // if the cachedModel entry is null, but there are listeners present, that means it is being loaded return loadingListeners.computeIfPresent(modelId, (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) != null; - } + } // synchronized (loadingListeners) } private void loadModel(String modelId) { @@ -115,19 +190,19 @@ private void loadModel(String modelId) { private void handleLoadSuccess(String modelId, TrainedModelConfig trainedModelConfig) { Queue> listeners; - Model loadedModel = new LocalModel(trainedModelConfig.getModelId(), trainedModelConfig.getDefinition()); + LocalModel loadedModel = new LocalModel(trainedModelConfig.getModelId(), trainedModelConfig.getDefinition()); synchronized (loadingListeners) { listeners = loadingListeners.remove(modelId); // If there is no loadingListener that means the loading was canceled and the listener was already notified as such // Consequently, we should not store the retrieved model - if (listeners != null) { - loadedModels.put(modelId, MaybeModel.of(loadedModel)); - } - } - if (listeners != null) { - for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { - listener.onResponse(loadedModel); + if (listeners == null) { + return; } + localModelCache.put(modelId, loadedModel); + shouldNotAudit.remove(modelId); + } // synchronized (loadingListeners) + for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + listener.onResponse(loadedModel); } } @@ -135,76 +210,110 @@ private void handleLoadFailure(String modelId, Exception failure) { Queue> listeners; synchronized (loadingListeners) { listeners = loadingListeners.remove(modelId); - if (listeners != null) { - // If we failed to load and there were listeners present, that means that this model is referenced by a processor - // Add an empty entry here so that we can attempt to load and cache the model again when it is accessed again. - loadedModels.computeIfAbsent(modelId, (key) -> MaybeModel.of(failure)); + if (listeners == null) { + return; } + } // synchronized (loadingListeners) + // If we failed to load and there were listeners present, that means that this model is referenced by a processor + // Alert the listeners to the failure + for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + listener.onFailure(failure); } - if (listeners != null) { - for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { - listener.onFailure(failure); - } + } + + private void cacheEvictionListener(RemovalNotification notification) { + if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) { + String msg = new ParameterizedMessage( + "model cache entry evicted." + + "current cache [{}] current max [{}] model size [{}]. " + + "If this is undesired, consider updating setting [{}] or [{}].", + new ByteSizeValue(localModelCache.weight()).getStringRep(), + maxCacheSize.getStringRep(), + new ByteSizeValue(notification.getValue().ramBytesUsed()).getStringRep(), + INFERENCE_MODEL_CACHE_SIZE.getKey(), + INFERENCE_MODEL_CACHE_TTL.getKey()).getFormattedMessage(); + auditIfNecessary(notification.getKey(), msg); } } @Override public void clusterChanged(ClusterChangedEvent event) { - if (event.changedCustomMetaDataSet().contains(IngestMetadata.TYPE)) { - ClusterState state = event.state(); - IngestMetadata currentIngestMetadata = state.metaData().custom(IngestMetadata.TYPE); - Set allReferencedModelKeys = getReferencedModelKeys(currentIngestMetadata); - // The listeners still waiting for a model and we are canceling the load? - List>>> drainWithFailure = new ArrayList<>(); - synchronized (loadingListeners) { - HashSet loadedModelBeforeClusterState = logger.isTraceEnabled() ? new HashSet<>(loadedModels.keySet()) : null; - HashSet loadingModelBeforeClusterState = logger.isTraceEnabled() ? new HashSet<>(loadingListeners.keySet()) : null; - // If we had models still loading here but are no longer referenced - // we should remove them from loadingListeners and alert the listeners - for (String modelId : loadingListeners.keySet()) { - if (allReferencedModelKeys.contains(modelId) == false) { - drainWithFailure.add(Tuple.tuple(modelId, new ArrayList<>(loadingListeners.remove(modelId)))); - } - } + // If ingest data has not changed or if the current node is not an ingest node, don't bother caching models + if (event.changedCustomMetaDataSet().contains(IngestMetadata.TYPE) == false || + event.state().nodes().getLocalNode().isIngestNode() == false) { + return; + } - // Remove all cached models that are not referenced by any processors - loadedModels.keySet().retainAll(allReferencedModelKeys); + ClusterState state = event.state(); + IngestMetadata currentIngestMetadata = state.metaData().custom(IngestMetadata.TYPE); + Set allReferencedModelKeys = getReferencedModelKeys(currentIngestMetadata); + if (allReferencedModelKeys.equals(referencedModels)) { + return; + } + // The listeners still waiting for a model and we are canceling the load? + List>>> drainWithFailure = new ArrayList<>(); + Set referencedModelsBeforeClusterState = null; + Set loadingModelBeforeClusterState = null; + Set removedModels = null; + synchronized (loadingListeners) { + referencedModelsBeforeClusterState = new HashSet<>(referencedModels); + if (logger.isTraceEnabled()) { + loadingModelBeforeClusterState = new HashSet<>(loadingListeners.keySet()); + } + // If we had models still loading here but are no longer referenced + // we should remove them from loadingListeners and alert the listeners + for (String modelId : loadingListeners.keySet()) { + if (allReferencedModelKeys.contains(modelId) == false) { + drainWithFailure.add(Tuple.tuple(modelId, new ArrayList<>(loadingListeners.remove(modelId)))); + } + } + removedModels = Sets.difference(referencedModelsBeforeClusterState, allReferencedModelKeys); - // Remove all that are currently being loaded - allReferencedModelKeys.removeAll(loadingListeners.keySet()); + // Remove all cached models that are not referenced by any processors + removedModels.forEach(localModelCache::invalidate); + // Remove the models that are no longer referenced + referencedModels.removeAll(removedModels); + shouldNotAudit.removeAll(removedModels); - // Remove all that are fully loaded, will attempt empty model loading again - loadedModels.forEach((id, optionalModel) -> { - if (optionalModel.isSuccess()) { - allReferencedModelKeys.remove(id); - } - }); - // Populate loadingListeners key so we know that we are currently loading the model - for (String modelId : allReferencedModelKeys) { - loadingListeners.put(modelId, new ArrayDeque<>()); - } - if (loadedModelBeforeClusterState != null && loadingModelBeforeClusterState != null) { - if (loadingListeners.keySet().equals(loadingModelBeforeClusterState) == false) { - logger.trace("cluster state event changed loading models: before {} after {}", loadingModelBeforeClusterState, - loadingListeners.keySet()); - } - if (loadedModels.keySet().equals(loadedModelBeforeClusterState) == false) { - logger.trace("cluster state event changed loaded models: before {} after {}", loadedModelBeforeClusterState, - loadedModels.keySet()); - } - } + // Remove all that are still referenced, i.e. the intersection of allReferencedModelKeys and referencedModels + allReferencedModelKeys.removeAll(referencedModels); + referencedModels.addAll(allReferencedModelKeys); + // Populate loadingListeners key so we know that we are currently loading the model + for (String modelId : allReferencedModelKeys) { + loadingListeners.put(modelId, new ArrayDeque<>()); } - for (Tuple>> modelAndListeners : drainWithFailure) { - final String msg = new ParameterizedMessage( - "Cancelling load of model [{}] as it is no longer referenced by a pipeline", - modelAndListeners.v1()).getFormat(); - for (ActionListener listener : modelAndListeners.v2()) { - listener.onFailure(new ElasticsearchException(msg)); - } + } // synchronized (loadingListeners) + if (logger.isTraceEnabled()) { + if (loadingListeners.keySet().equals(loadingModelBeforeClusterState) == false) { + logger.trace("cluster state event changed loading models: before {} after {}", loadingModelBeforeClusterState, + loadingListeners.keySet()); } - loadModels(allReferencedModelKeys); + if (referencedModels.equals(referencedModelsBeforeClusterState) == false) { + logger.trace("cluster state event changed referenced models: before {} after {}", referencedModelsBeforeClusterState, + referencedModels); + } + } + for (Tuple>> modelAndListeners : drainWithFailure) { + final String msg = new ParameterizedMessage( + "Cancelling load of model [{}] as it is no longer referenced by a pipeline", + modelAndListeners.v1()).getFormat(); + for (ActionListener listener : modelAndListeners.v2()) { + listener.onFailure(new ElasticsearchException(msg)); + } + } + removedModels.forEach(this::auditUnreferencedModel); + loadModels(allReferencedModelKeys); + } + + private void auditIfNecessary(String modelId, String msg) { + if (shouldNotAudit.contains(modelId)) { + logger.trace("[{}] {}", modelId, msg); + return; } + auditor.warning(modelId, msg); + shouldNotAudit.add(modelId); + logger.warn("[{}] {}", modelId, msg); } private void loadModels(Set modelIds) { @@ -214,11 +323,20 @@ private void loadModels(Set modelIds) { // Execute this on a utility thread as when the callbacks occur we don't want them tying up the cluster listener thread pool threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> { for (String modelId : modelIds) { + auditNewReferencedModel(modelId); this.loadModel(modelId); } }); } + private void auditNewReferencedModel(String modelId) { + auditor.info(modelId, "referenced by ingest processors. Attempting to load model into cache"); + } + + private void auditUnreferencedModel(String modelId) { + auditor.info(modelId, "no longer referenced by any processors"); + } + private static Queue addFluently(Queue queue, T object) { queue.add(object); return queue; @@ -226,62 +344,27 @@ private static Queue addFluently(Queue queue, T object) { private static Set getReferencedModelKeys(IngestMetadata ingestMetadata) { Set allReferencedModelKeys = new HashSet<>(); - if (ingestMetadata != null) { - ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> { - Object processors = pipelineConfiguration.getConfigAsMap().get("processors"); - if (processors instanceof List) { - for(Object processor : (List)processors) { - if (processor instanceof Map) { - Object processorConfig = ((Map)processor).get(InferenceProcessor.TYPE); - if (processorConfig instanceof Map) { - Object modelId = ((Map)processorConfig).get(InferenceProcessor.MODEL_ID); - if (modelId != null) { - assert modelId instanceof String; - allReferencedModelKeys.add(modelId.toString()); - } + if (ingestMetadata == null) { + return allReferencedModelKeys; + } + ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> { + Object processors = pipelineConfiguration.getConfigAsMap().get("processors"); + if (processors instanceof List) { + for(Object processor : (List)processors) { + if (processor instanceof Map) { + Object processorConfig = ((Map)processor).get(InferenceProcessor.TYPE); + if (processorConfig instanceof Map) { + Object modelId = ((Map)processorConfig).get(InferenceProcessor.MODEL_ID); + if (modelId != null) { + assert modelId instanceof String; + allReferencedModelKeys.add(modelId.toString()); } } } } - }); - } + } + }); return allReferencedModelKeys; } - private static class MaybeModel { - - private final Model model; - private final Exception exception; - - static MaybeModel of(Model model) { - return new MaybeModel(model, null); - } - - static MaybeModel of(Exception exception) { - return new MaybeModel(null, exception); - } - - private MaybeModel(Model model, Exception exception) { - this.model = model; - this.exception = exception; - } - - Model getModel() { - return model; - } - - Exception getException() { - return exception; - } - - boolean isSuccess() { - return this.model != null; - } - - boolean isFailure() { - return this.exception != null; - } - - } - } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index 48aa70dec74f2..47670da69398d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -163,7 +163,7 @@ public static TrainedModel buildClassification(boolean includeLabels) { .setTargetType(TargetType.CLASSIFICATION) .setFeatureNames(featureNames) .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) - .setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0))) + .setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0})) .build(); } @@ -208,7 +208,7 @@ public static TrainedModel buildRegression() { .setTargetType(TargetType.REGRESSION) .setFeatureNames(featureNames) .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) - .setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5, 0.5))) + .setOutputAggregator(new WeightedSum(new double[]{0.5, 0.5, 0.5})) .build(); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 36cfe23d9398a..80fd54e3a6bfb 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -14,8 +14,14 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateListener; import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodeRole; +import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; @@ -27,15 +33,17 @@ import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import org.junit.After; import org.junit.Before; +import org.mockito.Mockito; import java.io.IOException; -import java.time.Instant; +import java.net.InetAddress; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -59,6 +67,7 @@ public class ModelLoadingServiceTests extends ESTestCase { private TrainedModelProvider trainedModelProvider; private ThreadPool threadPool; private ClusterService clusterService; + private InferenceAuditor auditor; @Before public void setUpComponents() { @@ -66,6 +75,10 @@ public void setUpComponents() { 1, 4, TimeValue.timeValueMinutes(10), "xpack.ml.utility_thread_pool")); trainedModelProvider = mock(TrainedModelProvider.class); clusterService = mock(ClusterService.class); + auditor = mock(InferenceAuditor.class); + doAnswer(a -> null).when(auditor).error(any(String.class), any(String.class)); + doAnswer(a -> null).when(auditor).info(any(String.class), any(String.class)); + doAnswer(a -> null).when(auditor).warning(any(String.class), any(String.class)); doAnswer((invocationOnMock) -> null).when(clusterService).addListener(any(ClusterStateListener.class)); when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("_name")).build()); } @@ -79,11 +92,15 @@ public void testGetCachedModels() throws Exception { String model1 = "test-load-model-1"; String model2 = "test-load-model-2"; String model3 = "test-load-model-3"; - withTrainedModel(model1); - withTrainedModel(model2); - withTrainedModel(model3); + withTrainedModel(model1, 1L); + withTrainedModel(model2, 1L); + withTrainedModel(model3, 1L); - ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + Settings.EMPTY); modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3)); @@ -98,13 +115,124 @@ public void testGetCachedModels() throws Exception { verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), any()); verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), any()); verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), any()); + + // Test invalidate cache for model3 + modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2)); + for(int i = 0; i < 10; i++) { + String model = modelIds[i%3]; + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, future); + assertThat(future.get(), is(not(nullValue()))); + } + + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), any()); + // It is not referenced, so called eagerly + verify(trainedModelProvider, times(4)).getTrainedModel(eq(model3), any()); + } + + public void testMaxCachedLimitReached() throws Exception { + String model1 = "test-cached-limit-load-model-1"; + String model2 = "test-cached-limit-load-model-2"; + String model3 = "test-cached-limit-load-model-3"; + withTrainedModel(model1, 10L); + withTrainedModel(model2, 5L); + withTrainedModel(model3, 15L); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + Settings.builder().put(ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE.getKey(), new ByteSizeValue(20L)).build()); + + modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3)); + + String[] modelIds = new String[]{model1, model2, model3}; + for(int i = 0; i < 10; i++) { + // Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load) + String model = modelIds[i%2]; + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, future); + assertThat(future.get(), is(not(nullValue()))); + } + + verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model1), any()); + verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model2), any()); + // Only loaded requested once on the initial load from the change event + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), any()); + + // Load model 3, should invalidate 1 + for(int i = 0; i < 10; i++) { + PlainActionFuture future3 = new PlainActionFuture<>(); + modelLoadingService.getModel(model3, future3); + assertThat(future3.get(), is(not(nullValue()))); + } + verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model3), any()); + + // Load model 1, should invalidate 2 + for(int i = 0; i < 10; i++) { + PlainActionFuture future1 = new PlainActionFuture<>(); + modelLoadingService.getModel(model1, future1); + assertThat(future1.get(), is(not(nullValue()))); + } + verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), any()); + + // Load model 2, should invalidate 3 + for(int i = 0; i < 10; i++) { + PlainActionFuture future2 = new PlainActionFuture<>(); + modelLoadingService.getModel(model2, future2); + assertThat(future2.get(), is(not(nullValue()))); + } + verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), any()); + + + // Test invalidate cache for model3 + // Now both model 1 and 2 should fit in cache without issues + modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2)); + for(int i = 0; i < 10; i++) { + String model = modelIds[i%3]; + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, future); + assertThat(future.get(), is(not(nullValue()))); + } + + verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), any()); + verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), any()); + verify(trainedModelProvider, Mockito.atLeast(4)).getTrainedModel(eq(model3), any()); + verify(trainedModelProvider, Mockito.atMost(5)).getTrainedModel(eq(model3), any()); + } + + + public void testWhenCacheEnabledButNotIngestNode() throws Exception { + String model1 = "test-uncached-not-ingest-model-1"; + withTrainedModel(model1, 1L); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + Settings.EMPTY); + + modelLoadingService.clusterChanged(ingestChangedEvent(false, model1)); + + for(int i = 0; i < 10; i++) { + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model1, future); + assertThat(future.get(), is(not(nullValue()))); + } + + verify(trainedModelProvider, times(10)).getTrainedModel(eq(model1), any()); } public void testGetCachedMissingModel() throws Exception { String model = "test-load-cached-missing-model"; withMissingModel(model); - ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); + ModelLoadingService modelLoadingService =new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + Settings.EMPTY); modelLoadingService.clusterChanged(ingestChangedEvent(model)); PlainActionFuture future = new PlainActionFuture<>(); @@ -124,7 +252,11 @@ public void testGetMissingModel() { String model = "test-load-missing-model"; withMissingModel(model); - ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + Settings.EMPTY); PlainActionFuture future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); @@ -138,9 +270,13 @@ public void testGetMissingModel() { public void testGetModelEagerly() throws Exception { String model = "test-get-model-eagerly"; - withTrainedModel(model); + withTrainedModel(model, 1L); - ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + Settings.EMPTY); for(int i = 0; i < 3; i++) { PlainActionFuture future = new PlainActionFuture<>(); @@ -152,11 +288,11 @@ public void testGetModelEagerly() throws Exception { } @SuppressWarnings("unchecked") - private void withTrainedModel(String modelId) { - TrainedModelConfig trainedModelConfig = buildTrainedModelConfigBuilder(modelId) - .setVersion(Version.CURRENT) - .setCreateTime(Instant.now()) - .build(); + private void withTrainedModel(String modelId, long size) { + TrainedModelDefinition definition = mock(TrainedModelDefinition.class); + when(definition.ramBytesUsed()).thenReturn(size); + TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class); + when(trainedModelConfig.getDefinition()).thenReturn(definition); doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; @@ -175,22 +311,18 @@ private void withMissingModel(String modelId) { }).when(trainedModelProvider).getTrainedModel(eq(modelId), any()); } - private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) { - return TrainedModelConfig.builder() - .setCreatedBy("ml_test") - .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) - .setDescription("trained model config for test") - .setModelId(modelId); + private static ClusterChangedEvent ingestChangedEvent(String... modelId) throws IOException { + return ingestChangedEvent(true, modelId); } - private static ClusterChangedEvent ingestChangedEvent(String... modelId) throws IOException { + private static ClusterChangedEvent ingestChangedEvent(boolean isIngestNode, String... modelId) throws IOException { ClusterChangedEvent event = mock(ClusterChangedEvent.class); when(event.changedCustomMetaDataSet()).thenReturn(Collections.singleton(IngestMetadata.TYPE)); - when(event.state()).thenReturn(buildClusterStateWithModelReferences(modelId)); + when(event.state()).thenReturn(buildClusterStateWithModelReferences(isIngestNode, modelId)); return event; } - private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException { + private static ClusterState buildClusterStateWithModelReferences(boolean isIngestNode, String... modelId) throws IOException { Map configurations = new HashMap<>(modelId.length); for (String id : modelId) { configurations.put("pipeline_with_model_" + id, newConfigurationWithInferenceProcessor(id)); @@ -199,6 +331,15 @@ private static ClusterState buildClusterStateWithModelReferences(String... model return ClusterState.builder(new ClusterName("_name")) .metaData(MetaData.builder().putCustom(IngestMetadata.TYPE, ingestMetadata)) + .nodes(DiscoveryNodes.builder().add( + new DiscoveryNode("node_name", + "node_id", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + isIngestNode ? Collections.singleton(DiscoveryNodeRole.INGEST_ROLE) : Collections.emptySet(), + Version.CURRENT)) + .localNodeId("node_id") + .build()) .build(); } From 484886a81baea1fbec792e4f892f8dcaa8a35772 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 28 Oct 2019 07:38:25 -0400 Subject: [PATCH 10/17] [ML][Inference] adding more options to inference processor (#48545) * [ML][Inference] adding more options to inference processor, fixing minor bug * addressing PR comments --- .../inference/ingest/InferenceProcessor.java | 41 +++++++---- .../ingest/InferenceProcessorTests.java | 69 +++++++++++++++---- 2 files changed, 80 insertions(+), 30 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index 40ca9ba253594..5a0e50cd84bae 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -59,15 +59,16 @@ public class InferenceProcessor extends AbstractProcessor { public static final String TARGET_FIELD = "target_field"; public static final String FIELD_MAPPINGS = "field_mappings"; public static final String MODEL_INFO_FIELD = "model_info_field"; + public static final String INCLUDE_MODEL_METADATA = "include_model_metadata"; private final Client client; private final String modelId; private final String targetField; private final String modelInfoField; - private final Map modelInfo; private final InferenceConfig inferenceConfig; private final Map fieldMapping; + private final boolean includeModelMetadata; public InferenceProcessor(Client client, String tag, @@ -75,16 +76,16 @@ public InferenceProcessor(Client client, String modelId, InferenceConfig inferenceConfig, Map fieldMapping, - String modelInfoField) { + String modelInfoField, + boolean includeModelMetadata) { super(tag); - this.client = client; - this.targetField = targetField; - this.modelInfoField = modelInfoField; - this.modelId = modelId; - this.inferenceConfig = inferenceConfig; - this.fieldMapping = fieldMapping; - this.modelInfo = new HashMap<>(); - this.modelInfo.put("model_id", modelId); + this.client = ExceptionsHelper.requireNonNull(client, "client"); + this.targetField = ExceptionsHelper.requireNonNull(targetField, TARGET_FIELD); + this.modelInfoField = ExceptionsHelper.requireNonNull(modelInfoField, MODEL_INFO_FIELD); + this.includeModelMetadata = includeModelMetadata; + this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); + this.inferenceConfig = ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG); + this.fieldMapping = ExceptionsHelper.requireNonNull(fieldMapping, FIELD_MAPPINGS); } public String getModelId() { @@ -128,8 +129,8 @@ void mutateDocument(InferModelAction.Response response, IngestDocument ingestDoc throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR); } response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField); - if (modelInfoField != null) { - ingestDocument.setFieldValue(modelInfoField, modelInfo); + if (includeModelMetadata) { + ingestDocument.setFieldValue(modelInfoField + "." + MODEL_ID, modelId); } } @@ -207,15 +208,25 @@ public InferenceProcessor create(Map processorFactori maxIngestProcessors); } + boolean includeModelMetadata = ConfigurationUtils.readBooleanProperty(TYPE, tag, config, INCLUDE_MODEL_METADATA, true); String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID); String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD); Map fieldMapping = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS); InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG)); - String modelInfoField = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_INFO_FIELD, "_model_info"); - if (modelInfoField != null && tag != null) { + String modelInfoField = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_INFO_FIELD, "ml"); + // If multiple inference processors are in the same pipeline, it is wise to tag them + // The tag will keep metadata entries from stepping on each other + if (tag != null) { modelInfoField += "." + tag; } - return new InferenceProcessor(client, tag, targetField, modelId, inferenceConfig, fieldMapping, modelInfoField); + return new InferenceProcessor(client, + tag, + targetField, + modelId, + inferenceConfig, + fieldMapping, + modelInfoField, + includeModelMetadata); } // Package private for testing diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index 4f55768407339..81f3be6aeefbf 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -43,7 +43,8 @@ public void testMutateDocumentWithClassification() { "classification_model", new ClassificationConfig(0), Collections.emptyMap(), - "_ml_model.my_processor"); + "ml.my_processor", + true); Map source = new HashMap<>(); Map ingestMetadata = new HashMap<>(); @@ -54,7 +55,7 @@ public void testMutateDocumentWithClassification() { inferenceProcessor.mutateDocument(response, document); assertThat(document.getFieldValue(targetField, String.class), equalTo("foo")); - assertThat(document.getFieldValue("_ml_model", Map.class), + assertThat(document.getFieldValue("ml", Map.class), equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model")))); } @@ -67,7 +68,8 @@ public void testMutateDocumentClassificationTopNClasses() { "classification_model", new ClassificationConfig(2), Collections.emptyMap(), - "_ml_model.my_processor"); + "ml.my_processor", + true); Map source = new HashMap<>(); Map ingestMetadata = new HashMap<>(); @@ -83,7 +85,7 @@ public void testMutateDocumentClassificationTopNClasses() { assertThat((List>)document.getFieldValue(targetField, List.class), contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new))); - assertThat(document.getFieldValue("_ml_model", Map.class), + assertThat(document.getFieldValue("ml", Map.class), equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model")))); } @@ -95,7 +97,8 @@ public void testMutateDocumentRegression() { "regression_model", new RegressionConfig(), Collections.emptyMap(), - "_ml_model.my_processor"); + "ml.my_processor", + true); Map source = new HashMap<>(); Map ingestMetadata = new HashMap<>(); @@ -106,7 +109,7 @@ public void testMutateDocumentRegression() { inferenceProcessor.mutateDocument(response, document); assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7)); - assertThat(document.getFieldValue("_ml_model", Map.class), + assertThat(document.getFieldValue("ml", Map.class), equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "regression_model")))); } @@ -118,7 +121,8 @@ public void testMutateDocumentNoModelMetaData() { "regression_model", new RegressionConfig(), Collections.emptyMap(), - null); + "ml.my_processor", + false); Map source = new HashMap<>(); Map ingestMetadata = new HashMap<>(); @@ -129,7 +133,40 @@ public void testMutateDocumentNoModelMetaData() { inferenceProcessor.mutateDocument(response, document); assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7)); - assertThat(document.hasField("_ml_model"), is(false)); + assertThat(document.hasField("ml"), is(false)); + } + + public void testMutateDocumentModelMetaDataExistingField() { + String targetField = "regression_value"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + "my_processor", + targetField, + "regression_model", + new RegressionConfig(), + Collections.emptyMap(), + "ml.my_processor", + true); + + //cannot use singleton map as attempting to mutate later + Map ml = new HashMap<>(){{ + put("regression_prediction", 0.55); + }}; + Map source = new HashMap<>(){{ + put("ml", ml); + }}; + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + InferModelAction.Response response = new InferModelAction.Response( + Collections.singletonList(new RegressionInferenceResults(0.7))); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7)); + assertThat(document.getFieldValue("ml", Map.class), + equalTo(new HashMap<>(){{ + put("my_processor", Collections.singletonMap("model_id", "regression_model")); + put("regression_prediction", 0.55); + }})); } public void testGenerateRequestWithEmptyMapping() { @@ -137,12 +174,13 @@ public void testGenerateRequestWithEmptyMapping() { Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10); InferenceProcessor processor = new InferenceProcessor(client, - "my_processor", - "my_field", - modelId, - new ClassificationConfig(topNClasses), - Collections.emptyMap(), - null); + "my_processor", + "my_field", + modelId, + new ClassificationConfig(topNClasses), + Collections.emptyMap(), + "ml.my_processor", + false); Map source = new HashMap<>(){{ put("value1", 1); @@ -171,7 +209,8 @@ public void testGenerateWithMapping() { modelId, new ClassificationConfig(topNClasses), fieldMapping, - null); + "ml.my_processor", + false); Map source = new HashMap<>(3){{ put("value1", 1); From 3788a70926400e29c39262344b2a8622687b5936 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 28 Oct 2019 10:54:22 -0400 Subject: [PATCH 11/17] [ML][Inference] handle string values better in feature extraction (#48584) * [ML][Inference] handle string values better in feature extraction * adding tests for InferenceHelpers --- .../trainedmodel/InferenceHelpers.java | 16 +++++- .../ml/inference/trainedmodel/tree/Tree.java | 4 +- .../trainedmodel/InferenceHelpersTests.java | 55 +++++++++++++++++++ .../trainedmodel/tree/TreeTests.java | 8 ++- 4 files changed, 78 insertions(+), 5 deletions(-) create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpersTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java index 5e37b237e9f79..86bf076cd6bf1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java @@ -56,7 +56,6 @@ public static List topClasses(List return topClassEntries; } - public static String classificationLabel(double inferenceValue, @Nullable List classificationLabels) { assert inferenceValue == Math.rint(inferenceValue); if (classificationLabels == null) { @@ -72,4 +71,19 @@ public static String classificationLabel(double inferenceValue, @Nullable List fields, InferenceConfig config "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString()); } - List features = featureNames.stream().map(f -> - fields.get(f) instanceof Number ? ((Number) fields.get(f)).doubleValue() : null - ).collect(Collectors.toList()); + List features = featureNames.stream().map(f -> InferenceHelpers.toDouble(fields.get(f))).collect(Collectors.toList()); return infer(features, config); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpersTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpersTests.java new file mode 100644 index 0000000000000..ec5093f625bfb --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpersTests.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.test.ESTestCase; + +import java.util.HashMap; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; + + +public class InferenceHelpersTests extends ESTestCase { + + public void testToDoubleFromNumbers() { + assertThat(0.5, equalTo(InferenceHelpers.toDouble(0.5))); + assertThat(0.5, equalTo(InferenceHelpers.toDouble(0.5))); + assertThat(5.0, equalTo(InferenceHelpers.toDouble(5L))); + assertThat(5.0, equalTo(InferenceHelpers.toDouble(5))); + assertThat(0.5, equalTo(InferenceHelpers.toDouble(0.5f))); + } + + public void testToDoubleFromString() { + assertThat(0.5, equalTo(InferenceHelpers.toDouble("0.5"))); + assertThat(-0.5, equalTo(InferenceHelpers.toDouble("-0.5"))); + assertThat(5.0, equalTo(InferenceHelpers.toDouble("5"))); + assertThat(-5.0, equalTo(InferenceHelpers.toDouble("-5"))); + + // if ae are turned off, then we should get a null value + // otherwise, we should expect an assertion failure telling us that the string is improperly formatted + try { + assertThat(InferenceHelpers.toDouble(""), is(nullValue())); + } catch (AssertionError ae) { + assertThat(ae.getMessage(), equalTo("value is not properly formatted double []")); + } + try { + assertThat(InferenceHelpers.toDouble("notADouble"), is(nullValue())); + } catch (AssertionError ae) { + assertThat(ae.getMessage(), equalTo("value is not properly formatted double [notADouble]")); + } + } + + public void testToDoubleFromNull() { + assertThat(InferenceHelpers.toDouble(null), is(nullValue())); + } + + public void testDoubleFromUnknownObj() { + assertThat(InferenceHelpers.toDouble(new HashMap<>()), is(nullValue())); + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index 86e257469d82a..11bf44fd165e4 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -154,6 +154,12 @@ public void testInfer() { assertThat(0.2, closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001)); + // This should still work if the internal values are strings + List featureVectorStrings = Arrays.asList("0.3", "0.9"); + featureMap = zipObjMap(featureNames, featureVectorStrings); + assertThat(0.2, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001)); + // This should handle missing values and take the default_left path featureMap = new HashMap<>(2) {{ put("foo", 0.3); @@ -294,7 +300,7 @@ public void testTreeWithTargetTypeAndLabelsMismatch() { assertThat(ex.getMessage(), equalTo(msg)); } - private static Map zipObjMap(List keys, List values) { + private static Map zipObjMap(List keys, List values) { return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get)); } } From c2946d17af51e21cce683eb7a66ccfc68f43b80f Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 29 Oct 2019 09:24:30 -0400 Subject: [PATCH 12/17] [ML][Inference] Adding _stats endpoint for inference (#48492) [ML][Inference] Adding _stats endpoint for inference. Initially only contains ingest stats and pipeline counts. --- .../org/elasticsearch/ingest/IngestStats.java | 61 ++++ .../xpack/core/XPackClientPlugin.java | 6 + .../action/GetTrainedModelsStatsAction.java | 194 ++++++++++ ...TrainedModelsStatsActionResponseTests.java | 60 ++++ .../ml/qa/ml-with-security/build.gradle | 2 + .../ml/integration/InferenceIngestIT.java | 10 +- .../xpack/ml/MachineLearning.java | 7 +- .../TransportGetTrainedModelsStatsAction.java | 336 ++++++++++++++++++ .../RestGetTrainedModelsStatsAction.java | 52 +++ ...sportGetTrainedModelsStatsActionTests.java | 282 +++++++++++++++ .../api/ml.get_trained_models_stats.json | 48 +++ .../test/ml/inference_stats_crud.yml | 226 ++++++++++++ 12 files changed, 1277 insertions(+), 7 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsStatsAction.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models_stats.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml diff --git a/server/src/main/java/org/elasticsearch/ingest/IngestStats.java b/server/src/main/java/org/elasticsearch/ingest/IngestStats.java index f140c5f155563..29b088745ee37 100644 --- a/server/src/main/java/org/elasticsearch/ingest/IngestStats.java +++ b/server/src/main/java/org/elasticsearch/ingest/IngestStats.java @@ -32,6 +32,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.concurrent.TimeUnit; public class IngestStats implements Writeable, ToXContentFragment { @@ -135,6 +136,21 @@ public Map> getProcessorStats() { return processorStats; } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + IngestStats that = (IngestStats) o; + return Objects.equals(totalStats, that.totalStats) + && Objects.equals(pipelineStats, that.pipelineStats) + && Objects.equals(processorStats, that.processorStats); + } + + @Override + public int hashCode() { + return Objects.hash(totalStats, pipelineStats, processorStats); + } + public static class Stats implements Writeable, ToXContentFragment { private final long ingestCount; @@ -203,6 +219,22 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field("failed", ingestFailedCount); return builder; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + IngestStats.Stats that = (IngestStats.Stats) o; + return Objects.equals(ingestCount, that.ingestCount) + && Objects.equals(ingestTimeInMillis, that.ingestTimeInMillis) + && Objects.equals(ingestFailedCount, that.ingestFailedCount) + && Objects.equals(ingestCurrent, that.ingestCurrent); + } + + @Override + public int hashCode() { + return Objects.hash(ingestCount, ingestTimeInMillis, ingestFailedCount, ingestCurrent); + } } /** @@ -255,6 +287,20 @@ public String getPipelineId() { public Stats getStats() { return stats; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + IngestStats.PipelineStat that = (IngestStats.PipelineStat) o; + return Objects.equals(pipelineId, that.pipelineId) + && Objects.equals(stats, that.stats); + } + + @Override + public int hashCode() { + return Objects.hash(pipelineId, stats); + } } /** @@ -276,5 +322,20 @@ public String getName() { public Stats getStats() { return stats; } + + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + IngestStats.ProcessorStat that = (IngestStats.ProcessorStat) o; + return Objects.equals(name, that.name) + && Objects.equals(stats, that.stats); + } + + @Override + public int hashCode() { + return Objects.hash(name, stats); + } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index 08e595dc9e057..ec4d1cdce0fed 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -77,6 +77,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction; import org.elasticsearch.xpack.core.ml.action.DeleteJobAction; import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction; @@ -98,6 +99,8 @@ import org.elasticsearch.xpack.core.ml.action.GetModelSnapshotsAction; import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction; import org.elasticsearch.xpack.core.ml.action.GetRecordsAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction; import org.elasticsearch.xpack.core.ml.action.KillProcessAction; @@ -344,6 +347,9 @@ public List> getClientActions() { EvaluateDataFrameAction.INSTANCE, EstimateMemoryUsageAction.INSTANCE, InferModelAction.INSTANCE, + GetTrainedModelsAction.INSTANCE, + DeleteTrainedModelAction.INSTANCE, + GetTrainedModelsStatsAction.INSTANCE, // security ClearRealmCacheAction.INSTANCE, ClearRolesCacheAction.INSTANCE, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java new file mode 100644 index 0000000000000..f3cb43e8ef790 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java @@ -0,0 +1,194 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.ingest.IngestStats; +import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest; +import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse; +import org.elasticsearch.xpack.core.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +public class GetTrainedModelsStatsAction extends ActionType { + + public static final GetTrainedModelsStatsAction INSTANCE = new GetTrainedModelsStatsAction(); + public static final String NAME = "cluster:monitor/xpack/ml/inference/stats/get"; + + public static final ParseField MODEL_ID = new ParseField("model_id"); + public static final ParseField PIPELINE_COUNT = new ParseField("pipeline_count"); + + private GetTrainedModelsStatsAction() { + super(NAME, GetTrainedModelsStatsAction.Response::new); + } + + public static class Request extends AbstractGetResourcesRequest { + + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + public Request() { + setAllowNoResources(true); + } + + public Request(String id) { + setResourceId(id); + setAllowNoResources(true); + } + + public Request(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getResourceIdField() { + return TrainedModelConfig.MODEL_ID.getPreferredName(); + } + + } + + public static class RequestBuilder extends ActionRequestBuilder { + + public RequestBuilder(ElasticsearchClient client, GetTrainedModelsStatsAction action) { + super(client, action, new Request()); + } + } + + public static class Response extends AbstractGetResourcesResponse { + + public static class TrainedModelStats implements ToXContentObject, Writeable { + private final String modelId; + private final IngestStats ingestStats; + private final int pipelineCount; + + private static final IngestStats EMPTY_INGEST_STATS = new IngestStats(new IngestStats.Stats(0, 0, 0, 0), + Collections.emptyList(), + Collections.emptyMap()); + + public TrainedModelStats(String modelId, IngestStats ingestStats, int pipelineCount) { + this.modelId = Objects.requireNonNull(modelId); + this.ingestStats = ingestStats == null ? EMPTY_INGEST_STATS : ingestStats; + if (pipelineCount < 0) { + throw new ElasticsearchException("[{}] must be a greater than or equal to 0", PIPELINE_COUNT.getPreferredName()); + } + this.pipelineCount = pipelineCount; + } + + public TrainedModelStats(StreamInput in) throws IOException { + modelId = in.readString(); + ingestStats = new IngestStats(in); + pipelineCount = in.readVInt(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MODEL_ID.getPreferredName(), modelId); + builder.field(PIPELINE_COUNT.getPreferredName(), pipelineCount); + if (pipelineCount > 0) { + // Ingest stats is a fragment + ingestStats.toXContent(builder, params); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + ingestStats.writeTo(out); + out.writeVInt(pipelineCount); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, ingestStats, pipelineCount); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + TrainedModelStats other = (TrainedModelStats) obj; + return Objects.equals(this.modelId, other.modelId) + && Objects.equals(this.ingestStats, other.ingestStats) + && Objects.equals(this.pipelineCount, other.pipelineCount); + } + } + + public static final ParseField RESULTS_FIELD = new ParseField("trained_model_stats"); + + public Response(StreamInput in) throws IOException { + super(in); + } + + public Response(QueryPage trainedModels) { + super(trainedModels); + } + + @Override + protected Reader getReader() { + return Response.TrainedModelStats::new; + } + + public static class Builder { + + private long totalModelCount; + private Set expandedIds; + private Map ingestStatsMap; + + public Builder setTotalModelCount(long totalModelCount) { + this.totalModelCount = totalModelCount; + return this; + } + + public Builder setExpandedIds(Set expandedIds) { + this.expandedIds = expandedIds; + return this; + } + + public Set getExpandedIds() { + return this.expandedIds; + } + + public Builder setIngestStatsByModelId(Map ingestStatsByModelId) { + this.ingestStatsMap = ingestStatsByModelId; + return this; + } + + public Response build() { + List trainedModelStats = new ArrayList<>(expandedIds.size()); + expandedIds.forEach(id -> { + IngestStats ingestStats = ingestStatsMap.get(id); + trainedModelStats.add(new TrainedModelStats(id, ingestStats, ingestStats == null ? + 0 : + ingestStats.getPipelineStats().size())); + }); + return new Response(new QueryPage<>(trainedModelStats, totalModelCount, RESULTS_FIELD)); + } + } + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java new file mode 100644 index 0000000000000..3e49f73917a44 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.ingest.IngestStats; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response; + +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +public class GetTrainedModelsStatsActionResponseTests extends AbstractWireSerializingTestCase { + + @Override + protected Response createTestInstance() { + int listSize = randomInt(10); + List trainedModelStats = Stream.generate(() -> randomAlphaOfLength(10)) + .limit(listSize).map(id -> + new Response.TrainedModelStats(id, + randomBoolean() ? randomIngestStats() : null, + randomIntBetween(0, 10)) + ) + .collect(Collectors.toList()); + return new Response(new QueryPage<>(trainedModelStats, randomLongBetween(listSize, 1000), Response.RESULTS_FIELD)); + } + + private IngestStats randomIngestStats() { + List pipelineIds = Stream.generate(()-> randomAlphaOfLength(10)) + .limit(randomIntBetween(0, 10)) + .collect(Collectors.toList()); + return new IngestStats( + new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong()), + pipelineIds.stream().map(id -> new IngestStats.PipelineStat(id, randomStats())).collect(Collectors.toList()), + pipelineIds.stream().collect(Collectors.toMap(Function.identity(), (v) -> randomProcessorStats()))); + } + + private IngestStats.Stats randomStats(){ + return new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong()); + } + + private List randomProcessorStats() { + return Stream.generate(() -> randomAlphaOfLength(10)) + .limit(randomIntBetween(0, 10)) + .map(name -> new IngestStats.ProcessorStat(name, randomStats())) + .collect(Collectors.toList()); + } + + @Override + protected Writeable.Reader instanceReader() { + return Response::new; + } +} diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 2dd63883b523a..a8481a9dee966 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -130,6 +130,8 @@ integTest.runner { 'ml/inference_crud/Test delete with missing model', 'ml/inference_crud/Test get given missing trained model', 'ml/inference_crud/Test get given expression without matches and allow_no_match is false', + 'ml/inference_stats_crud/Test get stats given missing trained model', + 'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false', 'ml/jobs_crud/Test cannot create job with existing categorizer state document', 'ml/jobs_crud/Test cannot create job with existing quantiles document', 'ml/jobs_crud/Test cannot create job with existing result document', diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java index 3533afe947535..1099c70ee238a 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -31,15 +31,13 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase { @Before public void createBothModels() { - assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, - "_doc", - "test_classification") + assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME) + .setId("test_classification") .setSource(CLASSIFICATION_MODEL, XContentType.JSON) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .get().status(), equalTo(RestStatus.CREATED)); - assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, - "_doc", - "test_regression") + assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME) + .setId("test_regression") .setSource(REGRESSION_MODEL, XContentType.JSON) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .get().status(), equalTo(RestStatus.CREATED)); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 74a1a01278c68..2338136f51c3d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -98,6 +98,7 @@ import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction; import org.elasticsearch.xpack.core.ml.action.GetRecordsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.KillProcessAction; @@ -167,6 +168,7 @@ import org.elasticsearch.xpack.ml.action.TransportGetModelSnapshotsAction; import org.elasticsearch.xpack.ml.action.TransportGetOverallBucketsAction; import org.elasticsearch.xpack.ml.action.TransportGetRecordsAction; +import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsStatsAction; import org.elasticsearch.xpack.ml.action.TransportInferModelAction; import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsAction; import org.elasticsearch.xpack.ml.action.TransportIsolateDatafeedAction; @@ -273,6 +275,7 @@ import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction; import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction; +import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction; import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction; @@ -741,7 +744,8 @@ public List getRestHandlers(Settings settings, RestController restC new RestEvaluateDataFrameAction(restController), new RestEstimateMemoryUsageAction(restController), new RestGetTrainedModelsAction(restController), - new RestDeleteTrainedModelAction(restController) + new RestDeleteTrainedModelAction(restController), + new RestGetTrainedModelsStatsAction(restController) ); } @@ -816,6 +820,7 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(InferModelAction.INSTANCE, TransportInferModelAction.class), new ActionHandler<>(GetTrainedModelsAction.INSTANCE, TransportGetTrainedModelsAction.class), new ActionHandler<>(DeleteTrainedModelAction.INSTANCE, TransportDeleteTrainedModelAction.class), + new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class), usageAction, infoAction); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java new file mode 100644 index 0000000000000..773577d92d1a0 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java @@ -0,0 +1,336 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.cluster.node.stats.NodeStats; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.metrics.CounterMetric; +import org.elasticsearch.common.regex.Regex; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.IngestService; +import org.elasticsearch.ingest.IngestStats; +import org.elasticsearch.ingest.Pipeline; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.sort.SortBuilders; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + + +public class TransportGetTrainedModelsStatsAction extends HandledTransportAction { + + private final Client client; + private final ClusterService clusterService; + private final IngestService ingestService; + + @Inject + public TransportGetTrainedModelsStatsAction(TransportService transportService, + ActionFilters actionFilters, + ClusterService clusterService, + IngestService ingestService, + Client client) { + super(GetTrainedModelsStatsAction.NAME, transportService, actionFilters, GetTrainedModelsStatsAction.Request::new); + this.client = client; + this.clusterService = clusterService; + this.ingestService = ingestService; + } + + @Override + protected void doExecute(Task task, + GetTrainedModelsStatsAction.Request request, + ActionListener listener) { + + GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder(); + + ActionListener nodesStatsListener = ActionListener.wrap( + nodesStatsResponse -> { + Map modelIdIngestStats = inferenceIngestStatsByPipelineId(nodesStatsResponse, + pipelineIdsByModelIds(clusterService.state(), + ingestService, + responseBuilder.getExpandedIds())); + listener.onResponse(responseBuilder.setIngestStatsByModelId(modelIdIngestStats).build()); + }, + listener::onFailure + ); + + ActionListener>> idsListener = ActionListener.wrap( + tuple -> { + responseBuilder.setExpandedIds(tuple.v2()) + .setTotalModelCount(tuple.v1()); + String[] ingestNodes = ingestNodes(clusterService.state()); + NodesStatsRequest nodesStatsRequest = new NodesStatsRequest(ingestNodes).clear().ingest(true); + executeAsyncWithOrigin(client, ML_ORIGIN, NodesStatsAction.INSTANCE, nodesStatsRequest, nodesStatsListener); + }, + listener::onFailure + ); + + expandIds(request, idsListener); + } + + static Map inferenceIngestStatsByPipelineId(NodesStatsResponse response, + Map> modelIdToPipelineId) { + + Map ingestStatsMap = new HashMap<>(); + + modelIdToPipelineId.forEach((modelId, pipelineIds) -> { + List collectedStats = response.getNodes() + .stream() + .map(nodeStats -> ingestStatsForPipelineIds(nodeStats, pipelineIds)) + .collect(Collectors.toList()); + ingestStatsMap.put(modelId, mergeStats(collectedStats)); + }); + + return ingestStatsMap; + } + + + private void expandIds(GetTrainedModelsStatsAction.Request request, ActionListener>> idsListener) { + String[] tokens = Strings.tokenizeToStringArray(request.getResourceId(), ","); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() + .sort(SortBuilders.fieldSort(request.getResourceIdField()) + // If there are no resources, there might be no mapping for the id field. + // This makes sure we don't get an error if that happens. + .unmappedType("long")) + .query(buildQuery(tokens, request.getResourceIdField())); + if (request.getPageParams() != null) { + sourceBuilder.from(request.getPageParams().getFrom()) + .size(request.getPageParams().getSize()); + } + sourceBuilder.trackTotalHits(true) + // we only care about the item id's, there is no need to load large model definitions. + .fetchSource(TrainedModelConfig.MODEL_ID.getPreferredName(), null); + + IndicesOptions indicesOptions = SearchRequest.DEFAULT_INDICES_OPTIONS; + SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN) + .indicesOptions(IndicesOptions.fromOptions(true, + indicesOptions.allowNoIndices(), + indicesOptions.expandWildcardsOpen(), + indicesOptions.expandWildcardsClosed(), + indicesOptions)) + .source(sourceBuilder); + + executeAsyncWithOrigin(client.threadPool().getThreadContext(), + ML_ORIGIN, + searchRequest, + ActionListener.wrap( + response -> { + Set foundResourceIds = new LinkedHashSet<>(); + long totalHitCount = response.getHits().getTotalHits().value; + for (SearchHit hit : response.getHits().getHits()) { + Map docSource = hit.getSourceAsMap(); + if (docSource == null) { + continue; + } + Object idValue = docSource.get(TrainedModelConfig.MODEL_ID.getPreferredName()); + if (idValue instanceof String) { + foundResourceIds.add(idValue.toString()); + } + } + ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, request.isAllowNoResources()); + requiredMatches.filterMatchedIds(foundResourceIds); + if (requiredMatches.hasUnmatchedIds()) { + idsListener.onFailure(ExceptionsHelper.missingTrainedModel(requiredMatches.unmatchedIdsString())); + } else { + idsListener.onResponse(Tuple.tuple(totalHitCount, foundResourceIds)); + } + }, + idsListener::onFailure + ), + client::search); + + } + + private QueryBuilder buildQuery(String[] tokens, String resourceIdField) { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME)); + + if (Strings.isAllOrWildcard(tokens)) { + return boolQuery; + } + // If the resourceId is not _all or *, we should see if it is a comma delimited string with wild-cards + // e.g. id1,id2*,id3 + BoolQueryBuilder shouldQueries = new BoolQueryBuilder(); + List terms = new ArrayList<>(); + for (String token : tokens) { + if (Regex.isSimpleMatchPattern(token)) { + shouldQueries.should(QueryBuilders.wildcardQuery(resourceIdField, token)); + } else { + terms.add(token); + } + } + if (terms.isEmpty() == false) { + shouldQueries.should(QueryBuilders.termsQuery(resourceIdField, terms)); + } + + if (shouldQueries.should().isEmpty() == false) { + boolQuery.filter(shouldQueries); + } + return boolQuery; + } + + static String[] ingestNodes(final ClusterState clusterState) { + String[] ingestNodes = new String[clusterState.nodes().getIngestNodes().size()]; + Iterator nodeIterator = clusterState.nodes().getIngestNodes().keysIt(); + int i = 0; + while(nodeIterator.hasNext()) { + ingestNodes[i++] = nodeIterator.next(); + } + return ingestNodes; + } + + static Map> pipelineIdsByModelIds(ClusterState state, IngestService ingestService, Set modelIds) { + IngestMetadata ingestMetadata = state.metaData().custom(IngestMetadata.TYPE); + Map> pipelineIdsByModelIds = new HashMap<>(); + if (ingestMetadata == null) { + return pipelineIdsByModelIds; + } + + ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> { + try { + Pipeline pipeline = Pipeline.create(pipelineId, + pipelineConfiguration.getConfigAsMap(), + ingestService.getProcessorFactories(), + ingestService.getScriptService()); + pipeline.getProcessors().forEach(processor -> { + if (processor instanceof InferenceProcessor) { + InferenceProcessor inferenceProcessor = (InferenceProcessor) processor; + if (modelIds.contains(inferenceProcessor.getModelId())) { + pipelineIdsByModelIds.computeIfAbsent(inferenceProcessor.getModelId(), + m -> new LinkedHashSet<>()).add(pipelineId); + } + } + }); + } catch (Exception ex) { + throw new ElasticsearchException("unexpected failure gathering pipeline information", ex); + } + }); + + return pipelineIdsByModelIds; + } + + static IngestStats ingestStatsForPipelineIds(NodeStats nodeStats, Set pipelineIds) { + IngestStats fullNodeStats = nodeStats.getIngestStats(); + Map> filteredProcessorStats = new HashMap<>(fullNodeStats.getProcessorStats()); + filteredProcessorStats.keySet().retainAll(pipelineIds); + List filteredPipelineStats = fullNodeStats.getPipelineStats() + .stream() + .filter(pipelineStat -> pipelineIds.contains(pipelineStat.getPipelineId())) + .collect(Collectors.toList()); + CounterMetric ingestCount = new CounterMetric(); + CounterMetric ingestTimeInMillis = new CounterMetric(); + CounterMetric ingestCurrent = new CounterMetric(); + CounterMetric ingestFailedCount = new CounterMetric(); + + filteredPipelineStats.forEach(pipelineStat -> { + IngestStats.Stats stats = pipelineStat.getStats(); + ingestCount.inc(stats.getIngestCount()); + ingestTimeInMillis.inc(stats.getIngestTimeInMillis()); + ingestCurrent.inc(stats.getIngestCurrent()); + ingestFailedCount.inc(stats.getIngestFailedCount()); + }); + + return new IngestStats( + new IngestStats.Stats(ingestCount.count(), ingestTimeInMillis.count(), ingestCurrent.count(), ingestFailedCount.count()), + filteredPipelineStats, + filteredProcessorStats); + } + + private static IngestStats mergeStats(List ingestStatsList) { + + Map pipelineStatsAcc = new LinkedHashMap<>(ingestStatsList.size()); + Map> processorStatsAcc = new LinkedHashMap<>(ingestStatsList.size()); + IngestStatsAccumulator totalStats = new IngestStatsAccumulator(); + ingestStatsList.forEach(ingestStats -> { + + ingestStats.getPipelineStats() + .forEach(pipelineStat -> + pipelineStatsAcc.computeIfAbsent(pipelineStat.getPipelineId(), + p -> new IngestStatsAccumulator()).inc(pipelineStat.getStats())); + + ingestStats.getProcessorStats() + .forEach((pipelineId, processorStat) -> { + Map processorAcc = processorStatsAcc.computeIfAbsent(pipelineId, + k -> new LinkedHashMap<>()); + processorStat.forEach(p -> + processorAcc.computeIfAbsent(p.getName(), + k -> new IngestStatsAccumulator()).inc(p.getStats())); + }); + + totalStats.inc(ingestStats.getTotalStats()); + }); + + List pipelineStatList = new ArrayList<>(pipelineStatsAcc.size()); + pipelineStatsAcc.forEach((pipelineId, accumulator) -> + pipelineStatList.add(new IngestStats.PipelineStat(pipelineId, accumulator.build()))); + + Map> processorStatList = new LinkedHashMap<>(processorStatsAcc.size()); + processorStatsAcc.forEach((pipelineId, accumulatorMap) -> { + List processorStats = new ArrayList<>(accumulatorMap.size()); + accumulatorMap.forEach((processorName, acc) -> processorStats.add(new IngestStats.ProcessorStat(processorName, acc.build()))); + processorStatList.put(pipelineId, processorStats); + }); + + return new IngestStats(totalStats.build(), pipelineStatList, processorStatList); + } + + private static class IngestStatsAccumulator { + CounterMetric ingestCount = new CounterMetric(); + CounterMetric ingestTimeInMillis = new CounterMetric(); + CounterMetric ingestCurrent = new CounterMetric(); + CounterMetric ingestFailedCount = new CounterMetric(); + + IngestStatsAccumulator inc(IngestStats.Stats s) { + ingestCount.inc(s.getIngestCount()); + ingestTimeInMillis.inc(s.getIngestTimeInMillis()); + ingestCurrent.inc(s.getIngestCurrent()); + ingestFailedCount.inc(s.getIngestFailedCount()); + return this; + } + + IngestStats.Stats build() { + return new IngestStats.Stats(ingestCount.count(), ingestTimeInMillis.count(), ingestCurrent.count(), ingestFailedCount.count()); + } + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsStatsAction.java new file mode 100644 index 0000000000000..100c8cfa2f922 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsStatsAction.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.common.Strings; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; + +import static org.elasticsearch.rest.RestRequest.Method.GET; +import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request.ALLOW_NO_MATCH; + +public class RestGetTrainedModelsStatsAction extends BaseRestHandler { + + public RestGetTrainedModelsStatsAction(RestController controller) { + controller.registerHandler( + GET, MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/_stats", this); + controller.registerHandler(GET, MachineLearning.BASE_PATH + "inference/_stats", this); + } + + @Override + public String getName() { + return "ml_get_trained_models_stats_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + if (Strings.isNullOrEmpty(modelId)) { + modelId = MetaData.ALL; + } + GetTrainedModelsStatsAction.Request request = new GetTrainedModelsStatsAction.Request(modelId); + if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) { + request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), + restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); + } + request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources())); + return channel -> client.execute(GetTrainedModelsStatsAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java new file mode 100644 index 0000000000000..d00c5a8cab5f3 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java @@ -0,0 +1,282 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.admin.cluster.node.stats.NodeStats; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.IngestService; +import org.elasticsearch.ingest.IngestStats; +import org.elasticsearch.ingest.PipelineConfiguration; +import org.elasticsearch.ingest.Processor; +import org.elasticsearch.plugins.IngestPlugin; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; +import org.junit.Before; + +import java.io.IOException; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasEntry; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TransportGetTrainedModelsStatsActionTests extends ESTestCase { + + private static class NotInferenceProcessor implements Processor { + + @Override + public IngestDocument execute(IngestDocument ingestDocument) throws Exception { + return ingestDocument; + } + + @Override + public String getType() { + return "not_inference"; + } + + @Override + public String getTag() { + return null; + } + + static class Factory implements Processor.Factory { + + @Override + public Processor create(Map processorFactories, String tag, Map config) { + return new NotInferenceProcessor(); + } + } + } + + private static final IngestPlugin SKINNY_INGEST_PLUGIN = new IngestPlugin() { + @Override + public Map getProcessors(Processor.Parameters parameters) { + Map factoryMap = new HashMap<>(); + factoryMap.put(InferenceProcessor.TYPE, + new InferenceProcessor.Factory(parameters.client, + parameters.ingestService.getClusterService(), + Settings.EMPTY, + parameters.ingestService)); + + factoryMap.put("not_inference", new NotInferenceProcessor.Factory()); + + return factoryMap; + } + }; + + private ClusterService clusterService; + private IngestService ingestService; + private Client client; + + @Before + public void setUpVariables() { + ThreadPool tp = mock(ThreadPool.class); + client = mock(Client.class); + clusterService = mock(ClusterService.class); + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, + Collections.singleton(InferenceProcessor.MAX_INFERENCE_PROCESSORS)); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + ingestService = new IngestService(clusterService, tp, null, null, + null, Collections.singletonList(SKINNY_INGEST_PLUGIN), client); + } + + + public void testInferenceIngestStatsByPipelineId() throws IOException { + List nodeStatsList = Arrays.asList( + buildNodeStats( + new IngestStats.Stats(2, 2, 3, 4), + Arrays.asList( + new IngestStats.PipelineStat( + "pipeline1", + new IngestStats.Stats(0, 0, 3, 1)), + new IngestStats.PipelineStat( + "pipeline2", + new IngestStats.Stats(1, 1, 0, 1)), + new IngestStats.PipelineStat( + "pipeline3", + new IngestStats.Stats(2, 1, 1, 1)) + ), + Arrays.asList( + Arrays.asList( + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, new IngestStats.Stats(10, 1, 0, 0)), + new IngestStats.ProcessorStat("grok", new IngestStats.Stats(10, 1, 0, 0)), + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, new IngestStats.Stats(100, 10, 0, 1)) + ), + Arrays.asList( + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, new IngestStats.Stats(5, 1, 0, 0)), + new IngestStats.ProcessorStat("grok", new IngestStats.Stats(10, 1, 0, 0)) + ), + Arrays.asList( + new IngestStats.ProcessorStat("grok", new IngestStats.Stats(10, 1, 0, 0)) + ) + )), + buildNodeStats( + new IngestStats.Stats(15, 5, 3, 4), + Arrays.asList( + new IngestStats.PipelineStat( + "pipeline1", + new IngestStats.Stats(10, 1, 3, 1)), + new IngestStats.PipelineStat( + "pipeline2", + new IngestStats.Stats(1, 1, 0, 1)), + new IngestStats.PipelineStat( + "pipeline3", + new IngestStats.Stats(2, 1, 1, 1)) + ), + Arrays.asList( + Arrays.asList( + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, new IngestStats.Stats(0, 0, 0, 0)), + new IngestStats.ProcessorStat("grok", new IngestStats.Stats(0, 0, 0, 0)), + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, new IngestStats.Stats(10, 1, 0, 0)) + ), + Arrays.asList( + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, new IngestStats.Stats(5, 1, 0, 0)), + new IngestStats.ProcessorStat("grok", new IngestStats.Stats(10, 1, 0, 0)) + ), + Arrays.asList( + new IngestStats.ProcessorStat("grok", new IngestStats.Stats(10, 1, 0, 0)) + ) + )) + ); + + NodesStatsResponse response = new NodesStatsResponse(new ClusterName("_name"), nodeStatsList, Collections.emptyList()); + + Map> pipelineIdsByModelIds = new HashMap<>(){{ + put("trained_model_1", Collections.singleton("pipeline1")); + put("trained_model_2", new HashSet<>(Arrays.asList("pipeline1", "pipeline2"))); + }}; + Map ingestStatsMap = TransportGetTrainedModelsStatsAction.inferenceIngestStatsByPipelineId(response, + pipelineIdsByModelIds); + + assertThat(ingestStatsMap.keySet(), equalTo(new HashSet<>(Arrays.asList("trained_model_1", "trained_model_2")))); + + IngestStats expectedStatsModel1 = new IngestStats( + new IngestStats.Stats(10, 1, 6, 2), + Collections.singletonList(new IngestStats.PipelineStat("pipeline1", new IngestStats.Stats(10, 1, 6, 2))), + Collections.singletonMap("pipeline1", Arrays.asList( + new IngestStats.ProcessorStat("inference", new IngestStats.Stats(120, 12, 0, 1)), + new IngestStats.ProcessorStat("grok", new IngestStats.Stats(10, 1, 0, 0)))) + ); + + IngestStats expectedStatsModel2 = new IngestStats( + new IngestStats.Stats(12, 3, 6, 4), + Arrays.asList( + new IngestStats.PipelineStat("pipeline1", new IngestStats.Stats(10, 1, 6, 2)), + new IngestStats.PipelineStat("pipeline2", new IngestStats.Stats(2, 2, 0, 2))), + new HashMap<>() {{ + put("pipeline2", Arrays.asList( + new IngestStats.ProcessorStat("inference", new IngestStats.Stats(10, 2, 0, 0)), + new IngestStats.ProcessorStat("grok", new IngestStats.Stats(20, 2, 0, 0)))); + put("pipeline1", Arrays.asList( + new IngestStats.ProcessorStat("inference", new IngestStats.Stats(120, 12, 0, 1)), + new IngestStats.ProcessorStat("grok", new IngestStats.Stats(10, 1, 0, 0)))); + }} + ); + + assertThat(ingestStatsMap, hasEntry("trained_model_1", expectedStatsModel1)); + assertThat(ingestStatsMap, hasEntry("trained_model_2", expectedStatsModel2)); + } + + public void testPipelineIdsByModelIds() throws IOException { + String modelId1 = "trained_model_1"; + String modelId2 = "trained_model_2"; + String modelId3 = "trained_model_3"; + Set modelIds = new HashSet<>(Arrays.asList(modelId1, modelId2, modelId3)); + + ClusterState clusterState = buildClusterStateWithModelReferences(modelId1, modelId2, modelId3); + + Map> pipelineIdsByModelIds = + TransportGetTrainedModelsStatsAction.pipelineIdsByModelIds(clusterState, ingestService, modelIds); + + assertThat(pipelineIdsByModelIds.keySet(), equalTo(modelIds)); + assertThat(pipelineIdsByModelIds, + hasEntry(modelId1, new HashSet<>(Arrays.asList("pipeline_with_model_" + modelId1 + 0, "pipeline_with_model_" + modelId1 + 1)))); + assertThat(pipelineIdsByModelIds, + hasEntry(modelId2, new HashSet<>(Arrays.asList("pipeline_with_model_" + modelId2 + 0, "pipeline_with_model_" + modelId2 + 1)))); + assertThat(pipelineIdsByModelIds, + hasEntry(modelId3, new HashSet<>(Arrays.asList("pipeline_with_model_" + modelId3 + 0, "pipeline_with_model_" + modelId3 + 1)))); + + } + + private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException { + Map configurations = new HashMap<>(modelId.length); + for (String id : modelId) { + configurations.put("pipeline_with_model_" + id + 0, newConfigurationWithInferenceProcessor(id, 0)); + configurations.put("pipeline_with_model_" + id + 1, newConfigurationWithInferenceProcessor(id, 1)); + } + for (int i = 0; i < 3; i++) { + configurations.put("pipeline_without_model_" + i, newConfigurationWithOutInferenceProcessor(i)); + } + IngestMetadata ingestMetadata = new IngestMetadata(configurations); + + return ClusterState.builder(new ClusterName("_name")) + .metaData(MetaData.builder().putCustom(IngestMetadata.TYPE, ingestMetadata)) + .build(); + } + + private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId, int num) throws IOException { + try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors", + Collections.singletonList( + Collections.singletonMap(InferenceProcessor.TYPE, + new HashMap() {{ + put(InferenceProcessor.MODEL_ID, modelId); + put("inference_config", Collections.singletonMap("regression", Collections.emptyMap())); + put("field_mappings", Collections.emptyMap()); + put("target_field", randomAlphaOfLength(10)); + }}))))) { + return new PipelineConfiguration("pipeline_with_model_" + modelId + num, + BytesReference.bytes(xContentBuilder), + XContentType.JSON); + } + } + + private static PipelineConfiguration newConfigurationWithOutInferenceProcessor(int i) throws IOException { + try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors", + Collections.singletonList(Collections.singletonMap("not_inference", Collections.emptyMap()))))) { + return new PipelineConfiguration("pipeline_without_model_" + i, BytesReference.bytes(xContentBuilder), XContentType.JSON); + } + } + + private static NodeStats buildNodeStats(IngestStats.Stats overallStats, + List pipelineNames, + List> processorStats) { + List pipelineids = pipelineNames.stream().map(IngestStats.PipelineStat::getPipelineId).collect(Collectors.toList()); + IngestStats ingestStats = new IngestStats( + overallStats, + pipelineNames, + IntStream.range(0, pipelineids.size()).boxed().collect(Collectors.toMap(pipelineids::get, processorStats::get))); + return new NodeStats(mock(DiscoveryNode.class), + Instant.now().toEpochMilli(), null, null, null, null, null, null, null, null, + null, null, null, ingestStats, null); + + } + +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models_stats.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models_stats.json new file mode 100644 index 0000000000000..703380c708703 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models_stats.json @@ -0,0 +1,48 @@ +{ + "ml.get_trained_models_stats":{ + "documentation":{ + "url":"TODO" + }, + "stability":"experimental", + "url":{ + "paths":[ + { + "path":"/_ml/inference/{model_id}/_stats", + "methods":[ + "GET" + ], + "parts":{ + "model_id":{ + "type":"string", + "description":"The ID of the trained models stats to fetch" + } + } + }, + { + "path":"/_ml/inference/_stats", + "methods":[ + "GET" + ] + } + ] + }, + "params":{ + "allow_no_match":{ + "type":"boolean", + "required":false, + "description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)", + "default":true + }, + "from":{ + "type":"int", + "description":"skips a number of trained models", + "default":0 + }, + "size":{ + "type":"int", + "description":"specifies a max number of trained models to get", + "default":100 + } + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml new file mode 100644 index 0000000000000..efc6b784dbeac --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml @@ -0,0 +1,226 @@ +setup: + - skip: + features: headers + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + index: + id: trained_model_config-unused-regression-model1-0 + index: .ml-inference-000001 + body: > + { + "model_id": "unused-regression-model1", + "created_by": "ml_tests", + "version": "8.0.0", + "description": "empty model for tests", + "create_time": 0, + "model_version": 0, + "model_type": "local", + "doc_type": "trained_model_config" + } + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + index: + id: trained_model_config-unused-regression-model-0 + index: .ml-inference-000001 + body: > + { + "model_id": "unused-regression-model", + "created_by": "ml_tests", + "version": "8.0.0", + "description": "empty model for tests", + "create_time": 0, + "model_version": 0, + "model_type": "local", + "doc_type": "trained_model_config" + } + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + index: + id: trained_model_config-used-regression-model-0 + index: .ml-inference-000001 + body: > + { + "model_id": "used-regression-model", + "created_by": "ml_tests", + "version": "8.0.0", + "description": "empty model for tests", + "create_time": 0, + "model_version": 0, + "model_type": "local", + "doc_type": "trained_model_config" + } + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + indices.refresh: {} + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ingest.put_pipeline: + id: "regression-model-pipeline" + body: > + { + "processors": [ + { + "inference" : { + "model_id" : "used-regression-model", + "inference_config": {"regression": {}}, + "target_field": "regression_field", + "field_mappings": {} + } + } + ] + } + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ingest.put_pipeline: + id: "regression-model-pipeline-1" + body: > + { + "processors": [ + { + "inference" : { + "model_id" : "used-regression-model", + "inference_config": {"regression": {}}, + "target_field": "regression_field", + "field_mappings": {} + } + } + ] + } +--- +"Test get stats given missing trained model": + + - do: + catch: missing + ml.get_trained_models_stats: + model_id: "missing-trained-model" +--- +"Test get stats given expression without matches and allow_no_match is false": + + - do: + catch: missing + ml.get_trained_models_stats: + model_id: "missing-trained-model*" + allow_no_match: false + +--- +"Test get stats given expression without matches and allow_no_match is true": + + - do: + ml.get_trained_models_stats: + model_id: "missing-trained-model*" + allow_no_match: true + - match: { count: 0 } + - match: { trained_model_stats: [] } +--- +"Test get stats given trained models": + + - do: + ml.get_trained_models_stats: + model_id: "unused-regression-model" + + - match: { count: 1 } + + - do: + ml.get_trained_models_stats: + model_id: "_all" + - match: { count: 3 } + - match: { trained_model_stats.0.model_id: unused-regression-model } + - match: { trained_model_stats.0.pipeline_count: 0 } + - is_false: trained_model_stats.0.ingest + - match: { trained_model_stats.1.model_id: unused-regression-model1 } + - match: { trained_model_stats.1.pipeline_count: 0 } + - is_false: trained_model_stats.1.ingest + - match: { trained_model_stats.2.pipeline_count: 2 } + - is_true: trained_model_stats.2.ingest + + - do: + ml.get_trained_models_stats: + model_id: "*" + - match: { count: 3 } + - match: { trained_model_stats.0.model_id: unused-regression-model } + - match: { trained_model_stats.0.pipeline_count: 0 } + - is_false: trained_model_stats.0.ingest + - match: { trained_model_stats.1.model_id: unused-regression-model1 } + - match: { trained_model_stats.1.pipeline_count: 0 } + - is_false: trained_model_stats.1.ingest + - match: { trained_model_stats.2.pipeline_count: 2 } + - is_true: trained_model_stats.2.ingest + + - do: + ml.get_trained_models_stats: + model_id: "unused-regression-model*" + - match: { count: 2 } + - match: { trained_model_stats.0.model_id: unused-regression-model } + - match: { trained_model_stats.0.pipeline_count: 0 } + - is_false: trained_model_stats.0.ingest + - match: { trained_model_stats.1.model_id: unused-regression-model1 } + - match: { trained_model_stats.1.pipeline_count: 0 } + - is_false: trained_model_stats.1.ingest + + - do: + ml.get_trained_models_stats: + model_id: "unused-regression-model*" + size: 1 + - match: { count: 2 } + - match: { trained_model_stats.0.model_id: unused-regression-model } + - match: { trained_model_stats.0.pipeline_count: 0 } + - is_false: trained_model_stats.0.ingest + + - do: + ml.get_trained_models_stats: + model_id: "unused-regression-model*" + from: 1 + size: 1 + - match: { count: 2 } + - match: { trained_model_stats.0.model_id: unused-regression-model1 } + - match: { trained_model_stats.0.pipeline_count: 0 } + - is_false: trained_model_stats.0.ingest + + - do: + ml.get_trained_models_stats: + model_id: "used-regression-model" + + - match: { count: 1 } + - match: { trained_model_stats.0.model_id: used-regression-model } + - match: { trained_model_stats.0.pipeline_count: 2 } + - match: + trained_model_stats.0.ingest.total: + count: 0 + time_in_millis: 0 + current: 0 + failed: 0 + + - match: + trained_model_stats.0.ingest.pipelines.regression-model-pipeline: + count: 0 + time_in_millis: 0 + current: 0 + failed: 0 + processors: + - inference: + count: 0 + time_in_millis: 0 + current: 0 + failed: 0 + + - match: + trained_model_stats.0.ingest.pipelines.regression-model-pipeline-1: + count: 0 + time_in_millis: 0 + current: 0 + failed: 0 + processors: + - inference: + count: 0 + time_in_millis: 0 + current: 0 + failed: 0 From eb5e5e8480dd928ddcdbb421ae6285acb6afe6df Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 30 Oct 2019 13:32:11 -0400 Subject: [PATCH 13/17] [ML][Inference] add inference processors and trained models to usage (#47869) * [ML][Inference] add inference processors and trained models to usage * renaming usage fields --- .../ml/MachineLearningFeatureSetUsage.java | 13 ++ .../MachineLearningUsageTransportAction.java | 123 +++++++++++++++++- ...chineLearningInfoTransportActionTests.java | 109 ++++++++++++++++ 3 files changed, 239 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java index 4acfe8f091cb3..d57235f15609f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java @@ -29,10 +29,12 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage { public static final String CREATED_BY = "created_by"; public static final String NODE_COUNT = "node_count"; public static final String DATA_FRAME_ANALYTICS_JOBS_FIELD = "data_frame_analytics_jobs"; + public static final String INFERENCE_FIELD = "inference"; private final Map jobsUsage; private final Map datafeedsUsage; private final Map analyticsUsage; + private final Map inferenceUsage; private final int nodeCount; public MachineLearningFeatureSetUsage(boolean available, @@ -40,11 +42,13 @@ public MachineLearningFeatureSetUsage(boolean available, Map jobsUsage, Map datafeedsUsage, Map analyticsUsage, + Map inferenceUsage, int nodeCount) { super(XPackField.MACHINE_LEARNING, available, enabled); this.jobsUsage = Objects.requireNonNull(jobsUsage); this.datafeedsUsage = Objects.requireNonNull(datafeedsUsage); this.analyticsUsage = Objects.requireNonNull(analyticsUsage); + this.inferenceUsage = Objects.requireNonNull(inferenceUsage); this.nodeCount = nodeCount; } @@ -57,6 +61,11 @@ public MachineLearningFeatureSetUsage(StreamInput in) throws IOException { } else { this.analyticsUsage = Collections.emptyMap(); } + if (in.getVersion().onOrAfter(Version.V_8_0_0)) { + this.inferenceUsage = in.readMap(); + } else { + this.inferenceUsage = Collections.emptyMap(); + } this.nodeCount = in.readInt(); } @@ -68,6 +77,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getVersion().onOrAfter(Version.V_7_4_0)) { out.writeMap(analyticsUsage); } + if (out.getVersion().onOrAfter(Version.V_8_0_0)) { + out.writeMap(inferenceUsage); + } out.writeInt(nodeCount); } @@ -77,6 +89,7 @@ protected void innerXContent(XContentBuilder builder, Params params) throws IOEx builder.field(JOBS_FIELD, jobsUsage); builder.field(DATAFEEDS_FIELD, datafeedsUsage); builder.field(DATA_FRAME_ANALYTICS_JOBS_FIELD, analyticsUsage); + builder.field(INFERENCE_FIELD, inferenceUsage); if (nodeCount >= 0) { builder.field(NODE_COUNT, nodeCount); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearningUsageTransportAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearningUsageTransportAction.java index ab815e17fe0c8..4f731d9804c4b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearningUsageTransportAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearningUsageTransportAction.java @@ -7,6 +7,12 @@ import org.apache.lucene.util.Counter; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.cluster.node.stats.NodeStats; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.ClusterState; @@ -16,11 +22,13 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.env.Environment; +import org.elasticsearch.ingest.IngestStats; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.protocol.xpack.XPackUsageRequest; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction; import org.elasticsearch.xpack.core.action.XPackUsageFeatureResponse; @@ -32,19 +40,24 @@ import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.config.Job; import org.elasticsearch.xpack.core.ml.job.config.JobState; import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSizeStats; import org.elasticsearch.xpack.core.ml.stats.ForecastStats; import org.elasticsearch.xpack.core.ml.stats.StatsAccumulator; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.job.JobManagerHolder; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransportAction { @@ -72,7 +85,7 @@ protected void masterOperation(Task task, XPackUsageRequest request, ClusterStat ActionListener listener) { if (enabled == false) { MachineLearningFeatureSetUsage usage = new MachineLearningFeatureSetUsage(licenseState.isMachineLearningAllowed(), enabled, - Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), 0); + Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), 0); listener.onResponse(new XPackUsageFeatureResponse(usage)); return; } @@ -80,20 +93,48 @@ protected void masterOperation(Task task, XPackUsageRequest request, ClusterStat Map jobsUsage = new LinkedHashMap<>(); Map datafeedsUsage = new LinkedHashMap<>(); Map analyticsUsage = new LinkedHashMap<>(); + Map inferenceUsage = new LinkedHashMap<>(); int nodeCount = mlNodeCount(state); - // Step 3. Extract usage from data frame analytics stats and return usage response - ActionListener dataframeAnalyticsListener = ActionListener.wrap( + // Step 5. extract trained model config count and then return results + ActionListener trainedModelConfigCountListener = ActionListener.wrap( response -> { - addDataFrameAnalyticsUsage(response, analyticsUsage); + addTrainedModelStats(response, inferenceUsage); MachineLearningFeatureSetUsage usage = new MachineLearningFeatureSetUsage(licenseState.isMachineLearningAllowed(), - enabled, jobsUsage, datafeedsUsage, analyticsUsage, nodeCount); + enabled, jobsUsage, datafeedsUsage, analyticsUsage, inferenceUsage, nodeCount); listener.onResponse(new XPackUsageFeatureResponse(usage)); }, listener::onFailure ); - // Step 2. Extract usage from datafeeds stats and return usage response + // Step 4. Extract usage from ingest statistics and gather trained model config count + ActionListener nodesStatsListener = ActionListener.wrap( + response -> { + addInferenceIngestUsage(response, inferenceUsage); + SearchRequestBuilder requestBuilder = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) + .setSize(0) + .setTrackTotalHits(true); + ClientHelper.executeAsyncWithOrigin(client.threadPool().getThreadContext(), + ClientHelper.ML_ORIGIN, + requestBuilder.request(), + trainedModelConfigCountListener, + client::search); + }, + listener::onFailure + ); + + // Step 3. Extract usage from data frame analytics stats and then request ingest node stats + ActionListener dataframeAnalyticsListener = ActionListener.wrap( + response -> { + addDataFrameAnalyticsUsage(response, analyticsUsage); + String[] ingestNodes = ingestNodes(state); + NodesStatsRequest nodesStatsRequest = new NodesStatsRequest(ingestNodes).clear().ingest(true); + client.execute(NodesStatsAction.INSTANCE, nodesStatsRequest, nodesStatsListener); + }, + listener::onFailure + ); + + // Step 2. Extract usage from datafeeds stats and then request stats for data frame analytics ActionListener datafeedStatsListener = ActionListener.wrap(response -> { addDatafeedsUsage(response, datafeedsUsage); @@ -227,6 +268,66 @@ private void addDataFrameAnalyticsUsage(GetDataFrameAnalyticsStatsAction.Respons } } + private static void initializeStats(Map emptyStatsMap) { + emptyStatsMap.put("sum", 0L); + emptyStatsMap.put("min", 0L); + emptyStatsMap.put("max", 0L); + } + + private static void updateStats(Map statsMap, Long value) { + statsMap.compute("sum", (k, v) -> v + value); + statsMap.compute("min", (k, v) -> Math.min(v, value)); + statsMap.compute("max", (k, v) -> Math.max(v, value)); + } + + //TODO separate out ours and users models possibly regression vs classification + private void addTrainedModelStats(SearchResponse response, Map inferenceUsage) { + inferenceUsage.put("trained_models", + Collections.singletonMap(MachineLearningFeatureSetUsage.ALL, createCountUsageEntry(response.getHits().getTotalHits().value))); + } + + //TODO separate out ours and users models possibly regression vs classification + private void addInferenceIngestUsage(NodesStatsResponse response, Map inferenceUsage) { + Set pipelines = new HashSet<>(); + Map docCountStats = new HashMap<>(3); + Map timeStats = new HashMap<>(3); + Map failureStats = new HashMap<>(3); + initializeStats(docCountStats); + initializeStats(timeStats); + initializeStats(failureStats); + + response.getNodes() + .stream() + .map(NodeStats::getIngestStats) + .map(IngestStats::getProcessorStats) + .forEach(map -> + map.forEach((pipelineId, processors) -> { + boolean containsInference = false; + for(IngestStats.ProcessorStat stats : processors) { + if (stats.getName().equals(InferenceProcessor.TYPE)) { + containsInference = true; + long ingestCount = stats.getStats().getIngestCount(); + long ingestTime = stats.getStats().getIngestTimeInMillis(); + long failureCount = stats.getStats().getIngestFailedCount(); + updateStats(docCountStats, ingestCount); + updateStats(timeStats, ingestTime); + updateStats(failureStats, failureCount); + } + } + if (containsInference) { + pipelines.add(pipelineId); + } + }) + ); + + Map ingestUsage = new HashMap<>(6); + ingestUsage.put("pipelines", createCountUsageEntry(pipelines.size())); + ingestUsage.put("num_docs_processed", docCountStats); + ingestUsage.put("time_ms", timeStats); + ingestUsage.put("num_failures", failureStats); + inferenceUsage.put("ingest_processors", Collections.singletonMap(MachineLearningFeatureSetUsage.ALL, ingestUsage)); + } + private static int mlNodeCount(final ClusterState clusterState) { int mlNodeCount = 0; for (DiscoveryNode node : clusterState.getNodes()) { @@ -236,4 +337,14 @@ private static int mlNodeCount(final ClusterState clusterState) { } return mlNodeCount; } + + private static String[] ingestNodes(final ClusterState clusterState) { + String[] ingestNodes = new String[clusterState.nodes().getIngestNodes().size()]; + Iterator nodeIterator = clusterState.nodes().getIngestNodes().keysIt(); + int i = 0; + while(nodeIterator.hasNext()) { + ingestNodes[i++] = nodeIterator.next(); + } + return ingestNodes; + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java index b2f69158aca05..3659778222c87 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java @@ -5,11 +5,19 @@ */ package org.elasticsearch.xpack.ml; +import org.apache.lucene.search.TotalHits; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.cluster.node.stats.NodeStats; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.MetaData; @@ -20,13 +28,18 @@ import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.env.Environment; import org.elasticsearch.env.TestEnvironment; +import org.elasticsearch.ingest.IngestStats; import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.XPackFeatureSet; import org.elasticsearch.xpack.core.XPackField; @@ -41,6 +54,7 @@ import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; import org.elasticsearch.xpack.core.ml.job.config.Detector; @@ -50,10 +64,12 @@ import org.elasticsearch.xpack.core.ml.stats.ForecastStats; import org.elasticsearch.xpack.core.ml.stats.ForecastStatsTests; import org.elasticsearch.xpack.core.watcher.support.xcontent.XContentSource; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; import org.junit.Before; +import java.time.Instant; import java.util.Arrays; import java.util.Collections; import java.util.Date; @@ -62,6 +78,8 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.notNullValue; @@ -99,6 +117,8 @@ public void init() { givenJobs(Collections.emptyList(), Collections.emptyList()); givenDatafeeds(Collections.emptyList()); givenDataFrameAnalytics(Collections.emptyList()); + givenProcessorStats(Collections.emptyList()); + givenTrainedModelConfigCount(0); } private MachineLearningUsageTransportAction newUsageAction(Settings settings) { @@ -169,12 +189,50 @@ public void testUsage() throws Exception { buildDatafeedStats(DatafeedState.STARTED), buildDatafeedStats(DatafeedState.STOPPED) )); + givenDataFrameAnalytics(Arrays.asList( buildDataFrameAnalyticsStats(DataFrameAnalyticsState.STOPPED), buildDataFrameAnalyticsStats(DataFrameAnalyticsState.STOPPED), buildDataFrameAnalyticsStats(DataFrameAnalyticsState.STARTED) )); + givenProcessorStats(Arrays.asList( + buildNodeStats( + Arrays.asList("pipeline1", "pipeline2", "pipeline3"), + Arrays.asList( + Arrays.asList( + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, new IngestStats.Stats(10, 1, 0, 0)), + new IngestStats.ProcessorStat("grok", new IngestStats.Stats(10, 1, 0, 0)), + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, new IngestStats.Stats(100, 10, 0, 1)) + ), + Arrays.asList( + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, new IngestStats.Stats(5, 1, 0, 0)), + new IngestStats.ProcessorStat("grok", new IngestStats.Stats(10, 1, 0, 0)) + ), + Arrays.asList( + new IngestStats.ProcessorStat("grok", new IngestStats.Stats(10, 1, 0, 0)) + ) + )), + buildNodeStats( + Arrays.asList("pipeline1", "pipeline2", "pipeline3"), + Arrays.asList( + Arrays.asList( + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, new IngestStats.Stats(0, 0, 0, 0)), + new IngestStats.ProcessorStat("grok", new IngestStats.Stats(0, 0, 0, 0)), + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, new IngestStats.Stats(10, 1, 0, 0)) + ), + Arrays.asList( + new IngestStats.ProcessorStat(InferenceProcessor.TYPE, new IngestStats.Stats(5, 1, 0, 0)), + new IngestStats.ProcessorStat("grok", new IngestStats.Stats(10, 1, 0, 0)) + ), + Arrays.asList( + new IngestStats.ProcessorStat("grok", new IngestStats.Stats(10, 1, 0, 0)) + ) + )) + )); + givenTrainedModelConfigCount(100); + + var usageAction = newUsageAction(settings.build()); PlainActionFuture future = new PlainActionFuture<>(); usageAction.masterOperation(null, null, ClusterState.EMPTY_STATE, future); @@ -251,6 +309,18 @@ public void testUsage() throws Exception { assertThat(source.getValue("jobs.opened.forecasts.total"), equalTo(11)); assertThat(source.getValue("jobs.opened.forecasts.forecasted_jobs"), equalTo(2)); + + assertThat(source.getValue("inference.trained_models._all.count"), equalTo(100)); + assertThat(source.getValue("inference.ingest_processors._all.pipelines.count"), equalTo(2)); + assertThat(source.getValue("inference.ingest_processors._all.num_docs_processed.sum"), equalTo(130)); + assertThat(source.getValue("inference.ingest_processors._all.num_docs_processed.min"), equalTo(0)); + assertThat(source.getValue("inference.ingest_processors._all.num_docs_processed.max"), equalTo(100)); + assertThat(source.getValue("inference.ingest_processors._all.time_ms.sum"), equalTo(14)); + assertThat(source.getValue("inference.ingest_processors._all.time_ms.min"), equalTo(0)); + assertThat(source.getValue("inference.ingest_processors._all.time_ms.max"), equalTo(10)); + assertThat(source.getValue("inference.ingest_processors._all.num_failures.sum"), equalTo(1)); + assertThat(source.getValue("inference.ingest_processors._all.num_failures.min"), equalTo(0)); + assertThat(source.getValue("inference.ingest_processors._all.num_failures.max"), equalTo(1)); } } @@ -417,6 +487,34 @@ private void givenDataFrameAnalytics(List stats) { + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + ActionListener listener = + (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(new NodesStatsResponse(new ClusterName("_name"), stats, Collections.emptyList())); + return Void.TYPE; + }).when(client).execute(same(NodesStatsAction.INSTANCE), any(), any()); + } + + private void givenTrainedModelConfigCount(long count) { + when(client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)) + .thenReturn(new SearchRequestBuilder(client, SearchAction.INSTANCE)); + ThreadPool pool = mock(ThreadPool.class); + when(pool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + when(client.threadPool()).thenReturn(pool); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + ActionListener listener = + (ActionListener) invocationOnMock.getArguments()[1]; + SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(count, TotalHits.Relation.EQUAL_TO), (float)0.0); + SearchResponse searchResponse = mock(SearchResponse.class); + when(searchResponse.getHits()).thenReturn(searchHits); + listener.onResponse(searchResponse); + return Void.TYPE; + }).when(client).search(any(), any()); + } + private static Detector buildMinDetector(String fieldName) { Detector.Builder detectorBuilder = new Detector.Builder(); detectorBuilder.setFunction("min"); @@ -463,6 +561,17 @@ private static GetDataFrameAnalyticsStatsAction.Response.Stats buildDataFrameAna return stats; } + private static NodeStats buildNodeStats(List pipelineNames, List> processorStats) { + IngestStats ingestStats = new IngestStats( + new IngestStats.Stats(0,0,0,0), + Collections.emptyList(), + IntStream.range(0, pipelineNames.size()).boxed().collect(Collectors.toMap(pipelineNames::get, processorStats::get))); + return new NodeStats(mock(DiscoveryNode.class), + Instant.now().toEpochMilli(), null, null, null, null, null, null, null, null, + null, null, null, ingestStats, null); + + } + private static ForecastStats buildForecastStats(long numberOfForecasts) { return new ForecastStatsTests().createForecastStats(numberOfForecasts, numberOfForecasts); } From cd79777612f237e743fa7eeea5dbd2bfbf2503ff Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 31 Oct 2019 18:22:05 -0400 Subject: [PATCH 14/17] [ML][Inference] add new flag for optionally including model definition (#48718) * [ML][Inference] add new flag for optionally including model definition * adjusting after definition and config split * revert unnecessary changes to AbstractTransportGetResourcesAction * fixing TrainedModelDefinitionTests * fixing yaml tests from previous code changes * fixing integration test * making tests an assertBusy for verification --- .../ml/action/GetTrainedModelsAction.java | 71 ++++++-- .../xpack/core/ml/job/messages/Messages.java | 4 + .../action/GetTrainedModelsRequestTests.java | 2 +- .../TrainedModelDefinitionTests.java | 38 ++-- .../xpack/ml/integration/TrainedModelIT.java | 56 +++++- .../TransportGetTrainedModelsAction.java | 102 +++++------ .../TransportGetTrainedModelsStatsAction.java | 106 +---------- .../persistence/TrainedModelProvider.java | 166 +++++++++++++++++- .../inference/RestGetTrainedModelsAction.java | 6 +- .../ModelLoadingServiceTests.java | 8 + .../integration/ModelInferenceActionIT.java | 6 +- .../api/ml.get_trained_models.json | 6 + .../test/ml/inference_stats_crud.yml | 20 ++- 13 files changed, 391 insertions(+), 200 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java index 005f0d180cdc1..b86cfced5524f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java @@ -5,17 +5,20 @@ */ package org.elasticsearch.xpack.core.ml.action; -import org.elasticsearch.action.ActionRequestBuilder; import org.elasticsearch.action.ActionType; -import org.elasticsearch.client.ElasticsearchClient; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest; import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse; import org.elasticsearch.xpack.core.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + public class GetTrainedModelsAction extends ActionType { @@ -28,19 +31,20 @@ private GetTrainedModelsAction() { public static class Request extends AbstractGetResourcesRequest { + public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition"); public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); - public Request() { - setAllowNoResources(true); - } + private final boolean includeModelDefinition; - public Request(String id) { + public Request(String id, boolean includeModelDefinition) { setResourceId(id); setAllowNoResources(true); + this.includeModelDefinition = includeModelDefinition; } public Request(StreamInput in) throws IOException { super(in); + this.includeModelDefinition = in.readBoolean(); } @Override @@ -48,6 +52,32 @@ public String getResourceIdField() { return TrainedModelConfig.MODEL_ID.getPreferredName(); } + public boolean isIncludeModelDefinition() { + return includeModelDefinition; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeBoolean(includeModelDefinition); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), includeModelDefinition); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + Request other = (Request) obj; + return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition; + } } public static class Response extends AbstractGetResourcesResponse { @@ -66,12 +96,33 @@ public Response(QueryPage trainedModels) { protected Reader getReader() { return TrainedModelConfig::new; } - } - public static class RequestBuilder extends ActionRequestBuilder { + public static Builder builder() { + return new Builder(); + } + + public static class Builder { - public RequestBuilder(ElasticsearchClient client) { - super(client, INSTANCE, new Request()); + private long totalCount; + private List configs = Collections.emptyList(); + + private Builder() { + } + + public Builder setTotalCount(long totalCount) { + this.totalCount = totalCount; + return this; + } + + public Builder setModels(List configs) { + this.configs = configs; + return this; + } + + public Response build() { + return new Response(new QueryPage<>(configs, totalCount, RESULTS_FIELD)); + } } } + } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 61cd542a9fc2d..00fce4e58e826 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -83,9 +83,13 @@ public final class Messages { public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists"; public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]"; public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]"; + public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}"; public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION = "Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]"; public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]"; + public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]"; + public static final String INFERENCE_TO_MANY_DEFINITIONS_REQUESTED = + "Getting model definition is not supported when getting more than one model"; public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again"; public static final String JOB_AUDIT_CREATED = "Job created"; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java index 0abc0318e215e..85345467df169 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java @@ -14,7 +14,7 @@ public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCas @Override protected Request createTestInstance() { - Request request = new Request(randomAlphaOfLength(20)); + Request request = new Request(randomAlphaOfLength(20), randomBoolean()); request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100))); return request; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java index 20213ba99d62d..5fdadac712d0d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.inference; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; @@ -12,6 +13,8 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentParseException; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.search.SearchModule; @@ -22,7 +25,6 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; -import org.junit.Before; import java.io.IOException; import java.util.ArrayList; @@ -33,27 +35,21 @@ import java.util.stream.Stream; import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; public class TrainedModelDefinitionTests extends AbstractSerializingTestCase { - private boolean lenient; - - @Before - public void chooseStrictOrLenient() { - lenient = randomBoolean(); - } - @Override protected TrainedModelDefinition doParseInstance(XContentParser parser) throws IOException { - return TrainedModelDefinition.fromXContent(parser, lenient).build(); + return TrainedModelDefinition.fromXContent(parser, true).build(); } @Override protected boolean supportsUnknownFields() { - return lenient; + return true; } @Override @@ -63,7 +59,7 @@ protected Predicate getRandomFieldsExcludeFilter() { @Override protected ToXContent.Params getToXContentParams() { - return lenient ? ToXContent.EMPTY_PARAMS : new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")); + return new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")); } @Override @@ -286,9 +282,27 @@ public void testTreeSchemaDeserialization() throws IOException { assertThat(definition.getTrainedModel().getClass(), equalTo(Tree.class)); } + public void testStrictParser() throws IOException { + TrainedModelDefinition.Builder builder = createRandomBuilder("asdf"); + BytesReference reference = XContentHelper.toXContent(builder.build(), + XContentType.JSON, + new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")), + false); + + XContentParser parser = XContentHelper.createParser(xContentRegistry(), + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + reference, + XContentType.JSON); + + XContentParseException exception = expectThrows(XContentParseException.class, + () -> TrainedModelDefinition.fromXContent(parser, false)); + + assertThat(exception.getMessage(), containsString("[trained_model_definition] unknown field [doc_type]")); + } + @Override protected TrainedModelDefinition createTestInstance() { - return createRandomBuilder(null).build(); + return createRandomBuilder(randomAlphaOfLength(10)).build(); } @Override diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java index 9c059bced93d3..1982cec7eca0c 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -25,6 +25,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests; @@ -63,6 +64,11 @@ public void testGetTrainedModels() throws IOException { model1.setJsonEntity(buildRegressionModel(modelId)); assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201)); + Request modelDefinition1 = new Request("PUT", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId)); + modelDefinition1.setJsonEntity(buildRegressionModelDefinition(modelId)); + assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201)); + Request model2 = new Request("PUT", InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId2); model2.setJsonEntity(buildRegressionModel(modelId2)); @@ -85,8 +91,26 @@ public void testGetTrainedModels() throws IOException { response = EntityUtils.toString(getModel.getEntity()); assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); assertThat(response, containsString("\"model_id\":\"test_regression_model-2\"")); + assertThat(response, not(containsString("\"definition\""))); assertThat(response, containsString("\"count\":2")); + getModel = client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/test_regression_model?human=true&include_model_definition=true")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); + assertThat(response, containsString("\"heap_memory_estimation_bytes\"")); + assertThat(response, containsString("\"heap_memory_estimation\"")); + assertThat(response, containsString("\"definition\"")); + assertThat(response, containsString("\"count\":1")); + + ResponseException responseException = expectThrows(ResponseException.class, () -> + client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/test_regression*?human=true&include_model_definition=true"))); + assertThat(EntityUtils.toString(responseException.getResponse().getEntity()), + containsString(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED)); + getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference/test_regression_model,test_regression_model-2")); assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); @@ -131,6 +155,11 @@ public void testDeleteTrainedModels() throws IOException { model1.setJsonEntity(buildRegressionModel(modelId)); assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201)); + Request modelDefinition1 = new Request("PUT", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId)); + modelDefinition1.setJsonEntity(buildRegressionModelDefinition(modelId)); + assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201)); + adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh")); Response delModel = client().performRequest(new Request("DELETE", @@ -141,6 +170,18 @@ public void testDeleteTrainedModels() throws IOException { ResponseException responseException = expectThrows(ResponseException.class, () -> client().performRequest(new Request("DELETE", MachineLearning.BASE_PATH + "inference/" + modelId))); assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404)); + + responseException = expectThrows(ResponseException.class, + () -> client().performRequest( + new Request("GET", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId)))); + assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404)); + + responseException = expectThrows(ResponseException.class, + () -> client().performRequest( + new Request("GET", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId))); + assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404)); } private static String buildRegressionModel(String modelId) throws IOException { @@ -149,9 +190,6 @@ private static String buildRegressionModel(String modelId) throws IOException { .setModelId(modelId) .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3"))) .setCreatedBy("ml_test") - .setDefinition(new TrainedModelDefinition.Builder() - .setPreProcessors(Collections.emptyList()) - .setTrainedModel(LocalModelTests.buildRegression())) .setVersion(Version.CURRENT) .setCreateTime(Instant.now()) .build() @@ -160,6 +198,18 @@ private static String buildRegressionModel(String modelId) throws IOException { } } + private static String buildRegressionModelDefinition(String modelId) throws IOException { + try(XContentBuilder builder = XContentFactory.jsonBuilder()) { + new TrainedModelDefinition.Builder() + .setPreProcessors(Collections.emptyList()) + .setTrainedModel(LocalModelTests.buildRegression()) + .setModelId(modelId) + .build() + .toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"))); + return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON); + } + } + @After public void clearMlState() throws Exception { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java index ee95ddbd9670d..15629579368f3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java @@ -5,86 +5,72 @@ */ package org.elasticsearch.xpack.ml.action; -import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.client.Client; -import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.ParseField; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.inject.Inject; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; -import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.core.action.AbstractTransportGetResourcesAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Response; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; -import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import java.util.Collections; +import java.util.Set; -public class TransportGetTrainedModelsAction extends AbstractTransportGetResourcesAction { +public class TransportGetTrainedModelsAction extends HandledTransportAction { + + private final TrainedModelProvider provider; @Inject - public TransportGetTrainedModelsAction(TransportService transportService, ActionFilters actionFilters, Client client, - NamedXContentRegistry xContentRegistry) { - super(GetTrainedModelsAction.NAME, transportService, actionFilters, GetTrainedModelsAction.Request::new, client, - xContentRegistry); + public TransportGetTrainedModelsAction(TransportService transportService, + ActionFilters actionFilters, + TrainedModelProvider trainedModelProvider) { + super(GetTrainedModelsAction.NAME, transportService, actionFilters, GetTrainedModelsAction.Request::new); + this.provider = trainedModelProvider; } @Override - protected ParseField getResultsField() { - return GetTrainedModelsAction.Response.RESULTS_FIELD; - } + protected void doExecute(Task task, Request request, ActionListener listener) { - @Override - protected String[] getIndices() { - return new String[] { InferenceIndexConstants.INDEX_PATTERN }; - } + Response.Builder responseBuilder = Response.builder(); - @Override - protected TrainedModelConfig parse(XContentParser parser) { - return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build(); - } + ActionListener>> idExpansionListener = ActionListener.wrap( + totalAndIds -> { + responseBuilder.setTotalCount(totalAndIds.v1()); - @Override - protected ResourceNotFoundException notFoundException(String resourceId) { - return ExceptionsHelper.missingTrainedModel(resourceId); - } + if (totalAndIds.v2().isEmpty()) { + listener.onResponse(responseBuilder.build()); + return; + } - @Override - protected void doExecute(Task task, GetTrainedModelsAction.Request request, - ActionListener listener) { - searchResources(request, ActionListener.wrap( - queryPage -> listener.onResponse(new GetTrainedModelsAction.Response(queryPage)), - listener::onFailure - )); - } + if (request.isIncludeModelDefinition() && totalAndIds.v2().size() > 1) { + listener.onFailure( + ExceptionsHelper.badRequestException(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED) + ); + return; + } - @Override - protected String executionOrigin() { - return ML_ORIGIN; - } - - @Override - protected String extractIdFromResource(TrainedModelConfig config) { - return config.getModelId(); - } + if (request.isIncludeModelDefinition()) { + provider.getTrainedModel(totalAndIds.v2().iterator().next(), true, ActionListener.wrap( + config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()), + listener::onFailure + )); + } else { + provider.getTrainedModels(totalAndIds.v2(), request.isAllowNoResources(), ActionListener.wrap( + configs -> listener.onResponse(responseBuilder.setModels(configs).build()), + listener::onFailure + )); + } + }, + listener::onFailure + ); - @Override - protected SearchSourceBuilder customSearchOptions(SearchSourceBuilder searchSourceBuilder) { - return searchSourceBuilder.sort("_index", SortOrder.DESC); + provider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idExpansionListener); } - @Nullable - protected QueryBuilder additionalQuery() { - return QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME); - } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java index aabd760be4097..a15579b62de6a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java @@ -11,37 +11,23 @@ import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction; import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest; import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse; -import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; -import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.metrics.CounterMetric; -import org.elasticsearch.common.regex.Regex; -import org.elasticsearch.index.query.BoolQueryBuilder; -import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.ingest.IngestStats; import org.elasticsearch.ingest.Pipeline; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import java.util.ArrayList; import java.util.HashMap; @@ -63,17 +49,20 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction private final Client client; private final ClusterService clusterService; private final IngestService ingestService; + private final TrainedModelProvider trainedModelProvider; @Inject public TransportGetTrainedModelsStatsAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, IngestService ingestService, + TrainedModelProvider trainedModelProvider, Client client) { super(GetTrainedModelsStatsAction.NAME, transportService, actionFilters, GetTrainedModelsStatsAction.Request::new); this.client = client; this.clusterService = clusterService; this.ingestService = ingestService; + this.trainedModelProvider = trainedModelProvider; } @Override @@ -105,7 +94,7 @@ protected void doExecute(Task task, listener::onFailure ); - expandIds(request, idsListener); + trainedModelProvider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idsListener); } static Map inferenceIngestStatsByPipelineId(NodesStatsResponse response, @@ -124,91 +113,6 @@ static Map inferenceIngestStatsByPipelineId(NodesStatsRespo return ingestStatsMap; } - - private void expandIds(GetTrainedModelsStatsAction.Request request, ActionListener>> idsListener) { - String[] tokens = Strings.tokenizeToStringArray(request.getResourceId(), ","); - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() - .sort(SortBuilders.fieldSort(request.getResourceIdField()) - // If there are no resources, there might be no mapping for the id field. - // This makes sure we don't get an error if that happens. - .unmappedType("long")) - .query(buildQuery(tokens, request.getResourceIdField())); - if (request.getPageParams() != null) { - sourceBuilder.from(request.getPageParams().getFrom()) - .size(request.getPageParams().getSize()); - } - sourceBuilder.trackTotalHits(true) - // we only care about the item id's, there is no need to load large model definitions. - .fetchSource(TrainedModelConfig.MODEL_ID.getPreferredName(), null); - - IndicesOptions indicesOptions = SearchRequest.DEFAULT_INDICES_OPTIONS; - SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN) - .indicesOptions(IndicesOptions.fromOptions(true, - indicesOptions.allowNoIndices(), - indicesOptions.expandWildcardsOpen(), - indicesOptions.expandWildcardsClosed(), - indicesOptions)) - .source(sourceBuilder); - - executeAsyncWithOrigin(client.threadPool().getThreadContext(), - ML_ORIGIN, - searchRequest, - ActionListener.wrap( - response -> { - Set foundResourceIds = new LinkedHashSet<>(); - long totalHitCount = response.getHits().getTotalHits().value; - for (SearchHit hit : response.getHits().getHits()) { - Map docSource = hit.getSourceAsMap(); - if (docSource == null) { - continue; - } - Object idValue = docSource.get(TrainedModelConfig.MODEL_ID.getPreferredName()); - if (idValue instanceof String) { - foundResourceIds.add(idValue.toString()); - } - } - ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, request.isAllowNoResources()); - requiredMatches.filterMatchedIds(foundResourceIds); - if (requiredMatches.hasUnmatchedIds()) { - idsListener.onFailure(ExceptionsHelper.missingTrainedModel(requiredMatches.unmatchedIdsString())); - } else { - idsListener.onResponse(Tuple.tuple(totalHitCount, foundResourceIds)); - } - }, - idsListener::onFailure - ), - client::search); - - } - - private QueryBuilder buildQuery(String[] tokens, String resourceIdField) { - BoolQueryBuilder boolQuery = QueryBuilders.boolQuery() - .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME)); - - if (Strings.isAllOrWildcard(tokens)) { - return boolQuery; - } - // If the resourceId is not _all or *, we should see if it is a comma delimited string with wild-cards - // e.g. id1,id2*,id3 - BoolQueryBuilder shouldQueries = new BoolQueryBuilder(); - List terms = new ArrayList<>(); - for (String token : tokens) { - if (Regex.isSimpleMatchPattern(token)) { - shouldQueries.should(QueryBuilders.wildcardQuery(resourceIdField, token)); - } else { - terms.add(token); - } - } - if (terms.isEmpty() == false) { - shouldQueries.should(QueryBuilders.termsQuery(resourceIdField, terms)); - } - - if (shouldQueries.should().isEmpty() == false) { - boolQuery.filter(shouldQueries); - } - return boolQuery; - } - static String[] ingestNodes(final ClusterState clusterState) { String[] ingestNodes = new String[clusterState.nodes().getIngestNodes().size()]; Iterator nodeIterator = clusterState.nodes().getIngestNodes().keysIt(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 63e496060a788..d47a4eb886079 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -20,10 +20,19 @@ import org.elasticsearch.action.search.MultiSearchAction; import org.elasticsearch.action.search.MultiSearchRequestBuilder; import org.elasticsearch.action.search.MultiSearchResponse; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.Client; import org.elasticsearch.common.CheckedBiFunction; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.regex.Regex; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ToXContent; @@ -34,12 +43,18 @@ import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.engine.VersionConflictEngineException; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.reindex.DeleteByQueryAction; import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher; +import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; @@ -49,10 +64,17 @@ import java.io.IOException; import java.io.InputStream; +import java.util.ArrayList; import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; +import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_FAILED_TO_DESERIALIZE; public class TrainedModelProvider { @@ -191,6 +213,56 @@ public void getTrainedModel(final String modelId, final boolean includeDefinitio multiSearchResponseActionListener); } + /** + * Gets all the provided trained config model objects + * + * NOTE: + * This does no expansion on the ids. + * It assumes that there are fewer than 10k. + */ + public void getTrainedModels(Set modelIds, boolean allowNoResources, final ActionListener> listener) { + QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0]))); + + SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) + .addSort(TrainedModelConfig.MODEL_ID.getPreferredName(), SortOrder.ASC) + .addSort("_index", SortOrder.DESC) + .setQuery(queryBuilder) + .request(); + + ActionListener configSearchHandler = ActionListener.wrap( + searchResponse -> { + Set observedIds = new HashSet<>(searchResponse.getHits().getHits().length, 1.0f); + List configs = new ArrayList<>(searchResponse.getHits().getHits().length); + for(SearchHit searchHit : searchResponse.getHits().getHits()) { + try { + if (observedIds.contains(searchHit.getId()) == false) { + configs.add( + parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId()).build() + ); + observedIds.add(searchHit.getId()); + } + } catch (IOException ex) { + listener.onFailure( + ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ex, searchHit.getId())); + return; + } + } + // We previously expanded the IDs. + // If the config has gone missing between then and now we should throw if allowNoResources is false + // Otherwise, treat it as if it was never expanded to begin with. + Set missingConfigs = Sets.difference(modelIds, observedIds); + if (missingConfigs.isEmpty() == false && allowNoResources == false) { + listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs)); + return; + } + listener.onResponse(configs); + }, + listener::onFailure + ); + + executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, configSearchHandler); + } + public void deleteTrainedModel(String modelId, ActionListener listener) { DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false); @@ -216,6 +288,92 @@ public void deleteTrainedModel(String modelId, ActionListener listener) })); } + public void expandIds(String idExpression, + boolean allowNoResources, + @Nullable PageParams pageParams, + ActionListener>> idsListener) { + String[] tokens = Strings.tokenizeToStringArray(idExpression, ","); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() + .sort(SortBuilders.fieldSort(TrainedModelConfig.MODEL_ID.getPreferredName()) + // If there are no resources, there might be no mapping for the id field. + // This makes sure we don't get an error if that happens. + .unmappedType("long")) + .query(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName())); + if (pageParams != null) { + sourceBuilder.from(pageParams.getFrom()).size(pageParams.getSize()); + } + sourceBuilder.trackTotalHits(true) + // we only care about the item id's + .fetchSource(TrainedModelConfig.MODEL_ID.getPreferredName(), null); + + IndicesOptions indicesOptions = SearchRequest.DEFAULT_INDICES_OPTIONS; + SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN) + .indicesOptions(IndicesOptions.fromOptions(true, + indicesOptions.allowNoIndices(), + indicesOptions.expandWildcardsOpen(), + indicesOptions.expandWildcardsClosed(), + indicesOptions)) + .source(sourceBuilder); + + executeAsyncWithOrigin(client.threadPool().getThreadContext(), + ML_ORIGIN, + searchRequest, + ActionListener.wrap( + response -> { + Set foundResourceIds = new LinkedHashSet<>(); + long totalHitCount = response.getHits().getTotalHits().value; + for (SearchHit hit : response.getHits().getHits()) { + Map docSource = hit.getSourceAsMap(); + if (docSource == null) { + continue; + } + Object idValue = docSource.get(TrainedModelConfig.MODEL_ID.getPreferredName()); + if (idValue instanceof String) { + foundResourceIds.add(idValue.toString()); + } + } + ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, allowNoResources); + requiredMatches.filterMatchedIds(foundResourceIds); + if (requiredMatches.hasUnmatchedIds()) { + idsListener.onFailure(ExceptionsHelper.missingTrainedModel(requiredMatches.unmatchedIdsString())); + } else { + idsListener.onResponse(Tuple.tuple(totalHitCount, foundResourceIds)); + } + }, + idsListener::onFailure + ), + client::search); + + } + + private QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME)); + + if (Strings.isAllOrWildcard(tokens)) { + return boolQuery; + } + // If the resourceId is not _all or *, we should see if it is a comma delimited string with wild-cards + // e.g. id1,id2*,id3 + BoolQueryBuilder shouldQueries = new BoolQueryBuilder(); + List terms = new ArrayList<>(); + for (String token : tokens) { + if (Regex.isSimpleMatchPattern(token)) { + shouldQueries.should(QueryBuilders.wildcardQuery(resourceIdField, token)); + } else { + terms.add(token); + } + } + if (terms.isEmpty() == false) { + shouldQueries.should(QueryBuilders.termsQuery(resourceIdField, terms)); + } + + if (shouldQueries.should().isEmpty() == false) { + boolQuery.filter(shouldQueries); + } + return boolQuery; + } + private static T handleSearchItem(MultiSearchResponse.Item item, String resourceId, CheckedBiFunction parseLeniently) throws Exception { @@ -228,23 +386,23 @@ private static T handleSearchItem(MultiSearchResponse.Item item, return parseLeniently.apply(item.getResponse().getHits().getHits()[0].getSourceRef(), resourceId); } - private TrainedModelConfig.Builder parseInferenceDocLenientlyFromSource(BytesReference source, String modelId) throws Exception { + private TrainedModelConfig.Builder parseInferenceDocLenientlyFromSource(BytesReference source, String modelId) throws IOException { try (InputStream stream = source.streamInput(); XContentParser parser = XContentFactory.xContent(XContentType.JSON) .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) { return TrainedModelConfig.fromXContent(parser, true); - } catch (Exception e) { + } catch (IOException e) { logger.error(new ParameterizedMessage("[{}] failed to parse model", modelId), e); throw e; } } - private TrainedModelDefinition parseModelDefinitionDocLenientlyFromSource(BytesReference source, String modelId) throws Exception { + private TrainedModelDefinition parseModelDefinitionDocLenientlyFromSource(BytesReference source, String modelId) throws IOException { try (InputStream stream = source.streamInput(); XContentParser parser = XContentFactory.xContent(XContentType.JSON) .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) { return TrainedModelDefinition.fromXContent(parser, true).build(); - } catch (Exception e) { + } catch (IOException e) { logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), e); throw e; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java index 40ddd05827043..578b75fbc0793 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java @@ -41,7 +41,11 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient if (Strings.isNullOrEmpty(modelId)) { modelId = MetaData.ALL; } - GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId); + boolean includeModelDefinition = restRequest.paramAsBoolean( + GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION.getPreferredName(), + false + ); + GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition); if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) { request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 2362635142d0d..272628f4c12f8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -147,6 +147,14 @@ public void testMaxCachedLimitReached() throws Exception { modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3)); + // Should have been loaded from the cluster change event + // Verify that we have at least loaded all three so that evictions occur in the following loop + assertBusy(() -> { + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any()); + }); + String[] modelIds = new String[]{model1, model2, model3}; for(int i = 0; i < 10; i++) { // Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 40c52f55de144..099baf949b684 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -65,7 +65,8 @@ public void testInferModels() throws Exception { .setInput(new TrainedModelInput(Arrays.asList("field1", "field2"))) .setDefinition(new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) - .setTrainedModel(buildClassification(true))) + .setTrainedModel(buildClassification(true)) + .setModelId(modelId1)) .setVersion(Version.CURRENT) .setCreateTime(Instant.now()) .build(); @@ -73,7 +74,8 @@ public void testInferModels() throws Exception { .setInput(new TrainedModelInput(Arrays.asList("field1", "field2"))) .setDefinition(new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) - .setTrainedModel(buildRegression())) + .setTrainedModel(buildRegression()) + .setModelId(modelId2)) .setVersion(Version.CURRENT) .setCreateTime(Instant.now()) .build(); diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json index 481f8b25975bb..22d16a6c36941 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json @@ -33,6 +33,12 @@ "description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)", "default":true }, + "include_model_definition":{ + "type":"boolean", + "required":false, + "description":"Should the full model definition be included in the results. These definitions can be large", + "default":false + }, "from":{ "type":"int", "description":"skips a number of trained models", diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml index efc6b784dbeac..6062f6519067f 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml @@ -207,10 +207,12 @@ setup: failed: 0 processors: - inference: - count: 0 - time_in_millis: 0 - current: 0 - failed: 0 + type: inference + stats: + count: 0 + time_in_millis: 0 + current: 0 + failed: 0 - match: trained_model_stats.0.ingest.pipelines.regression-model-pipeline-1: @@ -220,7 +222,9 @@ setup: failed: 0 processors: - inference: - count: 0 - time_in_millis: 0 - current: 0 - failed: 0 + type: inference + stats: + count: 0 + time_in_millis: 0 + current: 0 + failed: 0 From cfd5c641e30f940d5f5312f90e21a82811308e4a Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 14 Nov 2019 09:13:26 -0500 Subject: [PATCH 15/17] [ML][Inference] adding license checks (#49056) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [ML][Inference] adding license checks * Apply suggestions from code review Co-Authored-By: Przemysław Witek --- .../xpack/ml/MachineLearning.java | 8 +- .../ml/action/TransportInferModelAction.java | 13 +- .../inference/ingest/InferenceProcessor.java | 25 +- .../MachineLearningLicensingTests.java | 223 ++++++++++++++++++ ...sportGetTrainedModelsStatsActionTests.java | 6 +- .../InferenceProcessorFactoryTests.java | 27 ++- 6 files changed, 291 insertions(+), 11 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 2338136f51c3d..bd89b82074168 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -342,10 +342,16 @@ protected Setting roleSetting() { @Override public Map getProcessors(Processor.Parameters parameters) { + if (this.enabled == false) { + return Collections.emptyMap(); + } + InferenceProcessor.Factory inferenceFactory = new InferenceProcessor.Factory(parameters.client, parameters.ingestService.getClusterService(), this.settings, - parameters.ingestService); + parameters.ingestService, + getLicenseState()); + getLicenseState().addListener(inferenceFactory); parameters.ingestService.addIngestClusterStateListener(inferenceFactory); return Collections.singletonMap(InferenceProcessor.TYPE, inferenceFactory); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java index b01063cac48dd..4edd214094f2c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java @@ -10,9 +10,12 @@ import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.client.Client; import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.ml.inference.loadingservice.Model; @@ -24,20 +27,28 @@ public class TransportInferModelAction extends HandledTransportAction listener) { + if (licenseState.isMachineLearningAllowed() == false) { + listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); + return; + } + ActionListener getModelListener = ActionListener.wrap( model -> { TypedChainTaskExecutor typedChainTaskExecutor = diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index 5a0e50cd84bae..fe9f942dff431 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -26,7 +26,11 @@ import org.elasticsearch.ingest.Pipeline; import org.elasticsearch.ingest.PipelineConfiguration; import org.elasticsearch.ingest.Processor; +import org.elasticsearch.license.LicenseStateListener; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; @@ -144,20 +148,28 @@ public String getType() { return TYPE; } - public static final class Factory implements Processor.Factory, Consumer { + public static final class Factory implements Processor.Factory, Consumer, LicenseStateListener { private static final Logger logger = LogManager.getLogger(Factory.class); private final Client client; private final IngestService ingestService; + private final XPackLicenseState licenseState; private volatile int currentInferenceProcessors; private volatile int maxIngestProcessors; private volatile Version minNodeVersion = Version.CURRENT; + private volatile boolean inferenceAllowed; - public Factory(Client client, ClusterService clusterService, Settings settings, IngestService ingestService) { + public Factory(Client client, + ClusterService clusterService, + Settings settings, + IngestService ingestService, + XPackLicenseState licenseState) { this.client = client; this.maxIngestProcessors = MAX_INFERENCE_PROCESSORS.get(settings); this.ingestService = ingestService; + this.licenseState = licenseState; + this.inferenceAllowed = licenseState.isMachineLearningAllowed(); clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_INFERENCE_PROCESSORS, this::setMaxIngestProcessors); } @@ -199,6 +211,10 @@ int numInferenceProcessors() { public InferenceProcessor create(Map processorFactories, String tag, Map config) throws Exception { + if (inferenceAllowed == false) { + throw LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING); + } + if (this.maxIngestProcessors <= currentInferenceProcessors) { throw new ElasticsearchStatusException("Max number of inference processors reached, total inference processors [{}]. " + "Adjust the setting [{}]: [{}] if a greater number is desired.", @@ -272,5 +288,10 @@ void checkSupportedVersion(InferenceConfig config) { minNodeVersion)); } } + + @Override + public void licenseStateChanged() { + this.inferenceAllowed = licenseState.isMachineLearningAllowed(); + } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java index 19da0d1a682e5..7f01ac0aa2670 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java @@ -6,9 +6,16 @@ package org.elasticsearch.license; import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.action.ingest.PutPipelineAction; +import org.elasticsearch.action.ingest.PutPipelineRequest; +import org.elasticsearch.action.ingest.SimulatePipelineAction; +import org.elasticsearch.action.ingest.SimulatePipelineRequest; +import org.elasticsearch.action.ingest.SimulatePipelineResponse; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.license.License.OperationMode; @@ -20,22 +27,30 @@ import org.elasticsearch.xpack.core.ml.action.DeleteJobAction; import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.OpenJobAction; import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction; import org.elasticsearch.xpack.core.ml.action.PutJobAction; import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction; import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.job.config.JobState; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase; import org.junit.Before; +import java.nio.charset.StandardCharsets; import java.util.Collections; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; public class MachineLearningLicensingTests extends BaseMlIntegTestCase { @@ -453,6 +468,214 @@ public void testMachineLearningDeleteDatafeedActionNotRestricted() throws Except listener.actionGet(); } + public void testMachineLearningCreateInferenceProcessorRestricted() { + String modelId = "modelprocessorlicensetest"; + assertMLAllowed(true); + putInferenceModel(modelId); + + String pipeline = "{" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"target_field\": \"regression_value\",\n" + + " \"model_id\": \"modelprocessorlicensetest\",\n" + + " \"inference_config\": {\"regression\": {}},\n" + + " \"field_mappings\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " }]}\n"; + // test that license restricted apis do now work + PlainActionFuture putPipelineListener = PlainActionFuture.newFuture(); + client().execute(PutPipelineAction.INSTANCE, + new PutPipelineRequest("test_infer_license_pipeline", + new BytesArray(pipeline.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON), + putPipelineListener); + AcknowledgedResponse putPipelineResponse = putPipelineListener.actionGet(); + assertTrue(putPipelineResponse.isAcknowledged()); + + String simulateSource = "{\n" + + " \"pipeline\": \n" + + pipeline + + " ,\n" + + " \"docs\": [\n" + + " {\"_source\": {\n" + + " \"col1\": \"female\",\n" + + " \"col2\": \"M\",\n" + + " \"col3\": \"none\",\n" + + " \"col4\": 10\n" + + " }}]\n" + + "}"; + PlainActionFuture simulatePipelineListener = PlainActionFuture.newFuture(); + client().execute(SimulatePipelineAction.INSTANCE, + new SimulatePipelineRequest(new BytesArray(simulateSource.getBytes(StandardCharsets.UTF_8)), XContentType.JSON), + simulatePipelineListener); + + assertThat(simulatePipelineListener.actionGet().getResults(), is(not(empty()))); + + + // Pick a license that does not allow machine learning + License.OperationMode mode = randomInvalidLicenseType(); + enableLicensing(mode); + assertMLAllowed(false); + + // creating a new pipeline should fail + ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> { + PlainActionFuture listener = PlainActionFuture.newFuture(); + client().execute(PutPipelineAction.INSTANCE, + new PutPipelineRequest("test_infer_license_pipeline_failure", + new BytesArray(pipeline.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON), + listener); + listener.actionGet(); + }); + assertThat(e.status(), is(RestStatus.FORBIDDEN)); + assertThat(e.getMessage(), containsString("non-compliant")); + assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING)); + + // Simulating the pipeline should fail + e = expectThrows(ElasticsearchSecurityException.class, () -> { + PlainActionFuture listener = PlainActionFuture.newFuture(); + client().execute(SimulatePipelineAction.INSTANCE, + new SimulatePipelineRequest(new BytesArray(simulateSource.getBytes(StandardCharsets.UTF_8)), XContentType.JSON), + listener); + listener.actionGet(); + }); + assertThat(e.status(), is(RestStatus.FORBIDDEN)); + assertThat(e.getMessage(), containsString("non-compliant")); + assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING)); + + // Pick a license that does allow machine learning + mode = randomValidLicenseType(); + enableLicensing(mode); + assertMLAllowed(true); + // test that license restricted apis do now work + PlainActionFuture putPipelineListenerNewLicense = PlainActionFuture.newFuture(); + client().execute(PutPipelineAction.INSTANCE, + new PutPipelineRequest("test_infer_license_pipeline", + new BytesArray(pipeline.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON), + putPipelineListenerNewLicense); + AcknowledgedResponse putPipelineResponseNewLicense = putPipelineListenerNewLicense.actionGet(); + assertTrue(putPipelineResponseNewLicense.isAcknowledged()); + + PlainActionFuture simulatePipelineListenerNewLicense = PlainActionFuture.newFuture(); + client().execute(SimulatePipelineAction.INSTANCE, + new SimulatePipelineRequest(new BytesArray(simulateSource.getBytes(StandardCharsets.UTF_8)), XContentType.JSON), + simulatePipelineListenerNewLicense); + + assertThat(simulatePipelineListenerNewLicense.actionGet().getResults(), is(not(empty()))); + } + + public void testMachineLearningInferModelRestricted() throws Exception { + String modelId = "modelinfermodellicensetest"; + assertMLAllowed(true); + putInferenceModel(modelId); + + + PlainActionFuture inferModelSuccess = PlainActionFuture.newFuture(); + client().execute(InferModelAction.INSTANCE, new InferModelAction.Request( + modelId, + Collections.singletonList(Collections.emptyMap()), + new RegressionConfig() + ), inferModelSuccess); + assertThat(inferModelSuccess.actionGet().getInferenceResults(), is(not(empty()))); + + // Pick a license that does not allow machine learning + License.OperationMode mode = randomInvalidLicenseType(); + enableLicensing(mode); + assertMLAllowed(false); + + // inferring against a model should now fail + ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> { + PlainActionFuture listener = PlainActionFuture.newFuture(); + client().execute(InferModelAction.INSTANCE, new InferModelAction.Request( + modelId, + Collections.singletonList(Collections.emptyMap()), + new RegressionConfig() + ), listener); + listener.actionGet(); + }); + assertThat(e.status(), is(RestStatus.FORBIDDEN)); + assertThat(e.getMessage(), containsString("non-compliant")); + assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING)); + + // Pick a license that does allow machine learning + mode = randomValidLicenseType(); + enableLicensing(mode); + assertMLAllowed(true); + + PlainActionFuture listener = PlainActionFuture.newFuture(); + client().execute(InferModelAction.INSTANCE, new InferModelAction.Request( + modelId, + Collections.singletonList(Collections.emptyMap()), + new RegressionConfig() + ), listener); + assertThat(listener.actionGet().getInferenceResults(), is(not(empty()))); + } + + private void putInferenceModel(String modelId) { + String config = "" + + "{\n" + + " \"model_id\": \"" + modelId + "\",\n" + + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + + " \"description\": \"test model for classification\",\n" + + " \"version\": \"8.0.0\",\n" + + " \"created_by\": \"benwtrent\",\n" + + " \"created_time\": 0\n" + + "}"; + String definition = "" + + "{" + + " \"trained_model\": {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " }," + + " \"model_id\": \"" + modelId + "\"\n" + + "}"; + assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME) + .setId(modelId) + .setSource(config, XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get().status(), equalTo(RestStatus.CREATED)); + assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME) + .setId(TrainedModelDefinition.docId(modelId)) + .setSource(definition, XContentType.JSON) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get().status(), equalTo(RestStatus.CREATED)); + } + private static OperationMode randomInvalidLicenseType() { return randomFrom(License.OperationMode.GOLD, License.OperationMode.STANDARD, License.OperationMode.BASIC); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java index 5a148f0676ef9..483906f3c5ecf 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.ingest.IngestStats; import org.elasticsearch.ingest.PipelineConfiguration; import org.elasticsearch.ingest.Processor; +import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; @@ -80,11 +81,14 @@ public Processor create(Map processorFactories, Strin @Override public Map getProcessors(Processor.Parameters parameters) { Map factoryMap = new HashMap<>(); + XPackLicenseState licenseState = mock(XPackLicenseState.class); + when(licenseState.isMachineLearningAllowed()).thenReturn(true); factoryMap.put(InferenceProcessor.TYPE, new InferenceProcessor.Factory(parameters.client, parameters.ingestService.getClusterService(), Settings.EMPTY, - parameters.ingestService)); + parameters.ingestService, + licenseState)); factoryMap.put("not_inference", new NotInferenceProcessor.Factory()); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java index 322b5cfb4ec2c..3337ad7be53a1 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.ingest.IngestService; import org.elasticsearch.ingest.PipelineConfiguration; import org.elasticsearch.ingest.Processor; +import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; @@ -48,14 +49,18 @@ public class InferenceProcessorFactoryTests extends ESTestCase { private static final IngestPlugin SKINNY_PLUGIN = new IngestPlugin() { @Override public Map getProcessors(Processor.Parameters parameters) { + XPackLicenseState licenseState = mock(XPackLicenseState.class); + when(licenseState.isMachineLearningAllowed()).thenReturn(true); return Collections.singletonMap(InferenceProcessor.TYPE, new InferenceProcessor.Factory(parameters.client, parameters.ingestService.getClusterService(), Settings.EMPTY, - parameters.ingestService)); + parameters.ingestService, + licenseState)); } }; private Client client; + private XPackLicenseState licenseState; private ClusterService clusterService; private IngestService ingestService; @@ -69,12 +74,18 @@ public void setUpVariables() { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); ingestService = new IngestService(clusterService, tp, null, null, null, Collections.singletonList(SKINNY_PLUGIN), client); + licenseState = mock(XPackLicenseState.class); + when(licenseState.isMachineLearningAllowed()).thenReturn(true); } public void testNumInferenceProcessors() throws Exception { MetaData metaData = null; - InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, clusterService, Settings.EMPTY, ingestService); + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, + clusterService, + Settings.EMPTY, + ingestService, + licenseState); processorFactory.accept(buildClusterState(metaData)); assertThat(processorFactory.numInferenceProcessors(), equalTo(0)); @@ -91,7 +102,8 @@ public void testCreateProcessorWithTooManyExisting() throws Exception { InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, clusterService, Settings.builder().put(InferenceProcessor.MAX_INFERENCE_PROCESSORS.getKey(), 1).build(), - ingestService); + ingestService, + licenseState); processorFactory.accept(buildClusterStateWithModelReferences("model1")); @@ -106,7 +118,8 @@ public void testCreateProcessorWithInvalidInferenceConfig() { InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, clusterService, Settings.EMPTY, - ingestService); + ingestService, + licenseState); Map config = new HashMap<>() {{ put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); @@ -147,7 +160,8 @@ public void testCreateProcessorWithTooOldMinNodeVersion() throws IOException { InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, clusterService, Settings.EMPTY, - ingestService); + ingestService, + licenseState); processorFactory.accept(builderClusterStateWithModelReferences(Version.V_7_5_0, "model1")); Map regression = new HashMap<>() {{ @@ -190,7 +204,8 @@ public void testCreateProcessor() { InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, clusterService, Settings.EMPTY, - ingestService); + ingestService, + licenseState); Map regression = new HashMap<>() {{ put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap()); From be53e316d6afa80f940e3080bb6710c9b35f10fa Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 14 Nov 2019 09:13:48 -0500 Subject: [PATCH 16/17] [ML][Inference] Adding memory and compute estimates to inference (#48955) * [ML][Inference] Adding memory and compute estimates to inference * Make nodes non-empty * fixing tests --- .../ml/inference/TrainedModelConfig.java | 51 +++++++++++++- .../ml/inference/TrainedModelConfigTests.java | 5 +- .../core/ml/inference/TrainedModelConfig.java | 69 ++++++++++++++++++- .../inference/trainedmodel/TrainedModel.java | 5 ++ .../trainedmodel/ensemble/Ensemble.java | 9 +++ .../ensemble/LogisticRegression.java | 3 +- .../trainedmodel/ensemble/WeightedMode.java | 3 +- .../trainedmodel/ensemble/WeightedSum.java | 3 +- .../ml/inference/trainedmodel/tree/Tree.java | 18 ++--- .../ml/inference/TrainedModelConfigTests.java | 8 ++- .../TrainedModelDefinitionTests.java | 3 +- .../trainedmodel/ensemble/EnsembleTests.java | 12 ++++ .../trainedmodel/tree/TreeTests.java | 5 ++ .../ml/integration/InferenceIngestIT.java | 8 +-- .../xpack/ml/integration/TrainedModelIT.java | 2 + .../process/AnalyticsResultProcessor.java | 2 + .../persistence/InferenceInternalIndex.java | 7 ++ .../AnalyticsResultProcessorTests.java | 2 + .../integration/ModelInferenceActionIT.java | 4 ++ .../integration/TrainedModelProviderIT.java | 2 + 20 files changed, 198 insertions(+), 23 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java index 273aa6b021325..384bfe53e4bb5 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java @@ -22,6 +22,7 @@ import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -47,6 +48,8 @@ public class TrainedModelConfig implements ToXContentObject { public static final ParseField TAGS = new ParseField("tags"); public static final ParseField METADATA = new ParseField("metadata"); public static final ParseField INPUT = new ParseField("input"); + public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes"); + public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations"); public static final ObjectParser PARSER = new ObjectParser<>(NAME, true, @@ -66,6 +69,8 @@ public class TrainedModelConfig implements ToXContentObject { PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS); PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT); + PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES); + PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS); } public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException { @@ -81,6 +86,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr private final List tags; private final Map metadata; private final TrainedModelInput input; + private final Long estimatedHeapMemory; + private final Long estimatedOperations; TrainedModelConfig(String modelId, String createdBy, @@ -90,7 +97,9 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr TrainedModelDefinition definition, List tags, Map metadata, - TrainedModelInput input) { + TrainedModelInput input, + Long estimatedHeapMemory, + Long estimatedOperations) { this.modelId = modelId; this.createdBy = createdBy; this.version = version; @@ -100,6 +109,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr this.tags = tags == null ? null : Collections.unmodifiableList(tags); this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata); this.input = input; + this.estimatedHeapMemory = estimatedHeapMemory; + this.estimatedOperations = estimatedOperations; } public String getModelId() { @@ -138,6 +149,18 @@ public TrainedModelInput getInput() { return input; } + public ByteSizeValue getEstimatedHeapMemory() { + return estimatedHeapMemory == null ? null : new ByteSizeValue(estimatedHeapMemory); + } + + public Long getEstimatedHeapMemoryBytes() { + return estimatedHeapMemory; + } + + public Long getEstimatedOperations() { + return estimatedOperations; + } + public static Builder builder() { return new Builder(); } @@ -172,6 +195,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (input != null) { builder.field(INPUT.getPreferredName(), input); } + if (estimatedHeapMemory != null) { + builder.field(ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), estimatedHeapMemory); + } + if (estimatedOperations != null) { + builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations); + } builder.endObject(); return builder; } @@ -194,6 +223,8 @@ public boolean equals(Object o) { Objects.equals(definition, that.definition) && Objects.equals(tags, that.tags) && Objects.equals(input, that.input) && + Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) && + Objects.equals(estimatedOperations, that.estimatedOperations) && Objects.equals(metadata, that.metadata); } @@ -206,6 +237,8 @@ public int hashCode() { definition, description, tags, + estimatedHeapMemory, + estimatedOperations, metadata, input); } @@ -222,6 +255,8 @@ public static class Builder { private List tags; private TrainedModelDefinition definition; private TrainedModelInput input; + private Long estimatedHeapMemory; + private Long estimatedOperations; public Builder setModelId(String modelId) { this.modelId = modelId; @@ -277,6 +312,16 @@ public Builder setInput(TrainedModelInput input) { return this; } + public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) { + this.estimatedHeapMemory = estimatedHeapMemory; + return this; + } + + public Builder setEstimatedOperations(Long estimatedOperations) { + this.estimatedOperations = estimatedOperations; + return this; + } + public TrainedModelConfig build() { return new TrainedModelConfig( modelId, @@ -287,7 +332,9 @@ public TrainedModelConfig build() { definition, tags, metadata, - input); + input, + estimatedHeapMemory, + estimatedOperations); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java index 6d1e04f066cb7..7afba62861362 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java @@ -64,7 +64,10 @@ protected TrainedModelConfig createTestInstance() { randomBoolean() ? null : Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()), randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), - randomBoolean() ? null : TrainedModelInputTests.createRandomInput()); + randomBoolean() ? null : TrainedModelInputTests.createRandomInput(), + randomBoolean() ? null : randomNonNegativeLong(), + randomBoolean() ? null : randomNonNegativeLong()); + } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 7078322b5d331..5361760e5ca26 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -34,6 +35,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { public static final String NAME = "trained_model_config"; + private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage"; + public static final ParseField MODEL_ID = new ParseField("model_id"); public static final ParseField CREATED_BY = new ParseField("created_by"); public static final ParseField VERSION = new ParseField("version"); @@ -43,6 +46,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { public static final ParseField TAGS = new ParseField("tags"); public static final ParseField METADATA = new ParseField("metadata"); public static final ParseField INPUT = new ParseField("input"); + public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes"); + public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations"); // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly public static final ObjectParser LENIENT_PARSER = createParser(true); @@ -66,6 +71,8 @@ private static ObjectParser createParser(boole parser.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p, ignoreUnknownFields), INPUT); + parser.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES); + parser.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS); return parser; } @@ -81,6 +88,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo private final List tags; private final Map metadata; private final TrainedModelInput input; + private final long estimatedHeapMemory; + private final long estimatedOperations; private final TrainedModelDefinition definition; @@ -92,7 +101,9 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo TrainedModelDefinition definition, List tags, Map metadata, - TrainedModelInput input) { + TrainedModelInput input, + Long estimatedHeapMemory, + Long estimatedOperations) { this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY); this.version = ExceptionsHelper.requireNonNull(version, VERSION); @@ -102,6 +113,15 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo this.tags = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(tags, TAGS)); this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata); this.input = ExceptionsHelper.requireNonNull(input, INPUT); + if (ExceptionsHelper.requireNonNull(estimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES) < 0) { + throw new IllegalArgumentException( + "[" + ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName() + "] must be greater than or equal to 0"); + } + this.estimatedHeapMemory = estimatedHeapMemory; + if (ExceptionsHelper.requireNonNull(estimatedOperations, ESTIMATED_OPERATIONS) < 0) { + throw new IllegalArgumentException("[" + ESTIMATED_OPERATIONS.getPreferredName() + "] must be greater than or equal to 0"); + } + this.estimatedOperations = estimatedOperations; } public TrainedModelConfig(StreamInput in) throws IOException { @@ -114,6 +134,8 @@ public TrainedModelConfig(StreamInput in) throws IOException { tags = Collections.unmodifiableList(in.readList(StreamInput::readString)); metadata = in.readMap(); input = new TrainedModelInput(in); + estimatedHeapMemory = in.readVLong(); + estimatedOperations = in.readVLong(); } public String getModelId() { @@ -157,6 +179,14 @@ public static Builder builder() { return new Builder(); } + public long getEstimatedHeapMemory() { + return estimatedHeapMemory; + } + + public long getEstimatedOperations() { + return estimatedOperations; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(modelId); @@ -168,6 +198,8 @@ public void writeTo(StreamOutput out) throws IOException { out.writeCollection(tags, StreamOutput::writeString); out.writeMap(metadata); input.writeTo(out); + out.writeVLong(estimatedHeapMemory); + out.writeVLong(estimatedOperations); } @Override @@ -192,6 +224,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME); } builder.field(INPUT.getPreferredName(), input); + builder.humanReadableField( + ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), + ESTIMATED_HEAP_MEMORY_USAGE_HUMAN, + new ByteSizeValue(estimatedHeapMemory)); + builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations); builder.endObject(); return builder; } @@ -214,6 +251,8 @@ public boolean equals(Object o) { Objects.equals(definition, that.definition) && Objects.equals(tags, that.tags) && Objects.equals(input, that.input) && + Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) && + Objects.equals(estimatedOperations, that.estimatedOperations) && Objects.equals(metadata, that.metadata); } @@ -227,6 +266,8 @@ public int hashCode() { description, tags, metadata, + estimatedHeapMemory, + estimatedOperations, input); } @@ -241,6 +282,8 @@ public static class Builder { private Map metadata; private TrainedModelInput input; private TrainedModelDefinition definition; + private Long estimatedHeapMemory; + private Long estimatedOperations; public Builder setModelId(String modelId) { this.modelId = modelId; @@ -296,6 +339,16 @@ public Builder setInput(TrainedModelInput input) { return this; } + public Builder setEstimatedHeapMemory(long estimatedHeapMemory) { + this.estimatedHeapMemory = estimatedHeapMemory; + return this; + } + + public Builder setEstimatedOperations(long estimatedOperations) { + this.estimatedOperations = estimatedOperations; + return this; + } + // TODO move to REST level instead of here in the builder public void validate() { // We require a definition to be available here even though it will be stored in a different doc @@ -326,6 +379,16 @@ public void validate() { throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", CREATE_TIME.getPreferredName()); } + + if (estimatedHeapMemory != null) { + throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", + ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName()); + } + + if (estimatedOperations != null) { + throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", + ESTIMATED_OPERATIONS.getPreferredName()); + } } public TrainedModelConfig build() { @@ -338,7 +401,9 @@ public TrainedModelConfig build() { definition, tags, metadata, - input); + input, + estimatedHeapMemory, + estimatedOperations); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java index a9028efdffa94..e206a70918096 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java @@ -50,4 +50,9 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou * @throws org.elasticsearch.ElasticsearchException if validations fail */ void validate(); + + /** + * @return The estimated number of operations required at inference time + */ + long estimatedNumOperations(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index 79883f4db4b4e..a59f1a1c245d9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -38,6 +38,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.OptionalDouble; import java.util.stream.Collectors; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel; @@ -251,6 +252,14 @@ public void validate() { this.models.forEach(TrainedModel::validate); } + @Override + public long estimatedNumOperations() { + OptionalDouble avg = models.stream().mapToLong(TrainedModel::estimatedNumOperations).average(); + assert avg.isPresent() : "unexpected null when calculating number of operations"; + // Average operations for each model and the operations required for processing and aggregating with the outputAggregator + return (long)Math.ceil(avg.getAsDouble()) + 2 * (models.size() - 1); + } + public static Builder builder() { return new Builder(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java index 14f2b1b64b523..2dba96916390c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java @@ -157,6 +157,7 @@ public int hashCode() { @Override public long ramBytesUsed() { - return SHALLOW_SIZE + RamUsageEstimator.sizeOf(weights); + long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights); + return SHALLOW_SIZE + weightSize; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java index 29b311794be06..73689d16b1cf8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -174,6 +174,7 @@ public int hashCode() { @Override public long ramBytesUsed() { - return SHALLOW_SIZE + RamUsageEstimator.sizeOf(weights); + long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights); + return SHALLOW_SIZE + weightSize; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java index b9b34508b88ba..ed1c13cf10203 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java @@ -150,6 +150,7 @@ public boolean compatibleWith(TargetType targetType) { @Override public long ramBytesUsed() { - return SHALLOW_SIZE + RamUsageEstimator.sizeOf(weights); + long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights); + return SHALLOW_SIZE + weightSize; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index 60192f46234b9..1408b17a0691a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -87,7 +87,10 @@ public static Tree fromXContentLenient(XContentParser parser) { Tree(List featureNames, List nodes, TargetType targetType, List classificationLabels) { this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES)); - this.nodes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE)); + if(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE).size() == 0) { + throw new IllegalArgumentException("[tree_structure] must not be empty"); + } + this.nodes = Collections.unmodifiableList(nodes); this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE); this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels); this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue()); @@ -257,6 +260,12 @@ public void validate() { detectCycle(); } + @Override + public long estimatedNumOperations() { + // Grabbing the features from the doc + the depth of the tree + return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size(); + } + private void checkTargetType() { if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) { throw ExceptionsHelper.badRequestException( @@ -265,9 +274,6 @@ private void checkTargetType() { } private void detectCycle() { - if (nodes.isEmpty()) { - return; - } Set visited = new HashSet<>(nodes.size()); Queue toVisit = new ArrayDeque<>(nodes.size()); toVisit.add(0); @@ -288,10 +294,6 @@ private void detectCycle() { } private void detectMissingNodes() { - if (nodes.isEmpty()) { - return; - } - List missingNodes = new ArrayList<>(); for (int i = 0; i < nodes.size(); i++) { TreeNode currentNode = nodes.get(i); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index 5583081df2732..3f3f3cb9a3ad8 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -74,7 +74,9 @@ protected TrainedModelConfig createTestInstance() { null, // is not parsed so should not be provided tags, randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), - TrainedModelInputTests.createRandomInput()); + TrainedModelInputTests.createRandomInput(), + randomNonNegativeLong(), + randomNonNegativeLong()); } @Override @@ -117,7 +119,9 @@ public void testToXContentWithParams() throws IOException { TrainedModelDefinitionTests.createRandomBuilder(randomAlphaOfLength(10)).build(), Collections.emptyList(), randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), - TrainedModelInputTests.createRandomInput()); + TrainedModelInputTests.createRandomInput(), + randomNonNegativeLong(), + randomNonNegativeLong()); BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); assertThat(reference.utf8ToString(), containsString("definition")); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java index 5fdadac712d0d..69ff018db9b44 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests; import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; @@ -78,7 +79,7 @@ public static TrainedModelDefinition.Builder createRandomBuilder(String modelId) TargetMeanEncodingTests.createRandom())) .limit(numberOfProcessors) .collect(Collectors.toList())) - .setTrainedModel(randomFrom(TreeTests.createRandom())); + .setTrainedModel(randomFrom(TreeTests.createRandom(), EnsembleTests.createRandom())); } private static final String ENSEMBLE_MODEL = "" + diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 753a9d3dd3cad..c38591ab6cfc5 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -445,6 +445,18 @@ public void testRegressionInference() { closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001)); } + public void testOperationsEstimations() { + Tree tree1 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar"), 2); + Tree tree2 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5); + Tree tree3 = TreeTests.buildRandomTree(Arrays.asList("foo", "baz"), 3); + Ensemble ensemble = Ensemble.builder().setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setTargetType(TargetType.CLASSIFICATION) + .setFeatureNames(Arrays.asList("foo", "bar", "baz")) + .setOutputAggregator(new LogisticRegression(new double[]{0.1, 0.4, 1.0})) + .build(); + assertThat(ensemble.estimatedNumOperations(), equalTo(9L)); + } + private static Map zipObjMap(List keys, List values) { return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index 11bf44fd165e4..7f5158706941f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -300,6 +300,11 @@ public void testTreeWithTargetTypeAndLabelsMismatch() { assertThat(ex.getMessage(), equalTo(msg)); } + public void testOperationsEstimations() { + Tree tree = buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5); + assertThat(tree.estimatedNumOperations(), equalTo(7L)); + } + private static Map zipObjMap(List keys, List values) { return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get)); } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java index 790d623fbec9c..17b9ed512c82b 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -366,12 +366,12 @@ private Map generateSourceDoc() { private static final String REGRESSION_CONFIG = "{" + " \"model_id\": \"test_regression\",\n" + - " \"model_version\": 0,\n" + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + " \"description\": \"test model for regression\",\n" + " \"version\": \"8.0.0\",\n" + " \"created_by\": \"ml_test\",\n" + - " \"model_type\": \"local\",\n" + + " \"estimated_heap_memory_usage_bytes\": 0," + + " \"estimated_operations\": 0," + " \"created_time\": 0" + "}"; @@ -499,12 +499,12 @@ private Map generateSourceDoc() { private static final String CLASSIFICATION_CONFIG = "" + "{\n" + " \"model_id\": \"test_classification\",\n" + - " \"model_version\": 0,\n" + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + " \"description\": \"test model for classification\",\n" + " \"version\": \"8.0.0\",\n" + " \"created_by\": \"benwtrent\",\n" + - " \"model_type\": \"local\",\n" + + " \"estimated_heap_memory_usage_bytes\": 0," + + " \"estimated_operations\": 0," + " \"created_time\": 0\n" + "}"; diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java index 1982cec7eca0c..153b169ea8f32 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -192,6 +192,8 @@ private static String buildRegressionModel(String modelId) throws IOException { .setCreatedBy("ml_test") .setVersion(Version.CURRENT) .setCreateTime(Instant.now()) + .setEstimatedOperations(0) + .setEstimatedHeapMemory(0) .build() .toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"))); return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index eb47c17137a3b..3abc3b5e43cce 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -150,6 +150,8 @@ private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Build .setMetadata(Collections.singletonMap("analytics_config", XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true))) .setDefinition(definition) + .setEstimatedHeapMemory(definition.ramBytesUsed()) + .setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations()) .setInput(new TrainedModelInput(fieldNames)) .build(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java index 19d5d33abe4cd..aa80807aae85e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java @@ -24,6 +24,7 @@ import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.DYNAMIC; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.ENABLED; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.KEYWORD; +import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.LONG; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.PROPERTIES; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TEXT; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TYPE; @@ -103,6 +104,12 @@ private static void addInferenceDocFields(XContentBuilder builder) throws IOExce .endObject() .startObject(TrainedModelConfig.METADATA.getPreferredName()) .field(ENABLED, false) + .endObject() + .startObject(TrainedModelConfig.ESTIMATED_OPERATIONS.getPreferredName()) + .field(TYPE, LONG) + .endObject() + .startObject(TrainedModelConfig.ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName()) + .field(TYPE, LONG) .endObject(); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index cb90b39772a0c..bdccdf8c6722f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -145,6 +145,8 @@ public void testProcess_GivenInferenceModelIsStoredSuccessfully() { assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION)); assertThat(storedModel.getDefinition(), equalTo(inferenceModel.build())); assertThat(storedModel.getInput().getFieldNames(), equalTo(expectedFieldNames)); + assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed())); + assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations())); Map metadata = storedModel.getMetadata(); assertThat(metadata.size(), equalTo(1)); assertThat(metadata, hasKey("analytics_config")); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 099baf949b684..0b0f7514caf2e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -69,6 +69,8 @@ public void testInferModels() throws Exception { .setModelId(modelId1)) .setVersion(Version.CURRENT) .setCreateTime(Instant.now()) + .setEstimatedOperations(0) + .setEstimatedHeapMemory(0) .build(); TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1) .setInput(new TrainedModelInput(Arrays.asList("field1", "field2"))) @@ -77,6 +79,8 @@ public void testInferModels() throws Exception { .setTrainedModel(buildRegression()) .setModelId(modelId2)) .setVersion(Version.CURRENT) + .setEstimatedOperations(0) + .setEstimatedHeapMemory(0) .setCreateTime(Instant.now()) .build(); AtomicReference putConfigHolder = new AtomicReference<>(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index 75e83acb0b22e..10644cb6da547 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -144,6 +144,8 @@ private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String .setDescription("trained model config for test") .setModelId(modelId) .setVersion(Version.CURRENT) + .setEstimatedHeapMemory(0) + .setEstimatedOperations(0) .setInput(TrainedModelInputTests.createRandomInput()); } From 551bdd0c8bb788b10d33f934d0ffba62b55d0a3b Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 14 Nov 2019 09:46:14 -0500 Subject: [PATCH 17/17] fixing licensing tests --- .../elasticsearch/license/MachineLearningLicensingTests.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java index 7f01ac0aa2670..42edb301372bd 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java @@ -626,6 +626,8 @@ private void putInferenceModel(String modelId) { " \"description\": \"test model for classification\",\n" + " \"version\": \"8.0.0\",\n" + " \"created_by\": \"benwtrent\",\n" + + " \"estimated_heap_memory_usage_bytes\": 0,\n" + + " \"estimated_operations\": 0,\n" + " \"created_time\": 0\n" + "}"; String definition = "" +