diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index b74bf82466eac..2c8553d68fe98 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -98,6 +98,7 @@ import org.elasticsearch.xpack.core.ml.action.GetModelSnapshotsAction; import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction; import org.elasticsearch.xpack.core.ml.action.GetRecordsAction; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction; import org.elasticsearch.xpack.core.ml.action.KillProcessAction; import org.elasticsearch.xpack.core.ml.action.MlInfoAction; @@ -139,7 +140,14 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; @@ -323,6 +331,7 @@ public List> getClientActions() { StartDataFrameAnalyticsAction.INSTANCE, EvaluateDataFrameAction.INSTANCE, EstimateMemoryUsageAction.INSTANCE, + InferModelAction.INSTANCE, // security ClearRealmCacheAction.INSTANCE, ClearRolesCacheAction.INSTANCE, @@ -451,6 +460,17 @@ public List getNamedWriteables() { new NamedWriteableRegistry.Entry(PreProcessor.class, TargetMeanEncoding.NAME.getPreferredName(), TargetMeanEncoding::new), // ML - Inference models new NamedWriteableRegistry.Entry(TrainedModel.class, Tree.NAME.getPreferredName(), Tree::new), + new NamedWriteableRegistry.Entry(TrainedModel.class, Ensemble.NAME.getPreferredName(), Ensemble::new), + // ML - Inference aggregators + new NamedWriteableRegistry.Entry(OutputAggregator.class, WeightedSum.NAME.getPreferredName(), WeightedSum::new), + new NamedWriteableRegistry.Entry(OutputAggregator.class, WeightedMode.NAME.getPreferredName(), WeightedMode::new), + // ML - Inference Results + new NamedWriteableRegistry.Entry(InferenceResults.class, + ClassificationInferenceResults.NAME, + ClassificationInferenceResults::new), + new NamedWriteableRegistry.Entry(InferenceResults.class, + RegressionInferenceResults.NAME, + RegressionInferenceResults::new), // monitoring new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java new file mode 100644 index 0000000000000..248d6180d3256 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java @@ -0,0 +1,153 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class InferModelAction extends ActionType { + + public static final InferModelAction INSTANCE = new InferModelAction(); + public static final String NAME = "cluster:admin/xpack/ml/infer"; + + private InferModelAction() { + super(NAME, Response::new); + } + + public static class Request extends ActionRequest { + + private final String modelId; + private final long modelVersion; + private final List> objectsToInfer; + private final InferenceParams params; + + public Request(String modelId, long modelVersion) { + this(modelId, modelVersion, Collections.emptyList(), InferenceParams.EMPTY_PARAMS); + } + + public Request(String modelId, long modelVersion, List> objectsToInfer, InferenceParams inferenceParams) { + this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); + this.modelVersion = modelVersion; + this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, "objects_to_infer")); + this.params = inferenceParams == null ? InferenceParams.EMPTY_PARAMS : inferenceParams; + } + + public Request(String modelId, long modelVersion, Map objectToInfer, InferenceParams params) { + this(modelId, + modelVersion, + Arrays.asList(ExceptionsHelper.requireNonNull(objectToInfer, "objects_to_infer")), + params); + } + + public Request(StreamInput in) throws IOException { + super(in); + this.modelId = in.readString(); + this.modelVersion = in.readVLong(); + this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap)); + this.params = new InferenceParams(in); + } + + public String getModelId() { + return modelId; + } + + public long getModelVersion() { + return modelVersion; + } + + public List> getObjectsToInfer() { + return objectsToInfer; + } + + public InferenceParams getParams() { + return params; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(modelId); + out.writeVLong(modelVersion); + out.writeCollection(objectsToInfer, StreamOutput::writeMap); + params.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferModelAction.Request that = (InferModelAction.Request) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(modelVersion, that.modelVersion) + && Objects.equals(params, that.params) + && Objects.equals(objectsToInfer, that.objectsToInfer); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, modelVersion, objectsToInfer, params); + } + + } + + public static class Response extends ActionResponse { + + private final List inferenceResults; + + public Response(List inferenceResults) { + super(); + this.inferenceResults = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(inferenceResults, "inferenceResults")); + } + + public Response(StreamInput in) throws IOException { + super(in); + this.inferenceResults = Collections.unmodifiableList(in.readNamedWriteableList(InferenceResults.class)); + } + + public List getInferenceResults() { + return inferenceResults; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeNamedWriteableList(inferenceResults); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferModelAction.Response that = (InferModelAction.Response) o; + return Objects.equals(inferenceResults, that.inferenceResults); + } + + @Override + public int hashCode() { + return Objects.hash(inferenceResults); + } + + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 7fff4d6abbd3b..7b56a4c3b4da3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -8,6 +8,9 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; @@ -100,6 +103,14 @@ public List getNamedWriteables() { WeightedMode.NAME.getPreferredName(), WeightedMode::new)); + // Inference Results + namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, + ClassificationInferenceResults.NAME, + ClassificationInferenceResults::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, + RegressionInferenceResults.NAME, + RegressionInferenceResults::new)); + return namedWriteables; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java index f85c184646e1f..e936d60bf87b7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java @@ -18,6 +18,8 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; @@ -27,6 +29,7 @@ import java.io.IOException; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Objects; public class TrainedModelDefinition implements ToXContentObject, Writeable { @@ -118,6 +121,15 @@ public Input getInput() { return input; } + private void preProcess(Map fields) { + preProcessors.forEach(preProcessor -> preProcessor.process(fields)); + } + + public InferenceResults infer(Map fields, InferenceParams params) { + preProcess(fields); + return trainedModel.infer(fields, params); + } + @Override public String toString() { return Strings.toString(this); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java new file mode 100644 index 0000000000000..662585bedf51d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java @@ -0,0 +1,175 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +public class ClassificationInferenceResults extends SingleValueInferenceResults { + + public static final String NAME = "classification"; + public static final ParseField CLASSIFICATION_LABEL = new ParseField("classification_label"); + public static final ParseField TOP_CLASSES = new ParseField("top_classes"); + + private final String classificationLabel; + private final List topClasses; + + public ClassificationInferenceResults(double value, String classificationLabel, List topClasses) { + super(value); + this.classificationLabel = classificationLabel; + this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses); + } + + public ClassificationInferenceResults(StreamInput in) throws IOException { + super(in); + this.classificationLabel = in.readOptionalString(); + this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new)); + } + + public String getClassificationLabel() { + return classificationLabel; + } + + public List getTopClasses() { + return topClasses; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalString(classificationLabel); + out.writeCollection(topClasses); + } + + @Override + XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { + if (classificationLabel != null) { + builder.field(CLASSIFICATION_LABEL.getPreferredName(), classificationLabel); + } + if (topClasses.isEmpty() == false) { + builder.field(TOP_CLASSES.getPreferredName(), topClasses); + } + return builder; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + ClassificationInferenceResults that = (ClassificationInferenceResults) object; + return Objects.equals(value(), that.value()) && + Objects.equals(classificationLabel, that.classificationLabel) && + Objects.equals(topClasses, that.topClasses); + } + + @Override + public int hashCode() { + return Objects.hash(value(), classificationLabel, topClasses); + } + + @Override + public String valueAsString() { + return classificationLabel == null ? super.valueAsString() : classificationLabel; + } + + @Override + public void writeResult(IngestDocument document, String resultField) { + ExceptionsHelper.requireNonNull(document, "document"); + ExceptionsHelper.requireNonNull(resultField, "resultField"); + if (topClasses.isEmpty()) { + document.setFieldValue(resultField, valueAsString()); + } else { + document.setFieldValue(resultField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList())); + } + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public String getName() { + return NAME; + } + + public static class TopClassEntry implements ToXContentObject, Writeable { + + public final ParseField CLASSIFICATION = new ParseField("classification"); + public final ParseField PROBABILITY = new ParseField("probability"); + + private final String classification; + private final double probability; + + public TopClassEntry(String classification, Double probability) { + this.classification = ExceptionsHelper.requireNonNull(classification, CLASSIFICATION); + this.probability = ExceptionsHelper.requireNonNull(probability, PROBABILITY); + } + + public TopClassEntry(StreamInput in) throws IOException { + this.classification = in.readString(); + this.probability = in.readDouble(); + } + + public String getClassification() { + return classification; + } + + public double getProbability() { + return probability; + } + + public Map asValueMap() { + Map map = new HashMap<>(2); + map.put(CLASSIFICATION.getPreferredName(), classification); + map.put(PROBABILITY.getPreferredName(), probability); + return map; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(classification); + out.writeDouble(probability); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASSIFICATION.getPreferredName(), classification); + builder.field(PROBABILITY.getPreferredName(), probability); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + TopClassEntry that = (TopClassEntry) object; + return Objects.equals(classification, that.classification) && + Objects.equals(probability, that.probability); + } + + @Override + public int hashCode() { + return Objects.hash(classification, probability); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java new file mode 100644 index 0000000000000..00744f6982f46 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; + +public interface InferenceResults extends NamedXContentObject, NamedWriteable { + + void writeResult(IngestDocument document, String resultField); + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java new file mode 100644 index 0000000000000..e186489b91dab --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class RegressionInferenceResults extends SingleValueInferenceResults { + + public static final String NAME = "regression"; + + public RegressionInferenceResults(double value) { + super(value); + } + + public RegressionInferenceResults(StreamInput in) throws IOException { + super(in.readDouble()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + XContentBuilder innerToXContent(XContentBuilder builder, Params params) { + return builder; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + RegressionInferenceResults that = (RegressionInferenceResults) object; + return Objects.equals(value(), that.value()); + } + + @Override + public int hashCode() { + return Objects.hash(value()); + } + + @Override + public void writeResult(IngestDocument document, String resultField) { + ExceptionsHelper.requireNonNull(document, "document"); + ExceptionsHelper.requireNonNull(resultField, "resultField"); + document.setFieldValue(resultField, value()); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java new file mode 100644 index 0000000000000..2905a6679584c --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; + +public abstract class SingleValueInferenceResults implements InferenceResults { + + public final ParseField VALUE = new ParseField("value"); + + private final double value; + + SingleValueInferenceResults(StreamInput in) throws IOException { + value = in.readDouble(); + } + + SingleValueInferenceResults(double value) { + this.value = value; + } + + public Double value() { + return value; + } + + public String valueAsString() { + return String.valueOf(value); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(value); + } + + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(VALUE.getPreferredName(), value); + innerToXContent(builder, params); + builder.endObject(); + return builder; + } + + abstract XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException; +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java new file mode 100644 index 0000000000000..5e37b237e9f79 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public final class InferenceHelpers { + + private InferenceHelpers() { } + + public static List topClasses(List probabilities, + List classificationLabels, + int numToInclude) { + if (numToInclude == 0) { + return Collections.emptyList(); + } + int[] sortedIndices = IntStream.range(0, probabilities.size()) + .boxed() + .sorted(Comparator.comparing(probabilities::get).reversed()) + .mapToInt(i -> i) + .toArray(); + + if (classificationLabels != null && probabilities.size() != classificationLabels.size()) { + throw ExceptionsHelper + .serverError( + "model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]", + null, + probabilities.size(), + classificationLabels); + } + + List labels = classificationLabels == null ? + // If we don't have the labels we should return the top classification values anyways, they will just be numeric + IntStream.range(0, probabilities.size()).boxed().map(String::valueOf).collect(Collectors.toList()) : + classificationLabels; + + int count = numToInclude < 0 ? probabilities.size() : Math.min(numToInclude, probabilities.size()); + List topClassEntries = new ArrayList<>(count); + for(int i = 0; i < count; i++) { + int idx = sortedIndices[i]; + topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx))); + } + + return topClassEntries; + } + + + public static String classificationLabel(double inferenceValue, @Nullable List classificationLabels) { + assert inferenceValue == Math.rint(inferenceValue); + if (classificationLabels == null) { + return String.valueOf(inferenceValue); + } + int label = Double.valueOf(inferenceValue).intValue(); + if (label < 0 || label >= classificationLabels.size()) { + throw ExceptionsHelper.serverError( + "model returned classification value of [{}] which is not a valid index in classification labels [{}]", + null, + label, + classificationLabels); + } + return classificationLabels.get(label); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java new file mode 100644 index 0000000000000..150bf3d483f26 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class InferenceParams implements ToXContentObject, Writeable { + + public static ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); + + public static InferenceParams EMPTY_PARAMS = new InferenceParams(0); + + private final int numTopClasses; + + public InferenceParams(Integer numTopClasses) { + this.numTopClasses = numTopClasses == null ? 0 : numTopClasses; + } + + public InferenceParams(StreamInput in) throws IOException { + this.numTopClasses = in.readInt(); + } + + public int getNumTopClasses() { + return numTopClasses; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(numTopClasses); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferenceParams that = (InferenceParams) o; + return Objects.equals(numTopClasses, that.numTopClasses); + } + + @Override + public int hashCode() { + return Objects.hash(numTopClasses); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (numTopClasses != 0) { + builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); + } + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java index cad5a6c0a8c74..a6c6f1eff011d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java @@ -7,6 +7,7 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; import java.util.List; @@ -26,13 +27,7 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable { * @return The predicted value. For classification this will be discrete values (e.g. 0.0, or 1.0). * For regression this is continuous. */ - double infer(Map fields); - - /** - * @param fields similar to {@link TrainedModel#infer(Map)}, but fields are already in order and doubles - * @return The predicted value. - */ - double infer(List fields); + InferenceResults infer(Map fields, InferenceParams params); /** * @return {@link TargetType} for the model. @@ -40,26 +35,7 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable { TargetType targetType(); /** - * This gathers the probabilities for each potential classification value. - * - * The probabilities are indexed by classification ordinal label encoding. - * The length of this list is equal to the number of classification labels. - * - * This only should return if the implementation model is inferring classification values and not regression - * @param fields The fields and their values to infer against - * @return The probabilities of each classification value - */ - List classificationProbability(Map fields); - - /** - * @param fields similar to {@link TrainedModel#classificationProbability(Map)} but the fields are already in order and doubles - * @return The probabilities of each classification value - */ - List classificationProbability(List fields); - - /** - * The ordinal encoded list of the classification labels. - * @return Oridinal encoded list of classification labels. + * @return Ordinal encoded list of classification labels. */ @Nullable List classificationLabels(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index 5e5199c24053f..09e418cec916d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -12,6 +12,12 @@ import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; @@ -26,6 +32,8 @@ import java.util.Objects; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel; + public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel { // TODO should we have regression/classification sub-classes that accept the builder? @@ -106,14 +114,21 @@ public List getFeatureNames() { } @Override - public double infer(Map fields) { - List processedInferences = inferAndProcess(fields); - return outputAggregator.aggregate(processedInferences); - } - - @Override - public double infer(List fields) { - throw new UnsupportedOperationException("Ensemble requires map containing field names and values"); + public InferenceResults infer(Map fields, InferenceParams params) { + if (params.getNumTopClasses() != 0 && + (targetType != TargetType.CLASSIFICATION || outputAggregator.providesProbabilities() == false)) { + throw ExceptionsHelper.badRequestException( + "Cannot return top classes for target_type [{}] and aggregate_output [{}]", + targetType, + outputAggregator.getName()); + } + List inferenceResults = this.models.stream().map(model -> { + InferenceResults results = model.infer(fields, InferenceParams.EMPTY_PARAMS); + assert results instanceof SingleValueInferenceResults; + return ((SingleValueInferenceResults)results).value(); + }).collect(Collectors.toList()); + List processed = outputAggregator.processValues(inferenceResults); + return buildResults(processed, params); } @Override @@ -121,18 +136,20 @@ public TargetType targetType() { return targetType; } - @Override - public List classificationProbability(Map fields) { - if ((targetType == TargetType.CLASSIFICATION) == false) { - throw new UnsupportedOperationException( - "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); + private InferenceResults buildResults(List processedInferences, InferenceParams params) { + switch(targetType) { + case REGRESSION: + return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences)); + case CLASSIFICATION: + List topClasses = + InferenceHelpers.topClasses(processedInferences, classificationLabels, params.getNumTopClasses()); + double value = outputAggregator.aggregate(processedInferences); + return new ClassificationInferenceResults(outputAggregator.aggregate(processedInferences), + classificationLabel(value, classificationLabels), + topClasses); + default: + throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model"); } - return inferAndProcess(fields); - } - - @Override - public List classificationProbability(List fields) { - throw new UnsupportedOperationException("Ensemble requires map containing field names and values"); } @Override @@ -140,11 +157,6 @@ public List classificationLabels() { return classificationLabels; } - private List inferAndProcess(Map fields) { - List modelInferences = models.stream().map(m -> m.infer(fields)).collect(Collectors.toList()); - return outputAggregator.processValues(modelInferences); - } - @Override public String getWriteableName() { return NAME.getPreferredName(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java index 1f882b724ee94..012f474ab0618 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java @@ -44,4 +44,6 @@ public interface OutputAggregator extends NamedXContentObject, NamedWriteable { * @return The name of the output aggregator */ String getName(); + + boolean providesProbabilities(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java index 739a4e13d8659..0689d748b0ccb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -158,4 +158,9 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(weights); } + + @Override + public boolean providesProbabilities() { + return true; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java index f5812dabf88f2..9c5c2bf582e54 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java @@ -135,4 +135,9 @@ public int hashCode() { public Integer expectedValueSize() { return weights == null ? null : this.weights.size(); } + + @Override + public boolean providesProbabilities() { + return false; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index 3a91ec0cd86c1..bce6b08b6ed4b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -13,6 +13,11 @@ import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; @@ -31,6 +36,8 @@ import java.util.Set; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel; + public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel { // TODO should we have regression/classification sub-classes that accept the builder? @@ -105,20 +112,36 @@ public List getNodes() { } @Override - public double infer(Map fields) { + public InferenceResults infer(Map fields, InferenceParams params) { + if (targetType != TargetType.CLASSIFICATION && params.getNumTopClasses() != 0) { + throw ExceptionsHelper.badRequestException( + "Cannot return top classes for target_type [{}]", targetType.toString()); + } List features = featureNames.stream().map(f -> - fields.get(f) instanceof Number ? ((Number)fields.get(f)).doubleValue() : null + fields.get(f) instanceof Number ? ((Number) fields.get(f)).doubleValue() : null ).collect(Collectors.toList()); - return infer(features); + return infer(features, params); } - @Override - public double infer(List features) { + private InferenceResults infer(List features, InferenceParams params) { TreeNode node = nodes.get(0); while(node.isLeaf() == false) { node = nodes.get(node.compare(features)); } - return node.getLeafValue(); + return buildResult(node.getLeafValue(), params); + } + + private InferenceResults buildResult(Double value, InferenceParams params) { + switch (targetType) { + case CLASSIFICATION: + List topClasses = + InferenceHelpers.topClasses(classificationProbability(value), classificationLabels, params.getNumTopClasses()); + return new ClassificationInferenceResults(value, classificationLabel(value, classificationLabels), topClasses); + case REGRESSION: + return new RegressionInferenceResults(value); + default: + throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model"); + } } /** @@ -142,34 +165,15 @@ public TargetType targetType() { return targetType; } - @Override - public List classificationProbability(Map fields) { - if ((targetType == TargetType.CLASSIFICATION) == false) { - throw new UnsupportedOperationException( - "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); - } - List features = featureNames.stream().map(f -> - fields.get(f) instanceof Number ? ((Number)fields.get(f)).doubleValue() : null) - .collect(Collectors.toList()); - - return classificationProbability(features); - } - - @Override - public List classificationProbability(List fields) { - if ((targetType == TargetType.CLASSIFICATION) == false) { - throw new UnsupportedOperationException( - "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); - } - double label = infer(fields); + private List classificationProbability(double inferenceValue) { // If we are classification, we should assume that the inference return value is whole. - assert label == Math.rint(label); + assert inferenceValue == Math.rint(inferenceValue); double maxCategory = this.highestOrderCategory.get(); // If we are classification, we should assume that the largest leaf value is whole. assert maxCategory == Math.rint(maxCategory); List list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)); // TODO, eventually have TreeNodes contain confidence levels - list.set(Double.valueOf(label).intValue(), 1.0); + list.set(Double.valueOf(inferenceValue).intValue(), 1.0); return list; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java index cb44d03e22bb2..44cca308ea794 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java @@ -5,6 +5,8 @@ */ package org.elasticsearch.xpack.core.ml.inference.utils; +import org.elasticsearch.common.Numbers; + import java.util.List; import java.util.stream.Collectors; @@ -22,31 +24,31 @@ private Statistics(){} */ public static List softMax(List values) { Double expSum = 0.0; - Double max = values.stream().filter(v -> isInvalid(v) == false).max(Double::compareTo).orElse(null); + Double max = values.stream().filter(Statistics::isValid).max(Double::compareTo).orElse(null); if (max == null) { throw new IllegalArgumentException("no valid values present"); } - List exps = values.stream().map(v -> isInvalid(v) ? Double.NEGATIVE_INFINITY : v - max) + List exps = values.stream().map(v -> isValid(v) ? v - max : Double.NEGATIVE_INFINITY) .collect(Collectors.toList()); for (int i = 0; i < exps.size(); i++) { - if (isInvalid(exps.get(i)) == false) { + if (isValid(exps.get(i))) { Double exp = Math.exp(exps.get(i)); expSum += exp; exps.set(i, exp); } } for (int i = 0; i < exps.size(); i++) { - if (isInvalid(exps.get(i))) { - exps.set(i, 0.0); - } else { + if (isValid(exps.get(i))) { exps.set(i, exps.get(i)/expSum); + } else { + exps.set(i, 0.0); } } return exps; } - public static boolean isInvalid(Double v) { - return v == null || Double.isInfinite(v) || Double.isNaN(v); + private static boolean isValid(Double v) { + return v != null && Numbers.isValidDouble(v); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java index 320eace983590..8dfcc5fc59977 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java @@ -51,6 +51,10 @@ public static ElasticsearchException serverError(String msg, Throwable cause) { return new ElasticsearchException(msg, cause); } + public static ElasticsearchException serverError(String msg, Throwable cause, Object... args) { + return new ElasticsearchException(msg, cause, args); + } + public static ElasticsearchStatusException conflictStatusException(String msg, Throwable cause, Object... args) { return new ElasticsearchStatusException(msg, RestStatus.CONFLICT, cause, args); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java new file mode 100644 index 0000000000000..a49643d081957 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.InferModelAction.Request; + +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParamsTests.randomInferenceParams; + +public class InferModelActionRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + return randomBoolean() ? + new Request( + randomAlphaOfLength(10), + randomLongBetween(1, 100), + Stream.generate(InferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()), + randomBoolean() ? null : randomInferenceParams()) : + new Request( + randomAlphaOfLength(10), + randomLongBetween(1, 100), + randomMap(), + randomBoolean() ? null : randomInferenceParams()); + } + + private static Map randomMap() { + return Stream.generate(()-> randomAlphaOfLength(10)) + .limit(randomInt(10)) + .collect(Collectors.toMap(Function.identity(), (v) -> randomAlphaOfLength(10))); + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java new file mode 100644 index 0000000000000..9e72d1c4e682a --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.InferModelAction.Response; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class InferModelActionResponseTests extends AbstractWireSerializingTestCase { + + @Override + protected Response createTestInstance() { + String resultType = randomFrom(ClassificationInferenceResults.NAME, RegressionInferenceResults.NAME); + return new Response( + Stream.generate(() -> randomInferenceResult(resultType)) + .limit(randomIntBetween(0, 10)) + .collect(Collectors.toList())); + } + + private static InferenceResults randomInferenceResult(String resultType) { + if (resultType.equals(ClassificationInferenceResults.NAME)) { + return ClassificationInferenceResultsTests.createRandomResults(); + } else if (resultType.equals(RegressionInferenceResults.NAME)) { + return RegressionInferenceResultsTests.createRandomResults(); + } else { + fail("unexpected result type [" + resultType + "]"); + return null; + } + } + + @Override + protected Writeable.Reader instanceReader() { + return Response::new; + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java new file mode 100644 index 0000000000000..ba90fece02f2a --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.equalTo; + +public class ClassificationInferenceResultsTests extends AbstractWireSerializingTestCase { + + public static ClassificationInferenceResults createRandomResults() { + return new ClassificationInferenceResults(randomDouble(), + randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : + Stream.generate(ClassificationInferenceResultsTests::createRandomClassEntry) + .limit(randomIntBetween(0, 10)) + .collect(Collectors.toList())); + } + + private static ClassificationInferenceResults.TopClassEntry createRandomClassEntry() { + return new ClassificationInferenceResults.TopClassEntry(randomAlphaOfLength(10), randomDouble()); + } + + public void testWriteResultsWithClassificationLabel() { + ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, "foo", Collections.emptyList()); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + + assertThat(document.getFieldValue("result_field", String.class), equalTo("foo")); + } + + public void testWriteResultsWithoutClassificationLabel() { + ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, null, Collections.emptyList()); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + + assertThat(document.getFieldValue("result_field", String.class), equalTo("1.0")); + } + + @SuppressWarnings("unchecked") + public void testWriteResultsWithTopClasses() { + List entries = Arrays.asList( + new ClassificationInferenceResults.TopClassEntry("foo", 0.7), + new ClassificationInferenceResults.TopClassEntry("bar", 0.2), + new ClassificationInferenceResults.TopClassEntry("baz", 0.1)); + ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, + "foo", + entries); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + + List list = document.getFieldValue("result_field", List.class); + assertThat(list.size(), equalTo(3)); + + for(int i = 0; i < 3; i++) { + Map map = (Map)list.get(i); + assertThat(map, equalTo(entries.get(i).asValueMap())); + } + } + + @Override + protected ClassificationInferenceResults createTestInstance() { + return createRandomResults(); + } + + @Override + protected Writeable.Reader instanceReader() { + return ClassificationInferenceResults::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java new file mode 100644 index 0000000000000..4f2d5926c84dc --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; + +import java.util.HashMap; + +import static org.hamcrest.Matchers.equalTo; + + +public class RegressionInferenceResultsTests extends AbstractWireSerializingTestCase { + + public static RegressionInferenceResults createRandomResults() { + return new RegressionInferenceResults(randomDouble()); + } + + public void testWriteResults() { + RegressionInferenceResults result = new RegressionInferenceResults(0.3); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + + assertThat(document.getFieldValue("result_field", Double.class), equalTo(0.3)); + } + + @Override + protected RegressionInferenceResults createTestInstance() { + return createRandomResults(); + } + + @Override + protected Writeable.Reader instanceReader() { + return RegressionInferenceResults::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParamsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParamsTests.java new file mode 100644 index 0000000000000..2586cdd75de47 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParamsTests.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +public class InferenceParamsTests extends AbstractWireSerializingTestCase { + + public static InferenceParams randomInferenceParams() { + return randomBoolean() ? InferenceParams.EMPTY_PARAMS : new InferenceParams(randomIntBetween(-1, 100)); + } + + @Override + protected InferenceParams createTestInstance() { + return randomInferenceParams(); + } + + @Override + protected Writeable.Reader instanceReader() { + return InferenceParams::new; + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index eb537e247e994..8da0c15718f24 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -15,6 +15,9 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; @@ -239,27 +242,30 @@ public void testClassificationProbability() { List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); - List expected = Arrays.asList(0.231475216, 0.768524783); + List expected = Arrays.asList(0.768524783, 0.231475216); double eps = 0.000001; - List probabilities = ensemble.classificationProbability(featureMap); + List probabilities = + ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { - assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); - expected = Arrays.asList(0.3100255188, 0.689974481); - probabilities = ensemble.classificationProbability(featureMap); + expected = Arrays.asList(0.689974481, 0.3100255188); + probabilities = + ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { - assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } featureVector = Arrays.asList(0.0, 1.0); featureMap = zipObjMap(featureNames, featureVector); - expected = Arrays.asList(0.231475216, 0.768524783); - probabilities = ensemble.classificationProbability(featureMap); + expected = Arrays.asList(0.768524783, 0.231475216); + probabilities = + ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { - assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } // This should handle missing values and take the default_left path @@ -268,9 +274,10 @@ public void testClassificationProbability() { put("bar", null); }}; expected = Arrays.asList(0.6899744811, 0.3100255188); - probabilities = ensemble.classificationProbability(featureMap); + probabilities = + ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { - assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } } @@ -320,21 +327,25 @@ public void testClassificationInference() { List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + assertThat(1.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + assertThat(1.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); featureVector = Arrays.asList(0.0, 1.0); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + assertThat(1.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); featureMap = new HashMap<>(2) {{ put("foo", 0.3); put("bar", null); }}; - assertEquals(0.0, ensemble.infer(featureMap), 0.00001); + assertThat(0.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); } public void testRegressionInference() { @@ -373,11 +384,13 @@ public void testRegressionInference() { List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); - assertEquals(0.9, ensemble.infer(featureMap), 0.00001); + assertThat(0.9, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(0.5, ensemble.infer(featureMap), 0.00001); + assertThat(0.5, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); // Test with NO aggregator supplied, verifies default behavior of non-weighted sum ensemble = Ensemble.builder() @@ -388,17 +401,20 @@ public void testRegressionInference() { featureVector = Arrays.asList(0.4, 0.0); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.8, ensemble.infer(featureMap), 0.00001); + assertThat(1.8, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + assertThat(1.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); featureMap = new HashMap<>(2) {{ put("foo", 0.3); put("bar", null); }}; - assertEquals(1.8, ensemble.infer(featureMap), 0.00001); + assertThat(1.8, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); } private static Map zipObjMap(List keys, List values) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index 81030585f1889..c362c17fd579d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -10,6 +10,9 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.junit.Before; @@ -120,26 +123,30 @@ public void testInfer() { // This feature vector should hit the right child of the root node List featureVector = Arrays.asList(0.6, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); - assertThat(0.3, closeTo(tree.infer(featureMap), 0.00001)); + assertThat(0.3, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); // This should hit the left child of the left child of the root node // i.e. it takes the path left, left featureVector = Arrays.asList(0.3, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertThat(0.1, closeTo(tree.infer(featureMap), 0.00001)); + assertThat(0.1, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); // This should hit the right child of the left child of the root node // i.e. it takes the path left, right featureVector = Arrays.asList(0.3, 0.9); featureMap = zipObjMap(featureNames, featureVector); - assertThat(0.2, closeTo(tree.infer(featureMap), 0.00001)); + assertThat(0.2, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); // This should handle missing values and take the default_left path featureMap = new HashMap<>(2) {{ put("foo", 0.3); put("bar", null); }}; - assertThat(0.1, closeTo(tree.infer(featureMap), 0.00001)); + assertThat(0.1, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); } public void testTreeClassificationProbability() { @@ -153,31 +160,43 @@ public void testTreeClassificationProbability() { builder.addLeaf(leftChildNode.getRightChild(), 0.0); List featureNames = Arrays.asList("foo", "bar"); - Tree tree = builder.setFeatureNames(featureNames).build(); + Tree tree = builder.setFeatureNames(featureNames).setClassificationLabels(Arrays.asList("cat", "dog")).build(); + double eps = 0.000001; // This feature vector should hit the right child of the root node List featureVector = Arrays.asList(0.6, 0.0); + List expectedProbs = Arrays.asList(1.0, 0.0); + List expectedFields = Arrays.asList("dog", "cat"); Map featureMap = zipObjMap(featureNames, featureVector); - assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap)); + List probabilities = + ((ClassificationInferenceResults)tree.infer(featureMap, new InferenceParams(2))).getTopClasses(); + for(int i = 0; i < expectedProbs.size(); i++) { + assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); + assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); + } // This should hit the left child of the left child of the root node // i.e. it takes the path left, left featureVector = Arrays.asList(0.3, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap)); - - // This should hit the right child of the left child of the root node - // i.e. it takes the path left, right - featureVector = Arrays.asList(0.3, 0.9); - featureMap = zipObjMap(featureNames, featureVector); - assertEquals(Arrays.asList(1.0, 0.0), tree.classificationProbability(featureMap)); + probabilities = + ((ClassificationInferenceResults)tree.infer(featureMap, new InferenceParams(2))).getTopClasses(); + for(int i = 0; i < expectedProbs.size(); i++) { + assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); + assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); + } // This should handle missing values and take the default_left path featureMap = new HashMap<>(2) {{ put("foo", 0.3); put("bar", null); }}; - assertEquals(1.0, tree.infer(featureMap), 0.00001); + probabilities = + ((ClassificationInferenceResults)tree.infer(featureMap, new InferenceParams(2))).getTopClasses(); + for(int i = 0; i < expectedProbs.size(); i++) { + assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); + assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); + } } public void testTreeWithNullRoot() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 277d7018e8d4e..0b09b0736bfbb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -95,6 +95,7 @@ import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction; import org.elasticsearch.xpack.core.ml.action.GetRecordsAction; import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.KillProcessAction; import org.elasticsearch.xpack.core.ml.action.MlInfoAction; import org.elasticsearch.xpack.core.ml.action.OpenJobAction; @@ -161,6 +162,7 @@ import org.elasticsearch.xpack.ml.action.TransportGetModelSnapshotsAction; import org.elasticsearch.xpack.ml.action.TransportGetOverallBucketsAction; import org.elasticsearch.xpack.ml.action.TransportGetRecordsAction; +import org.elasticsearch.xpack.ml.action.TransportInferModelAction; import org.elasticsearch.xpack.ml.action.TransportIsolateDatafeedAction; import org.elasticsearch.xpack.ml.action.TransportKillProcessAction; import org.elasticsearch.xpack.ml.action.TransportMlInfoAction; @@ -200,7 +202,9 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier; @@ -501,6 +505,8 @@ public Collection createComponents(Client client, ClusterService cluster notifier, xContentRegistry); + final TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry); + final ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); // special holder for @link(MachineLearningFeatureSetUsage) which needs access to job manager if ML is enabled JobManagerHolder jobManagerHolder = new JobManagerHolder(jobManager); @@ -613,7 +619,9 @@ public Collection createComponents(Client client, ClusterService cluster analyticsProcessManager, memoryEstimationProcessManager, dataFrameAnalyticsConfigProvider, - nativeStorageProvider + nativeStorageProvider, + modelLoadingService, + trainedModelProvider ); } @@ -768,6 +776,7 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(StopDataFrameAnalyticsAction.INSTANCE, TransportStopDataFrameAnalyticsAction.class), new ActionHandler<>(EvaluateDataFrameAction.INSTANCE, TransportEvaluateDataFrameAction.class), new ActionHandler<>(EstimateMemoryUsageAction.INSTANCE, TransportEstimateMemoryUsageAction.class), + new ActionHandler<>(InferModelAction.INSTANCE, TransportInferModelAction.class), usageAction, infoAction); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java new file mode 100644 index 0000000000000..a2f79fc9d0437 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.ml.inference.loadingservice.Model; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; +import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor; + + +public class TransportInferModelAction extends HandledTransportAction { + + private final ModelLoadingService modelLoadingService; + private final Client client; + + @Inject + public TransportInferModelAction(TransportService transportService, + ActionFilters actionFilters, + ModelLoadingService modelLoadingService, + Client client) { + super(InferModelAction.NAME, transportService, actionFilters, InferModelAction.Request::new); + this.modelLoadingService = modelLoadingService; + this.client = client; + } + + @Override + protected void doExecute(Task task, InferModelAction.Request request, ActionListener listener) { + + ActionListener getModelListener = ActionListener.wrap( + model -> { + TypedChainTaskExecutor typedChainTaskExecutor = + new TypedChainTaskExecutor<>(client.threadPool().executor(ThreadPool.Names.SAME), + // run through all tasks + r -> true, + // Always fail immediately and return an error + ex -> true); + request.getObjectsToInfer().forEach(stringObjectMap -> + typedChainTaskExecutor.add(chainedTask -> + model.infer(stringObjectMap, request.getParams(), chainedTask))); + + typedChainTaskExecutor.execute(ActionListener.wrap( + inferenceResultsInterfaces -> + listener.onResponse(new InferModelAction.Response(inferenceResultsInterfaces)), + listener::onFailure + )); + }, + listener::onFailure + ); + + this.modelLoadingService.getModel(request.getModelId(), request.getModelVersion(), getModelListener); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java new file mode 100644 index 0000000000000..59f6c62a7f55e --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.ingest; + +import org.elasticsearch.client.Client; +import org.elasticsearch.ingest.AbstractProcessor; +import org.elasticsearch.ingest.IngestDocument; + +import java.util.function.BiConsumer; + +public class InferenceProcessor extends AbstractProcessor { + + public static final String TYPE = "inference"; + public static final String MODEL_ID = "model_id"; + + private final Client client; + + public InferenceProcessor(Client client, String tag) { + super(tag); + this.client = client; + } + + @Override + public void execute(IngestDocument ingestDocument, BiConsumer handler) { + //TODO actually work + handler.accept(ingestDocument, null); + } + + @Override + public IngestDocument execute(IngestDocument ingestDocument) { + throw new UnsupportedOperationException("should never be called"); + } + + @Override + public String getType() { + return TYPE; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java new file mode 100644 index 0000000000000..e5253b3d5b173 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; + +import java.util.Map; + +public class LocalModel implements Model { + + private final TrainedModelDefinition trainedModelDefinition; + private final String modelId; + + public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition) { + this.trainedModelDefinition = trainedModelDefinition; + this.modelId = modelId; + } + + @Override + public String getResultsType() { + switch (trainedModelDefinition.getTrainedModel().targetType()) { + case CLASSIFICATION: + return ClassificationInferenceResults.NAME; + case REGRESSION: + return RegressionInferenceResults.NAME; + default: + throw ExceptionsHelper.badRequestException("Model [{}] has unsupported target type [{}]", + modelId, + trainedModelDefinition.getTrainedModel().targetType()); + } + } + + @Override + public void infer(Map fields, InferenceParams params, ActionListener listener) { + try { + listener.onResponse(trainedModelDefinition.infer(fields, params)); + } catch (Exception e) { + listener.onFailure(e); + } + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java new file mode 100644 index 0000000000000..27924a47aa153 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; + +import java.util.Map; + +public interface Model { + + String getResultsType(); + + void infer(Map fields, InferenceParams inferenceParams, ActionListener listener); + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java new file mode 100644 index 0000000000000..b4fc552ba5f93 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -0,0 +1,287 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; + +public class ModelLoadingService implements ClusterStateListener { + + private static final Logger logger = LogManager.getLogger(ModelLoadingService.class); + private final Map loadedModels = new HashMap<>(); + private final Map>> loadingListeners = new HashMap<>(); + private final TrainedModelProvider provider; + private final ThreadPool threadPool; + + public ModelLoadingService(TrainedModelProvider trainedModelProvider, + ThreadPool threadPool, + ClusterService clusterService) { + this.provider = trainedModelProvider; + this.threadPool = threadPool; + clusterService.addListener(this); + } + + public void getModel(String modelId, long modelVersion, ActionListener modelActionListener) { + String key = modelKey(modelId, modelVersion); + MaybeModel cachedModel = loadedModels.get(key); + if (cachedModel != null) { + if (cachedModel.isSuccess()) { + modelActionListener.onResponse(cachedModel.getModel()); + return; + } + } + if (loadModelIfNecessary(key, modelId, modelVersion, modelActionListener) == false) { + // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called + // by a simulated pipeline + logger.debug("[{}] version [{}] not actively loading, eager loading without cache", modelId, modelVersion); + provider.getTrainedModel(modelId, modelVersion, ActionListener.wrap( + trainedModelConfig -> + modelActionListener.onResponse(new LocalModel(trainedModelConfig.getModelId(), trainedModelConfig.getDefinition())), + modelActionListener::onFailure + )); + } else { + logger.debug("[{}] version [{}] is currently loading, added new listener to queue", modelId, modelVersion); + } + } + + /** + * Returns true if the model is loaded and the listener has been given the cached model + * Returns true if the model is CURRENTLY being loaded and the listener was added to be notified when it is loaded + * Returns false if the model is not loaded or actively being loaded + */ + private boolean loadModelIfNecessary(String key, String modelId, long modelVersion, ActionListener modelActionListener) { + synchronized (loadingListeners) { + MaybeModel cachedModel = loadedModels.get(key); + if (cachedModel != null) { + if (cachedModel.isSuccess()) { + modelActionListener.onResponse(cachedModel.getModel()); + return true; + } + // If the loaded model entry is there but is not present, that means the previous load attempt ran into an issue + // Attempt to load and cache the model if necessary + if (loadingListeners.computeIfPresent( + key, + (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) == null) { + logger.debug("[{}] version [{}] attempting to load and cache", modelId, modelVersion); + loadingListeners.put(key, addFluently(new ArrayDeque<>(), modelActionListener)); + loadModel(key, modelId, modelVersion); + } + return true; + } + // if the cachedModel entry is null, but there are listeners present, that means it is being loaded + return loadingListeners.computeIfPresent(key, + (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) != null; + } + } + + private void loadModel(String modelKey, String modelId, long modelVersion) { + provider.getTrainedModel(modelId, modelVersion, ActionListener.wrap( + trainedModelConfig -> { + logger.debug("[{}] successfully loaded model", modelKey); + handleLoadSuccess(modelKey, trainedModelConfig); + }, + failure -> { + logger.warn(new ParameterizedMessage("[{}] failed to load model", modelKey), failure); + handleLoadFailure(modelKey, failure); + } + )); + } + + private void handleLoadSuccess(String modelKey, TrainedModelConfig trainedModelConfig) { + Queue> listeners; + Model loadedModel = new LocalModel(trainedModelConfig.getModelId(), trainedModelConfig.getDefinition()); + synchronized (loadingListeners) { + listeners = loadingListeners.remove(modelKey); + // If there is no loadingListener that means the loading was canceled and the listener was already notified as such + // Consequently, we should not store the retrieved model + if (listeners != null) { + loadedModels.put(modelKey, MaybeModel.of(loadedModel)); + } + } + if (listeners != null) { + for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + listener.onResponse(loadedModel); + } + } + } + + private void handleLoadFailure(String modelKey, Exception failure) { + Queue> listeners; + synchronized (loadingListeners) { + listeners = loadingListeners.remove(modelKey); + if (listeners != null) { + // If we failed to load and there were listeners present, that means that this model is referenced by a processor + // Add an empty entry here so that we can attempt to load and cache the model again when it is accessed again. + loadedModels.computeIfAbsent(modelKey, (key) -> MaybeModel.of(failure)); + } + } + if (listeners != null) { + for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + listener.onFailure(failure); + } + } + } + + @Override + public void clusterChanged(ClusterChangedEvent event) { + if (event.changedCustomMetaDataSet().contains(IngestMetadata.TYPE)) { + ClusterState state = event.state(); + IngestMetadata currentIngestMetadata = state.metaData().custom(IngestMetadata.TYPE); + Set allReferencedModelKeys = getReferencedModelKeys(currentIngestMetadata); + // The listeners still waiting for a model and we are canceling the load? + List>>> drainWithFailure = new ArrayList<>(); + synchronized (loadingListeners) { + // If we had models still loading here but are no longer referenced + // we should remove them from loadingListeners and alert the listeners + for (String modelKey : loadingListeners.keySet()) { + if (allReferencedModelKeys.contains(modelKey) == false) { + drainWithFailure.add(Tuple.tuple(splitModelKey(modelKey).v1(), new ArrayList<>(loadingListeners.remove(modelKey)))); + } + } + + // Remove all cached models that are not referenced by any processors + loadedModels.keySet().retainAll(allReferencedModelKeys); + + // Remove all that are currently being loaded + allReferencedModelKeys.removeAll(loadingListeners.keySet()); + + // Remove all that are fully loaded, will attempt empty model loading again + loadedModels.forEach((id, optionalModel) -> { + if (optionalModel.isSuccess()) { + allReferencedModelKeys.remove(id); + } + }); + // Populate loadingListeners key so we know that we are currently loading the model + for (String modelId : allReferencedModelKeys) { + loadingListeners.put(modelId, new ArrayDeque<>()); + } + } + for (Tuple>> modelAndListeners : drainWithFailure) { + final String msg = new ParameterizedMessage( + "Cancelling load of model [{}] as it is no longer referenced by a pipeline", + modelAndListeners.v1()).getFormat(); + for (ActionListener listener : modelAndListeners.v2()) { + listener.onFailure(new ElasticsearchException(msg)); + } + } + loadModels(allReferencedModelKeys); + } + } + + private void loadModels(Set modelKeys) { + if (modelKeys.isEmpty()) { + return; + } + // Execute this on a utility thread as when the callbacks occur we don't want them tying up the cluster listener thread pool + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> { + for (String modelKey : modelKeys) { + Tuple modelIdAndVersion = splitModelKey(modelKey); + this.loadModel(modelKey, modelIdAndVersion.v1(), modelIdAndVersion.v2()); + } + }); + } + + private static Queue addFluently(Queue queue, T object) { + queue.add(object); + return queue; + } + + private static String modelKey(String modelId, long modelVersion) { + return modelId + "_" + modelVersion; + } + + private static Tuple splitModelKey(String modelKey) { + int delim = modelKey.lastIndexOf('_'); + String modelId = modelKey.substring(0, delim); + Long version = Long.valueOf(modelKey.substring(delim + 1)); + return Tuple.tuple(modelId, version); + } + + private static Set getReferencedModelKeys(IngestMetadata ingestMetadata) { + Set allReferencedModelKeys = new HashSet<>(); + if (ingestMetadata != null) { + ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> { + Object processors = pipelineConfiguration.getConfigAsMap().get("processors"); + if (processors instanceof List) { + for(Object processor : (List)processors) { + if (processor instanceof Map) { + Object processorConfig = ((Map)processor).get(InferenceProcessor.TYPE); + if (processorConfig instanceof Map) { + Object modelId = ((Map)processorConfig).get(InferenceProcessor.MODEL_ID); + if (modelId != null) { + assert modelId instanceof String; + // TODO also read model version + allReferencedModelKeys.add(modelKey(modelId.toString(), 0)); + } + } + } + } + } + }); + } + return allReferencedModelKeys; + } + + private static class MaybeModel { + + private final Model model; + private final Exception exception; + + static MaybeModel of(Model model) { + return new MaybeModel(model, null); + } + + static MaybeModel of(Exception exception) { + return new MaybeModel(null, exception); + } + + private MaybeModel(Model model, Exception exception) { + this.model = model; + this.exception = exception; + } + + Model getModel() { + return model; + } + + Exception getException() { + return exception; + } + + boolean isSuccess() { + return this.model != null; + } + + boolean isFailure() { + return this.exception != null; + } + + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java new file mode 100644 index 0000000000000..e66a5790d85e5 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -0,0 +1,213 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class LocalModelTests extends ESTestCase { + + public void testClassificationInfer() throws Exception { + String modelId = "classification_model"; + TrainedModelDefinition definition = new TrainedModelDefinition.Builder() + .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) + .setTrainedModel(buildClassification(false)) + .build(); + + Model model = new LocalModel(modelId, definition); + Map fields = new HashMap<>() {{ + put("foo", 1.0); + put("bar", 0.5); + put("categorical", "dog"); + }}; + + SingleValueInferenceResults result = getSingleValue(model, fields, InferenceParams.EMPTY_PARAMS); + assertThat(result.value(), equalTo(0.0)); + assertThat(result.valueAsString(), is("0.0")); + + ClassificationInferenceResults classificationResult = + (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(1)); + assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); + assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("0")); + + // Test with labels + definition = new TrainedModelDefinition.Builder() + .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) + .setTrainedModel(buildClassification(true)) + .build(); + model = new LocalModel(modelId, definition); + result = getSingleValue(model, fields, InferenceParams.EMPTY_PARAMS); + assertThat(result.value(), equalTo(0.0)); + assertThat(result.valueAsString(), equalTo("not_to_be")); + + classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(1)); + assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); + assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be")); + + classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(2)); + assertThat(classificationResult.getTopClasses(), hasSize(2)); + + classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(-1)); + assertThat(classificationResult.getTopClasses(), hasSize(2)); + } + + public void testRegression() throws Exception { + TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder() + .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) + .setTrainedModel(buildRegression()) + .build(); + Model model = new LocalModel("regression_model", trainedModelDefinition); + + Map fields = new HashMap<>() {{ + put("foo", 1.0); + put("bar", 0.5); + put("categorical", "dog"); + }}; + + SingleValueInferenceResults results = getSingleValue(model, fields, InferenceParams.EMPTY_PARAMS); + assertThat(results.value(), equalTo(1.3)); + + PlainActionFuture failedFuture = new PlainActionFuture<>(); + model.infer(fields, new InferenceParams(2), failedFuture); + ExecutionException ex = expectThrows(ExecutionException.class, failedFuture::get); + assertThat(ex.getCause().getMessage(), + equalTo("Cannot return top classes for target_type [regression] and aggregate_output [weighted_sum]")); + } + + private static SingleValueInferenceResults getSingleValue(Model model, + Map fields, + InferenceParams params) throws Exception { + PlainActionFuture future = new PlainActionFuture<>(); + model.infer(fields, params, future); + return (SingleValueInferenceResults)future.get(); + } + + private static Map oneHotMap() { + Map oneHotEncoding = new HashMap<>(); + oneHotEncoding.put("cat", "animal_cat"); + oneHotEncoding.put("dog", "animal_dog"); + return oneHotEncoding; + } + + public static TrainedModel buildClassification(boolean includeLabels) { + List featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(1.0)) + .addNode(TreeNode.builder(2) + .setThreshold(0.8) + .setSplitFeature(1) + .setLeftChild(3) + .setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(0.0)) + .addNode(TreeNode.builder(4).setLeafValue(1.0)).build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(3) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(0.0)) + .addNode(TreeNode.builder(2).setLeafValue(1.0)) + .build(); + Tree tree3 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(1.0)) + .addNode(TreeNode.builder(2).setLeafValue(0.0)) + .build(); + return Ensemble.builder() + .setClassificationLabels(includeLabels ? Arrays.asList("not_to_be", "to_be") : null) + .setTargetType(TargetType.CLASSIFICATION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0))) + .build(); + } + + public static TrainedModel buildRegression() { + List featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(0.3)) + .addNode(TreeNode.builder(2) + .setThreshold(0.0) + .setSplitFeature(3) + .setLeftChild(3) + .setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(0.1)) + .addNode(TreeNode.builder(4).setLeafValue(0.2)).build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(2) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(1.5)) + .addNode(TreeNode.builder(2).setLeafValue(0.9)) + .build(); + Tree tree3 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(1) + .setThreshold(0.2)) + .addNode(TreeNode.builder(1).setLeafValue(1.5)) + .addNode(TreeNode.builder(2).setLeafValue(0.9)) + .build(); + return Ensemble.builder() + .setTargetType(TargetType.REGRESSION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5, 0.5))) + .build(); + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java new file mode 100644 index 0000000000000..4fbf973aae7f0 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -0,0 +1,212 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.PipelineConfiguration; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ScalingExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.atMost; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class ModelLoadingServiceTests extends ESTestCase { + + private TrainedModelProvider trainedModelProvider; + private ThreadPool threadPool; + private ClusterService clusterService; + + @Before + public void setUpComponents() { + threadPool = new TestThreadPool("ModelLoadingServiceTests", new ScalingExecutorBuilder(UTILITY_THREAD_POOL_NAME, + 1, 4, TimeValue.timeValueMinutes(10), "xpack.ml.utility_thread_pool")); + trainedModelProvider = mock(TrainedModelProvider.class); + clusterService = mock(ClusterService.class); + doAnswer((invocationOnMock) -> null).when(clusterService).addListener(any(ClusterStateListener.class)); + when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("_name")).build()); + } + + @After + public void terminateThreadPool() { + terminate(threadPool); + } + + public void testGetCachedModels() throws Exception { + String model1 = "test-load-model-1"; + String model2 = "test-load-model-2"; + String model3 = "test-load-model-3"; + withTrainedModel(model1, 0); + withTrainedModel(model2, 0); + withTrainedModel(model3, 0); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); + + modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3)); + + String[] modelIds = new String[]{model1, model2, model3}; + for(int i = 0; i < 10; i++) { + String model = modelIds[i%3]; + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, 0, future); + assertThat(future.get(), is(not(nullValue()))); + } + + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(0L), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(0L), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(0L), any()); + } + + public void testGetCachedMissingModel() throws Exception { + String model = "test-load-cached-missing-model"; + withMissingModel(model, 0); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); + modelLoadingService.clusterChanged(ingestChangedEvent(model)); + + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, 0, future); + + try { + future.get(); + fail("Should not have succeeded in loaded model"); + } catch (Exception ex) { + assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model, 0))); + } + + verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model), eq(0L), any()); + } + + public void testGetMissingModel() { + String model = "test-load-missing-model"; + withMissingModel(model, 0); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); + + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, 0, future); + try { + future.get(); + fail("Should not have succeeded"); + } catch (Exception ex) { + assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model, 0))); + } + } + + public void testGetModelEagerly() throws Exception { + String model = "test-get-model-eagerly"; + withTrainedModel(model, 0); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); + + for(int i = 0; i < 3; i++) { + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, 0, future); + assertThat(future.get(), is(not(nullValue()))); + } + + verify(trainedModelProvider, times(3)).getTrainedModel(eq(model), eq(0L), any()); + } + + @SuppressWarnings("unchecked") + private void withTrainedModel(String modelId, long modelVersion) { + TrainedModelConfig trainedModelConfig = buildTrainedModelConfigBuilder(modelId, modelVersion).build(Version.CURRENT); + doAnswer(invocationOnMock -> { + @SuppressWarnings("rawtypes") + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(trainedModelConfig); + return null; + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(modelVersion), any()); + } + + private void withMissingModel(String modelId, long modelVersion) { + doAnswer(invocationOnMock -> { + @SuppressWarnings("rawtypes") + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId, modelVersion))); + return null; + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(modelVersion), any()); + } + + private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId, long modelVersion) { + return TrainedModelConfig.builder() + .setCreatedBy("ml_test") + .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) + .setDescription("trained model config for test") + .setModelId(modelId) + .setModelType("binary_decision_tree") + .setModelVersion(modelVersion); + } + + private static ClusterChangedEvent ingestChangedEvent(String... modelId) throws IOException { + ClusterChangedEvent event = mock(ClusterChangedEvent.class); + when(event.changedCustomMetaDataSet()).thenReturn(Collections.singleton(IngestMetadata.TYPE)); + when(event.state()).thenReturn(buildClusterStateWithModelReferences(modelId)); + return event; + } + + private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException { + Map configurations = new HashMap<>(modelId.length); + for (String id : modelId) { + configurations.put("pipeline_with_model_" + id, newConfigurationWithInferenceProcessor(id)); + } + IngestMetadata ingestMetadata = new IngestMetadata(configurations); + + return ClusterState.builder(new ClusterName("_name")) + .metaData(MetaData.builder().putCustom(IngestMetadata.TYPE, ingestMetadata)) + .build(); + } + + private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId) throws IOException { + try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors", + Collections.singletonList( + Collections.singletonMap(InferenceProcessor.TYPE, + Collections.singletonMap(InferenceProcessor.MODEL_ID, + modelId)))))) { + return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON); + } + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java new file mode 100644 index 0000000000000..032f84d16b52d --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -0,0 +1,185 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.Version; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildClassification; +import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildRegression; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.nullValue; + +public class ModelInferenceActionIT extends MlSingleNodeTestCase { + + private TrainedModelProvider trainedModelProvider; + + @Before + public void createComponents() throws Exception { + trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry()); + waitForMlTemplates(); + } + + public void testInferModels() throws Exception { + String modelId1 = "test-load-models-regression"; + String modelId2 = "test-load-models-classification"; + Map oneHotEncoding = new HashMap<>(); + oneHotEncoding.put("cat", "animal_cat"); + oneHotEncoding.put("dog", "animal_dog"); + TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2, 0) + .setDefinition(new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) + .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) + .setTrainedModel(buildClassification(true))) + .build(Version.CURRENT); + TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1, 0) + .setDefinition(new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) + .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) + .setTrainedModel(buildRegression())) + .build(Version.CURRENT); + AtomicReference putConfigHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config1, listener), putConfigHolder, exceptionHolder); + assertThat(putConfigHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config2, listener), putConfigHolder, exceptionHolder); + assertThat(putConfigHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + + + List> toInfer = new ArrayList<>(); + toInfer.add(new HashMap<>() {{ + put("foo", 1.0); + put("bar", 0.5); + put("categorical", "dog"); + }}); + toInfer.add(new HashMap<>() {{ + put("foo", 0.9); + put("bar", 1.5); + put("categorical", "cat"); + }}); + + List> toInfer2 = new ArrayList<>(); + toInfer2.add(new HashMap<>() {{ + put("foo", 0.0); + put("bar", 0.01); + put("categorical", "dog"); + }}); + toInfer2.add(new HashMap<>() {{ + put("foo", 1.0); + put("bar", 0.0); + put("categorical", "cat"); + }}); + + // Test regression + InferModelAction.Request request = new InferModelAction.Request(modelId1, 0, toInfer, null); + InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()), + contains(1.3, 1.25)); + + request = new InferModelAction.Request(modelId1, 0, toInfer2, null); + response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()), + contains(1.65, 1.55)); + + + // Test classification + request = new InferModelAction.Request(modelId2, 0, toInfer, null); + response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + assertThat(response.getInferenceResults() + .stream() + .map(i -> ((SingleValueInferenceResults)i).valueAsString()) + .collect(Collectors.toList()), + contains("not_to_be", "to_be")); + + // Get top classes + request = new InferModelAction.Request(modelId2, 0, toInfer, new InferenceParams(2)); + response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + + ClassificationInferenceResults classificationInferenceResults = + (ClassificationInferenceResults)response.getInferenceResults().get(0); + + assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("not_to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(), + greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); + + classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(1); + assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("not_to_be")); + // they should always be in order of Most probable to least + assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(), + greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); + + // Test that top classes restrict the number returned + request = new InferModelAction.Request(modelId2, 0, toInfer2, new InferenceParams(1)); + response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + + classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0); + assertThat(classificationInferenceResults.getTopClasses(), hasSize(1)); + assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be")); + } + + public void testInferMissingModel() { + String model = "test-infer-missing-model"; + InferModelAction.Request request = new InferModelAction.Request(model, 0, Collections.emptyList(), null); + try { + client().execute(InferModelAction.INSTANCE, request).actionGet(); + } catch (ElasticsearchException ex) { + assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model, 0))); + } + } + + private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId, long modelVersion) { + return TrainedModelConfig.builder() + .setCreatedBy("ml_test") + .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) + .setDescription("trained model config for test") + .setModelId(modelId) + .setModelType("binary_decision_tree") + .setModelVersion(modelVersion); + } + + @Override + public NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + +}