diff --git a/docs/reference/ml/df-analytics/apis/get-trained-models-stats.asciidoc b/docs/reference/ml/df-analytics/apis/get-trained-models-stats.asciidoc index f49db649631b2..2c6b15777a125 100644 --- a/docs/reference/ml/df-analytics/apis/get-trained-models-stats.asciidoc +++ b/docs/reference/ml/df-analytics/apis/get-trained-models-stats.asciidoc @@ -48,7 +48,7 @@ request by using a comma-separated list of model IDs or a wildcard expression. ``:: (Optional, string) -include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id-or-alias] [[ml-get-trained-models-stats-query-params]] diff --git a/docs/reference/ml/df-analytics/apis/get-trained-models.asciidoc b/docs/reference/ml/df-analytics/apis/get-trained-models.asciidoc index 909e2d74d2a48..2b25c12c3e8c2 100644 --- a/docs/reference/ml/df-analytics/apis/get-trained-models.asciidoc +++ b/docs/reference/ml/df-analytics/apis/get-trained-models.asciidoc @@ -50,7 +50,7 @@ using a comma-separated list of model IDs or a wildcard expression. ``:: (Optional, string) -include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id-or-alias] [[ml-get-trained-models-query-params]] diff --git a/docs/reference/ml/df-analytics/apis/index.asciidoc b/docs/reference/ml/df-analytics/apis/index.asciidoc index 63a46480ce757..958298f027874 100644 --- a/docs/reference/ml/df-analytics/apis/index.asciidoc +++ b/docs/reference/ml/df-analytics/apis/index.asciidoc @@ -2,6 +2,7 @@ include::ml-df-analytics-apis.asciidoc[leveloffset=+1] //CREATE include::put-dfanalytics.asciidoc[leveloffset=+2] include::put-trained-models.asciidoc[leveloffset=+2] +include::put-trained-models-aliases.asciidoc[leveloffset=+2] //UPDATE include::update-dfanalytics.asciidoc[leveloffset=+2] //DELETE diff --git a/docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc b/docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc index dae8757275ee4..7c485f6c35f48 100644 --- a/docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc +++ b/docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc @@ -22,8 +22,9 @@ You can use the following APIs to perform {infer} operations. * <> * <> * <> +* <> -You can deploy a trained model to make predictions in an ingest pipeline or in +You can deploy a trained model to make predictions in an ingest pipeline or in an aggregation. Refer to the following documentation to learn more. * <> diff --git a/docs/reference/ml/df-analytics/apis/put-trained-models-aliases.asciidoc b/docs/reference/ml/df-analytics/apis/put-trained-models-aliases.asciidoc new file mode 100644 index 0000000000000..6869af8fbc427 --- /dev/null +++ b/docs/reference/ml/df-analytics/apis/put-trained-models-aliases.asciidoc @@ -0,0 +1,89 @@ +[role="xpack"] +[testenv="platinum"] +[[put-trained-models-aliases]] += Put Trained Models Aliases API +[subs="attributes"] +++++ +Put Trained Models Aliases +++++ + +Creates a trained models alias. These model aliases can be used instead of the trained model ID +when referencing the model in the stack. Model aliases must be unique, and a trained model can have +more than one model alias referring to it. But a model alias can only refer to a single trained model. + +beta::[] + +[[ml-put-trained-models-aliases-request]] +== {api-request-title} + +`PUT _ml/trained_models//model_aliases/` + + +[[ml-put-trained-models-aliases-prereq]] +== {api-prereq-title} + +If the {es} {security-features} are enabled, you must have the following +built-in roles and privileges: + +* `machine_learning_admin` + +For more information, see <>, <>, and +{ml-docs-setup-privileges}. + +[[ml-put-trained-models-aliases-desc]] +== {api-description-title} + +This API creates a new model alias to refer to trained models, or updates an existing +trained model's alias. + +When updating an existing model alias to a new model ID, this API will return a error if the models +are of different inference types. Example, if attempting to put the model alias +`flights-delay-prediction` from a regression model to a classification model, the API will error. + +The API will return a warning if there are very few input fields in common between the old +and new models for the model alias. + +[[ml-put-trained-models-aliases-path-params]] +== {api-path-parms-title} + +`model_id`:: +(Required, string) +The trained model ID to which the model alias should refer. + +`model_alias`:: +(Required, string) +The model alias to create or update. The model_alias cannot end in numbers. + +[[ml-put-trained-models-aliases-query-params]] +== {api-query-parms-title} + +`reassign`:: +(Optional, boolean) +Should the `model_alias` get reassigned to the provided `model_id` if it is already +assigned to a model. Defaults to false. The API will return an error if the `model_alias` +is already assigned to a model but this parameter is `false`. + +[[ml-put-trained-models-aliases-example]] +== {api-examples-title} + +[[ml-put-trained-models-aliases-example-new-alias]] +=== Creating a new model alias + +The following example shows how to create a new model alias for a trained model ID. + +[source,console] +-------------------------------------------------- +PUT _ml/trained_models/flight-delay-prediction-1574775339910/model_aliases/flight_delay_model +-------------------------------------------------- +// TEST[skip:setup kibana sample data] + +[[ml-put-trained-models-aliases-example-put-alias]] +=== Updating an existing model alias + +The following example shows how to reassign an existing model alias for a trained model ID. + +[source,console] +-------------------------------------------------- +PUT _ml/trained_models/flight-delay-prediction-1580004349800/model_aliases/flight_delay_model?reassign=true +-------------------------------------------------- +// TEST[skip:setup kibana sample data] diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index 0dd46136c2ade..82e0cfe6e381e 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -1149,6 +1149,10 @@ tag::model-id[] The unique identifier of the trained model. end::model-id[] +tag::model-id-or-alias[] +The unique identifier of the trained model or a model alias. +end::model-id-or-alias[] + tag::model-memory-limit[] The approximate maximum amount of memory resources that are required for analytical processing. Once this limit is approached, data pruning becomes diff --git a/server/src/main/java/org/elasticsearch/common/util/set/Sets.java b/server/src/main/java/org/elasticsearch/common/util/set/Sets.java index 252f94af83b7a..763e56cf6265a 100644 --- a/server/src/main/java/org/elasticsearch/common/util/set/Sets.java +++ b/server/src/main/java/org/elasticsearch/common/util/set/Sets.java @@ -62,6 +62,12 @@ public static boolean haveEmptyIntersection(Set left, Set right) { return left.stream().noneMatch(right::contains); } + public static boolean haveNonEmptyIntersection(Set left, Set right) { + Objects.requireNonNull(left); + Objects.requireNonNull(right); + return left.stream().anyMatch(right::contains); + } + /** * The relative complement, or difference, of the specified left and right set. Namely, the resulting set contains all the elements that * are in the left set but not in the right set. Neither input is mutated by this operation, an entirely new set is returned. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java index dad03eed463b5..88daf2dc1e0ec 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java @@ -171,7 +171,7 @@ protected Reader getReader() { public static class Builder { private long totalModelCount; - private Set expandedIds; + private Map> expandedIdsWithAliases; private Map ingestStatsMap; private Map inferenceStatsMap; @@ -180,13 +180,13 @@ public Builder setTotalModelCount(long totalModelCount) { return this; } - public Builder setExpandedIds(Set expandedIds) { - this.expandedIds = expandedIds; + public Builder setExpandedIdsWithAliases(Map> expandedIdsWithAliases) { + this.expandedIdsWithAliases = expandedIdsWithAliases; return this; } - public Set getExpandedIds() { - return this.expandedIds; + public Map> getExpandedIdsWithAliases() { + return this.expandedIdsWithAliases; } public Builder setIngestStatsByModelId(Map ingestStatsByModelId) { @@ -200,8 +200,8 @@ public Builder setInferenceStatsByModelId(Map infereceSt } public Response build() { - List trainedModelStats = new ArrayList<>(expandedIds.size()); - expandedIds.forEach(id -> { + List trainedModelStats = new ArrayList<>(expandedIdsWithAliases.size()); + expandedIdsWithAliases.keySet().forEach(id -> { IngestStats ingestStats = ingestStatsMap.get(id); InferenceStats inferenceStats = inferenceStatsMap.get(id); trainedModelStats.add(new TrainedModelStats( diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java index 723aef2f25c38..21546bf22d6ae 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java @@ -143,18 +143,25 @@ public int hashCode() { public static class Response extends ActionResponse { private final List inferenceResults; + private final String modelId; private final boolean isLicensed; - public Response(List inferenceResults, boolean isLicensed) { + public Response(List inferenceResults, String modelId, boolean isLicensed) { super(); this.inferenceResults = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(inferenceResults, "inferenceResults")); this.isLicensed = isLicensed; + this.modelId = modelId; } public Response(StreamInput in) throws IOException { super(in); this.inferenceResults = Collections.unmodifiableList(in.readNamedWriteableList(InferenceResults.class)); this.isLicensed = in.readBoolean(); + if (in.getVersion().onOrAfter(Version.V_8_0_0)) { + this.modelId = in.readOptionalString(); + } else { + this.modelId = null; + } } public List getInferenceResults() { @@ -165,10 +172,17 @@ public boolean isLicensed() { return isLicensed; } + public String getModelId() { + return modelId; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeNamedWriteableList(inferenceResults); out.writeBoolean(isLicensed); + if (out.getVersion().onOrAfter(Version.V_8_0_0)) { + out.writeOptionalString(modelId); + } } @Override @@ -176,12 +190,14 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; InternalInferModelAction.Response that = (InternalInferModelAction.Response) o; - return isLicensed == that.isLicensed && Objects.equals(inferenceResults, that.inferenceResults); + return isLicensed == that.isLicensed + && Objects.equals(inferenceResults, that.inferenceResults) + && Objects.equals(modelId, that.modelId); } @Override public int hashCode() { - return Objects.hash(inferenceResults, isLicensed); + return Objects.hash(inferenceResults, isLicensed, modelId); } public static Builder builder() { @@ -190,6 +206,7 @@ public static Builder builder() { public static class Builder { private List inferenceResults; + private String modelId; private boolean isLicensed; public Builder setInferenceResults(List inferenceResults) { @@ -202,8 +219,13 @@ public Builder setLicensed(boolean licensed) { return this; } + public Builder setModelId(String modelId) { + this.modelId = modelId; + return this; + } + public Response build() { - return new Response(inferenceResults, isLicensed); + return new Response(inferenceResults, modelId, isLicensed); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAliasAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAliasAction.java new file mode 100644 index 0000000000000..d5078c1dadd21 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAliasAction.java @@ -0,0 +1,119 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Locale; +import java.util.Objects; +import java.util.regex.Pattern; + +import static org.elasticsearch.action.ValidateActions.addValidationError; +import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INVALID_MODEL_ALIAS; + +public class PutTrainedModelAliasAction extends ActionType { + + // NOTE this is similar to our valid ID check. The difference here is that model_aliases cannot end in numbers + // This is to protect our automatic model naming conventions from hitting weird model_alias conflicts + private static final Pattern VALID_MODEL_ALIAS_CHAR_PATTERN = Pattern.compile("[a-z0-9](?:[a-z0-9_\\-\\.]*[a-z])?"); + + public static final PutTrainedModelAliasAction INSTANCE = new PutTrainedModelAliasAction(); + public static final String NAME = "cluster:admin/xpack/ml/inference/model_aliases/put"; + + private PutTrainedModelAliasAction() { + super(NAME, AcknowledgedResponse::readFrom); + } + + public static class Request extends AcknowledgedRequest { + + public static final String MODEL_ALIAS = "model_alias"; + public static final String REASSIGN = "reassign"; + + private final String modelAlias; + private final String modelId; + private final boolean reassign; + + public Request(String modelAlias, String modelId, boolean reassign) { + this.modelAlias = ExceptionsHelper.requireNonNull(modelAlias, MODEL_ALIAS); + this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); + this.reassign = reassign; + } + + public Request(StreamInput in) throws IOException { + super(in); + this.modelAlias = in.readString(); + this.modelId = in.readString(); + this.reassign = in.readBoolean(); + } + + public String getModelAlias() { + return modelAlias; + } + + public String getModelId() { + return modelId; + } + + public boolean isReassign() { + return reassign; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(modelAlias); + out.writeString(modelId); + out.writeBoolean(reassign); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (modelAlias.equals(modelId)) { + validationException = addValidationError( + String.format( + Locale.ROOT, + "model_alias [%s] cannot equal model_id [%s]", + modelAlias, + modelId + ), + validationException + ); + } + if (VALID_MODEL_ALIAS_CHAR_PATTERN.matcher(modelAlias).matches() == false) { + validationException = addValidationError(Messages.getMessage(INVALID_MODEL_ALIAS, modelAlias), validationException); + } + return validationException; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(modelAlias, request.modelAlias) + && Objects.equals(modelId, request.modelId) + && Objects.equals(reassign, request.reassign); + } + + @Override + public int hashCode() { + return Objects.hash(modelAlias, modelId, reassign); + } + + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 1ba35f12f707d..3248e3f8bc2ad 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -43,6 +43,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.stream.Collectors; import static org.elasticsearch.action.ValidateActions.addValidationError; @@ -58,6 +59,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { public static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance"; public static final String FEATURE_IMPORTANCE_BASELINE = "feature_importance_baseline"; public static final String HYPERPARAMETERS = "hyperparameters"; + public static final String MODEL_ALIASES = "model_aliases"; private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage"; @@ -471,34 +473,41 @@ public Builder setFeatureImportance(List totalFeatureImp if (totalFeatureImportance == null) { return this; } - if (this.metadata == null) { - this.metadata = new HashMap<>(); - } - this.metadata.put(TOTAL_FEATURE_IMPORTANCE, - totalFeatureImportance.stream().map(TotalFeatureImportance::asMap).collect(Collectors.toList())); - return this; + return addToMetadata( + TOTAL_FEATURE_IMPORTANCE, + totalFeatureImportance.stream().map(TotalFeatureImportance::asMap).collect(Collectors.toList()) + ); } public Builder setBaselineFeatureImportance(FeatureImportanceBaseline featureImportanceBaseline) { if (featureImportanceBaseline == null) { return this; } - if (this.metadata == null) { - this.metadata = new HashMap<>(); - } - this.metadata.put(FEATURE_IMPORTANCE_BASELINE, featureImportanceBaseline.asMap()); - return this; + return addToMetadata(FEATURE_IMPORTANCE_BASELINE, featureImportanceBaseline.asMap()); } public Builder setHyperparameters(List hyperparameters) { if (hyperparameters == null) { return this; } + return addToMetadata( + HYPERPARAMETERS, + hyperparameters.stream().map(Hyperparameters::asMap).collect(Collectors.toList()) + ); + } + + public Builder setModelAliases(Set modelAliases) { + if (modelAliases == null || modelAliases.isEmpty()) { + return this; + } + return addToMetadata(MODEL_ALIASES, modelAliases.stream().sorted().collect(Collectors.toList())); + } + + private Builder addToMetadata(String fieldName, Object value) { if (this.metadata == null) { this.metadata = new HashMap<>(); } - this.metadata.put(HYPERPARAMETERS, - hyperparameters.stream().map(Hyperparameters::asMap).collect(Collectors.toList())); + this.metadata.put(fieldName, value); return this; } @@ -663,6 +672,10 @@ public Builder validate(boolean forCreation) { metadata.get(TOTAL_FEATURE_IMPORTANCE), METADATA.getPreferredName() + "." + TOTAL_FEATURE_IMPORTANCE, validationException); + validationException = checkIllegalSetting( + metadata.get(MODEL_ALIASES), + METADATA.getPreferredName() + "." + MODEL_ALIASES, + validationException); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 9b556052b9a25..401e6660f8917 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -120,6 +120,11 @@ public final class Messages { public static final String INFERENCE_TAGS_AND_MODEL_IDS_UNIQUE = "The provided tags {0} must not match existing model_ids."; public static final String INFERENCE_MODEL_ID_AND_TAGS_UNIQUE = "The provided model_id {0} must not match existing tags."; + public static final String INVALID_MODEL_ALIAS = "Invalid model_alias; ''{0}'' can contain lowercase alphanumeric (a-z and 0-9), " + + "hyphens or underscores; must start with alphanumeric and cannot end with numbers"; + public static final String TRAINED_MODEL_INPUTS_DIFFER_SIGNIFICANTLY = + "The input fields for new model [{0}] and for old model [{1}] differ significantly, model results may change drastically."; + public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again"; public static final String JOB_AUDIT_CREATED = "Job created"; public static final String JOB_AUDIT_UPDATED = "Job updated: {0}"; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionResponseTests.java index 9d752e6b6799e..9732ff14963b0 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionResponseTests.java @@ -29,6 +29,7 @@ protected Response createTestInstance() { Stream.generate(() -> randomInferenceResult(resultType)) .limit(randomIntBetween(0, 10)) .collect(Collectors.toList()), + randomAlphaOfLength(10), randomBoolean()); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAliasActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAliasActionRequestTests.java new file mode 100644 index 0000000000000..8f77b116d7e96 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAliasActionRequestTests.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAliasAction.Request; +import org.junit.Before; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; + +public class PutTrainedModelAliasActionRequestTests extends AbstractWireSerializingTestCase { + + private String modelAlias; + + @Before + public void setupModelAlias() { + modelAlias = randomAlphaOfLength(10); + } + + @Override + protected Request createTestInstance() { + return new Request( + modelAlias, + randomAlphaOfLength(10), + randomBoolean() + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } + + public void testCtor() { + expectThrows(Exception.class, () -> new Request(null, randomAlphaOfLength(10), randomBoolean())); + expectThrows(Exception.class, () -> new Request(randomAlphaOfLength(10), null, randomBoolean())); + } + + public void testValidate() { + + { // model_alias equal to model Id + ActionRequestValidationException ex = new Request("foo", "foo", randomBoolean()).validate(); + assertThat(ex, not(nullValue())); + assertThat(ex.getMessage(), containsString("model_alias [foo] cannot equal model_id [foo]")); + } + { // model_alias cannot end in numbers + String modelAlias = randomAlphaOfLength(10) + randomIntBetween(0, Integer.MAX_VALUE); + ActionRequestValidationException ex = new Request(modelAlias, "foo", randomBoolean()).validate(); + assertThat(ex, not(nullValue())); + assertThat( + ex.getMessage(), + containsString( + "can contain lowercase alphanumeric (a-z and 0-9), hyphens or underscores; " + + "must start with alphanumeric and cannot end with numbers" + ) + ); + } + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/integration/MlRestTestStateCleaner.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/integration/MlRestTestStateCleaner.java index bbfca86e27eaf..6b6d76188cef3 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/integration/MlRestTestStateCleaner.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/integration/MlRestTestStateCleaner.java @@ -14,11 +14,14 @@ import org.elasticsearch.test.rest.ESRestTestCase; import java.io.IOException; +import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; public class MlRestTestStateCleaner { + private static final Set NOT_DELETED_TRAINED_MODELS = Collections.singleton("lang_ident_model_1"); private final Logger logger; private final RestClient adminClient; @@ -28,12 +31,34 @@ public MlRestTestStateCleaner(Logger logger, RestClient adminClient) { } public void clearMlMetadata() throws IOException { + deleteAllTrainedModels(); deleteAllDatafeeds(); deleteAllJobs(); deleteAllDataFrameAnalytics(); // indices will be deleted by the ESRestTestCase class } + @SuppressWarnings("unchecked") + private void deleteAllTrainedModels() throws IOException { + final Request getTrainedModels = new Request("GET", "/_ml/trained_models"); + getTrainedModels.addParameter("size", "10000"); + final Response trainedModelsResponse = adminClient.performRequest(getTrainedModels); + final List> models = (List>) XContentMapValues.extractValue( + "trained_model_configs", + ESRestTestCase.entityAsMap(trainedModelsResponse) + ); + if (models == null || models.isEmpty()) { + return; + } + for (Map model : models) { + String modelId = (String) model.get("model_id"); + if (NOT_DELETED_TRAINED_MODELS.contains(modelId)) { + continue; + } + adminClient.performRequest(new Request("DELETE", "/_ml/trained_models/" + modelId)); + } + } + @SuppressWarnings("unchecked") private void deleteAllDatafeeds() throws IOException { final Request datafeedsRequest = new Request("GET", "/_ml/datafeeds"); diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index afd4453625995..e26ab4bbd5792 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -145,6 +145,10 @@ tasks.named("yamlRestTest").configure { 'ml/inference_crud/Test put ensemble with tree where tree has out of bounds feature_names index', 'ml/inference_crud/Test put model with empty input.field_names', 'ml/inference_crud/Test PUT model where target type and inference config mismatch', + 'ml/inference_crud/Test update model alias with model id referring to missing model', + 'ml/inference_crud/Test update model alias with bad alias', + 'ml/inference_crud/Test update model alias where alias exists but old model id is different inference type', + 'ml/inference_crud/Test update model alias where alias exists but reassign is false', 'ml/inference_processor/Test create processor with missing mandatory fields', 'ml/inference_stats_crud/Test get stats given missing trained model', 'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false', diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java index 7607514d74275..e60b8217798eb 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -48,6 +48,7 @@ import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; /** * This is a {@link ESRestTestCase} because the cleanup code in {@link ExternalTestCluster#ensureEstimatedStats()} causes problems @@ -201,6 +202,84 @@ public void testPipelineIngest() throws Exception { }, 30, TimeUnit.SECONDS); } + public void testPipelineIngestWithModelAliases() throws Exception { + String regressionModelId = "test_regression_1"; + putModel(regressionModelId, REGRESSION_CONFIG); + String regressionModelId2 = "test_regression_2"; + putModel(regressionModelId2, REGRESSION_CONFIG); + String modelAlias = "test_regression"; + putModelAlias(modelAlias, regressionModelId); + + client().performRequest(putPipeline("simple_regression_pipeline", pipelineDefinition(modelAlias, "regression"))); + + for (int i = 0; i < 10; i++) { + client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc())); + } + putModelAlias(modelAlias, regressionModelId2); + // Need to assert busy as loading the model and then switching the model alias can take time + assertBusy(() -> { + String source = "{\n" + + " \"docs\": [\n" + + " {\"_source\": {\n" + + " \"col1\": \"female\",\n" + + " \"col2\": \"M\",\n" + + " \"col3\": \"none\",\n" + + " \"col4\": 10\n" + + " }}]\n" + + "}"; + Request request = new Request("POST", "_ingest/pipeline/simple_regression_pipeline/_simulate"); + request.setJsonEntity(source); + Response response = client().performRequest(request); + String responseString = EntityUtils.toString(response.getEntity()); + assertThat(responseString, containsString("\"model_id\":\"test_regression_2\"")); + }, 30, TimeUnit.SECONDS); + + for (int i = 0; i < 10; i++) { + client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc())); + } + + client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline")); + + client().performRequest(new Request("POST", "index_for_inference_test/_refresh")); + + Response searchResponse = client().performRequest(searchRequest("index_for_inference_test", + QueryBuilders.boolQuery() + .filter( + QueryBuilders.existsQuery("ml.inference.regression.predicted_value")))); + // Verify we have 20 documents that contain a predicted value for regression + assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":20")); + + + // Since this is a multi-node cluster, the model could be loaded and cached on one ingest node but not the other + // Consequently, we should only verify that some of the documents refer to the first regression model + // and some refer to the second. + searchResponse = client().performRequest(searchRequest("index_for_inference_test", + QueryBuilders.boolQuery() + .filter( + QueryBuilders.termQuery("ml.inference.regression.model_id.keyword", regressionModelId)))); + assertThat(EntityUtils.toString(searchResponse.getEntity()), not(containsString("\"value\":0"))); + + searchResponse = client().performRequest(searchRequest("index_for_inference_test", + QueryBuilders.boolQuery() + .filter( + QueryBuilders.termQuery("ml.inference.regression.model_id.keyword", regressionModelId2)))); + assertThat(EntityUtils.toString(searchResponse.getEntity()), not(containsString("\"value\":0"))); + + assertBusy(() -> { + try (XContentParser parser = createParser(JsonXContent.jsonXContent, client().performRequest(new Request("GET", + "_ml/trained_models/" + modelAlias + "/_stats")).getEntity().getContent())) { + GetTrainedModelsStatsResponse response = GetTrainedModelsStatsResponse.fromXContent(parser); + assertThat(response.toString(), response.getTrainedModelStats(), hasSize(1)); + TrainedModelStats trainedModelStats = response.getTrainedModelStats().get(0); + assertThat(trainedModelStats.getModelId(), equalTo(regressionModelId2)); + assertThat(trainedModelStats.getInferenceStats(), is(notNullValue())); + } catch (ResponseException ex) { + //this could just mean shard failures. + fail(ex.getMessage()); + } + }); + } + public void assertStatsWithCacheMisses(String modelId, long inferenceCount) throws IOException { Response statsResponse = client().performRequest(new Request("GET", "_ml/trained_models/" + modelId + "/_stats")); @@ -629,4 +708,9 @@ private void putModel(String modelId, String modelConfiguration) throws IOExcept client().performRequest(request); } + private void putModelAlias(String modelAlias, String newModel) throws IOException { + Request request = new Request("PUT", "_ml/trained_models/" + newModel + "/model_aliases/" + modelAlias + "?reassign=true"); + client().performRequest(request); + } + } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java index b4fb0d27de99c..78798dedcc83c 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java @@ -79,6 +79,7 @@ import org.elasticsearch.xpack.ilm.IndexLifecycle; import org.elasticsearch.xpack.ml.LocalStateMachineLearning; import org.elasticsearch.xpack.ml.autoscaling.MlScalingReason; +import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; import java.io.IOException; import java.io.UncheckedIOException; @@ -257,6 +258,8 @@ protected void ensureClusterStateConsistency() throws IOException { if (cluster() != null && cluster().size() > 0) { List entries = new ArrayList<>(ClusterModule.getNamedWriteables()); entries.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables()); + entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, ModelAliasMetadata.NAME, ModelAliasMetadata::new)); + entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelAliasMetadata.NAME, ModelAliasMetadata::readDiffFrom)); entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, "ml", MlMetadata::new)); entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, IndexLifecycleMetadata.TYPE, IndexLifecycleMetadata::new)); entries.add(new NamedWriteableRegistry.Entry(LifecycleType.class, TimeseriesLifecycleType.TYPE, diff --git a/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceProcessorIT.java b/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceProcessorIT.java index a13e6c38ad846..960514b6e1cc0 100644 --- a/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceProcessorIT.java +++ b/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceProcessorIT.java @@ -13,19 +13,26 @@ import org.elasticsearch.client.ResponseException; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.test.rest.ESRestTestCase; +import org.junit.After; import org.junit.Before; import java.io.IOException; +import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; public class InferenceProcessorIT extends ESRestTestCase { + private static final Set NOT_DELETED_TRAINED_MODELS = Collections.singleton("lang_ident_model_1"); private static final String MODEL_ID = "a-perfect-regression-model"; + private final Set createdPipelines = new HashSet<>(); @Before public void enableLogging() throws IOException { @@ -36,8 +43,39 @@ public void enableLogging() throws IOException { assertThat(client().performRequest(setTrace).getStatusLine().getStatusCode(), equalTo(200)); } - private void putRegressionModel() throws IOException { + @SuppressWarnings("unchecked") + @After + public void cleanup() throws Exception { + for (String createdPipeline : createdPipelines) { + deletePipeline(createdPipeline); + } + createdPipelines.clear(); + waitForStats(); + final Request getTrainedModels = new Request("GET", "/_ml/trained_models"); + getTrainedModels.addParameter("size", "10000"); + final Response trainedModelsResponse = adminClient().performRequest(getTrainedModels); + final List> models = (List>) XContentMapValues.extractValue( + "trained_model_configs", + ESRestTestCase.entityAsMap(trainedModelsResponse) + ); + if (models == null || models.isEmpty()) { + return; + } + for (Map model : models) { + String modelId = (String) model.get("model_id"); + if (NOT_DELETED_TRAINED_MODELS.contains(modelId)) { + continue; + } + adminClient().performRequest(new Request("DELETE", "/_ml/trained_models/" + modelId)); + } + } + + private void putModelAlias(String modelAlias, String newModel) throws IOException { + Request request = new Request("PUT", "_ml/trained_models/" + newModel + "/model_aliases/" + modelAlias + "?reassign=true"); + client().performRequest(request); + } + private void putRegressionModel() throws IOException { Request model = new Request("PUT", "_ml/trained_models/" + MODEL_ID); model.setJsonEntity( " {\n" + @@ -66,24 +104,9 @@ private void putRegressionModel() throws IOException { @SuppressWarnings("unchecked") public void testCreateAndDeletePipelineWithInferenceProcessor() throws Exception { putRegressionModel(); - - Request putPipeline = new Request("PUT", "_ingest/pipeline/regression-model-pipeline"); - putPipeline.setJsonEntity( - " {\n" + - " \"processors\": [\n" + - " {\n" + - " \"inference\" : {\n" + - " \"model_id\" : \"" + MODEL_ID + "\",\n" + - " \"inference_config\": {\"regression\": {}},\n" + - " \"target_field\": \"regression_field\",\n" + - " \"field_map\": {}\n" + - " }\n" + - " }\n" + - " ]\n" + - " }" - ); - - assertThat(client().performRequest(putPipeline).getStatusLine().getStatusCode(), equalTo(200)); + String pipelineId = "regression-model-pipeline"; + createdPipelines.add(pipelineId); + putPipeline(MODEL_ID, pipelineId); Map statsAsMap = getStats(); List pipelineCount = @@ -100,8 +123,8 @@ public void testCreateAndDeletePipelineWithInferenceProcessor() throws Exception // using the model will ensure it is loaded and stats will be written before it is deleted infer("regression-model-pipeline"); - Request deletePipeline = new Request("DELETE", "_ingest/pipeline/regression-model-pipeline"); - assertThat(client().performRequest(deletePipeline).getStatusLine().getStatusCode(), equalTo(200)); + deletePipeline(pipelineId); + createdPipelines.remove(pipelineId); // check stats are updated assertBusy(() -> { @@ -129,9 +152,100 @@ public void testCreateAndDeletePipelineWithInferenceProcessor() throws Exception }); } + @SuppressWarnings("unchecked") + public void testCreateAndDeletePipelineWithInferenceProcessorByName() throws Exception { + putRegressionModel(); + + putModelAlias("regression_first", MODEL_ID); + putModelAlias("regression_second", MODEL_ID); + createdPipelines.add("first_pipeline"); + putPipeline("regression_first", "first_pipeline"); + createdPipelines.add("second_pipeline"); + putPipeline("regression_second", "second_pipeline"); + + Map statsAsMap = getStats(); + List pipelineCount = + (List)XContentMapValues.extractValue("trained_model_stats.pipeline_count", statsAsMap); + assertThat(pipelineCount.get(0), equalTo(2)); + + List> counts = + (List>)XContentMapValues.extractValue("trained_model_stats.ingest.total", statsAsMap); + assertThat(counts.get(0).get("count"), equalTo(0)); + assertThat(counts.get(0).get("time_in_millis"), equalTo(0)); + assertThat(counts.get(0).get("current"), equalTo(0)); + assertThat(counts.get(0).get("failed"), equalTo(0)); + + // using the model will ensure it is loaded and stats will be written before it is deleted + infer("first_pipeline"); + deletePipeline("first_pipeline"); + createdPipelines.remove("first_pipeline"); + + infer("second_pipeline"); + deletePipeline("second_pipeline"); + createdPipelines.remove("second_pipeline"); + + // check stats are updated + assertBusy(() -> { + Map updatedStatsMap = null; + try { + updatedStatsMap = getStats(); + } catch (ResponseException e) { + // the search may fail because the index is not ready yet in which case retry + if (e.getMessage().contains("search_phase_execution_exception")) { + fail("search failed- retry"); + } else { + throw e; + } + } + + List updatedPipelineCount = + (List) XContentMapValues.extractValue("trained_model_stats.pipeline_count", updatedStatsMap); + assertThat(updatedPipelineCount.get(0), equalTo(0)); + + List> inferenceStats = + (List>) XContentMapValues.extractValue("trained_model_stats.inference_stats", updatedStatsMap); + assertNotNull(inferenceStats); + assertThat(inferenceStats, hasSize(1)); + assertThat(inferenceStats.toString(), inferenceStats.get(0).get("inference_count"), equalTo(2)); + }); + } + + public void testDeleteModelWhileAliasReferencedByPipeline() throws Exception { + putRegressionModel(); + putModelAlias("regression_first", MODEL_ID); + createdPipelines.add("first_pipeline"); + putPipeline("regression_first", "first_pipeline"); + Exception ex = expectThrows(Exception.class, + () -> client().performRequest(new Request("DELETE", "_ml/trained_models/" + MODEL_ID))); + assertThat(ex.getMessage(), + containsString("Cannot delete model [" + + MODEL_ID + + "] as it has a model_alias [regression_first] that is still referenced by ingest processors")); + infer("first_pipeline"); + deletePipeline("first_pipeline"); + waitForStats(); + } + + public void testDeleteModelWhileReferencedByPipeline() throws Exception { + putRegressionModel(); + createdPipelines.add("first_pipeline"); + putPipeline(MODEL_ID, "first_pipeline"); + Exception ex = expectThrows(Exception.class, + () -> client().performRequest(new Request("DELETE", "_ml/trained_models/" + MODEL_ID))); + assertThat(ex.getMessage(), + containsString("Cannot delete model [" + + MODEL_ID + + "] as it is still referenced by ingest processors")); + infer("first_pipeline"); + deletePipeline("first_pipeline"); + waitForStats(); + } + + @SuppressWarnings("unchecked") public void testCreateProcessorWithDeprecatedFields() throws Exception { putRegressionModel(); + createdPipelines.add("regression-model-deprecated-pipeline"); Request putPipeline = new Request("PUT", "_ingest/pipeline/regression-model-deprecated-pipeline"); putPipeline.setJsonEntity( "{\n" + @@ -155,14 +269,35 @@ public void testCreateProcessorWithDeprecatedFields() throws Exception { // using the model will ensure it is loaded and stats will be written before it is deleted infer("regression-model-deprecated-pipeline"); - Request deletePipeline = new Request("DELETE", "_ingest/pipeline/regression-model-deprecated-pipeline"); - Response deleteResponse = client().performRequest(deletePipeline); - assertThat(deleteResponse.getStatusLine().getStatusCode(), equalTo(200)); + deletePipeline("regression-model-deprecated-pipeline"); + createdPipelines.remove("regression-model-deprecated-pipeline"); + waitForStats(); + assertBusy(() -> { + Map updatedStatsMap = null; + try { + updatedStatsMap = getStats(); + } catch (ResponseException e) { + // the search may fail because the index is not ready yet in which case retry + if (e.getMessage().contains("search_phase_execution_exception")) { + fail("search failed- retry"); + } else { + throw e; + } + } + + List updatedPipelineCount = + (List) XContentMapValues.extractValue("trained_model_stats.pipeline_count", updatedStatsMap); + assertThat(updatedPipelineCount.get(0), equalTo(0)); - waitForStatsDoc(); + List> inferenceStats = + (List>) XContentMapValues.extractValue("trained_model_stats.inference_stats", updatedStatsMap); + assertNotNull(inferenceStats); + assertThat(inferenceStats, hasSize(1)); + assertThat(inferenceStats.get(0).get("inference_count"), equalTo(1)); + }); } - public void infer(String pipelineId) throws IOException { + private void infer(String pipelineId) throws IOException { Request putDoc = new Request("POST", "any_index/_doc?pipeline=" + pipelineId); putDoc.setJsonEntity("{\"field1\": 1, \"field2\": 2}"); @@ -170,43 +305,56 @@ public void infer(String pipelineId) throws IOException { assertThat(response.getStatusLine().getStatusCode(), equalTo(201)); } - @SuppressWarnings("unchecked") - public void waitForStatsDoc() throws Exception { - assertBusy( () -> { - Request searchForStats = new Request("GET", ".ml-stats-*/_search?rest_total_hits_as_int"); - searchForStats.setJsonEntity( - "{\n" + - " \"query\": {\n" + - " \"bool\": {\n" + - " \"filter\": [\n" + - " {\n" + - " \"term\": {\n" + - " \"type\": \"inference_stats\"\n" + - " }\n" + - " },\n" + - " {\n" + - " \"term\": {\n" + - " \"model_id\": \"" + MODEL_ID + "\"\n" + - " }\n" + - " }\n" + - " ]\n" + - " }\n" + - " }\n" + - "}" - ); + private void putPipeline(String modelId, String pipelineName) throws IOException { + Request putPipeline = new Request("PUT", "_ingest/pipeline/" + pipelineName); + putPipeline.setJsonEntity( + " {\n" + + " \"processors\": [\n" + + " {\n" + + " \"inference\" : {\n" + + " \"model_id\" : \"" + modelId + "\",\n" + + " \"inference_config\": {\"regression\": {}},\n" + + " \"target_field\": \"regression_field\",\n" + + " \"field_map\": {}\n" + + " }\n" + + " }\n" + + " ]\n" + + " }" + ); - try { - Response searchResponse = client().performRequest(searchForStats); + assertThat(client().performRequest(putPipeline).getStatusLine().getStatusCode(), equalTo(200)); + } - Map responseAsMap = entityAsMap(searchResponse); - Map hits = (Map)responseAsMap.get("hits"); - assertThat(responseAsMap.toString(), hits.get("total"), equalTo(1)); + private void deletePipeline(String pipelineId) throws IOException { + try { + Request deletePipeline = new Request("DELETE", "_ingest/pipeline/" + pipelineId); + assertThat(client().performRequest(deletePipeline).getStatusLine().getStatusCode(), equalTo(200)); + } catch (ResponseException ex) { + if (ex.getResponse().getStatusLine().getStatusCode() != 404) { + throw ex; + } + } + } + + @SuppressWarnings("unchecked") + private void waitForStats() throws Exception { + assertBusy(() -> { + Map updatedStatsMap = null; + try { + ensureGreen(".ml-stats-*"); + updatedStatsMap = getStats(); } catch (ResponseException e) { // the search may fail because the index is not ready yet in which case retry - if (e.getMessage().contains("search_phase_execution_exception") == false) { + if (e.getMessage().contains("search_phase_execution_exception")) { + fail("search failed- retry"); + } else { throw e; } } + + List> inferenceStats = + (List>) XContentMapValues.extractValue("trained_model_stats.inference_stats", updatedStatsMap); + assertNotNull(inferenceStats); }); } diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java index 8dd6413eea435..cb68f15e750bc 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java @@ -36,6 +36,7 @@ import org.elasticsearch.xpack.ml.extractor.DocValueField; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; +import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests; @@ -102,11 +103,18 @@ public void testStoreModelViaChunkedPersister() throws IOException { .collect(Collectors.toList())); persister.createAndIndexInferenceModelMetadata(modelMetadata); - PlainActionFuture>> getIdsFuture = new PlainActionFuture<>(); - trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture); - Tuple> ids = getIdsFuture.actionGet(); + PlainActionFuture>>> getIdsFuture = new PlainActionFuture<>(); + trainedModelProvider.expandIds( + modelId + "*", + false, + PageParams.defaultParams(), + Collections.emptySet(), + ModelAliasMetadata.EMPTY, + getIdsFuture + ); + Tuple>> ids = getIdsFuture.actionGet(); assertThat(ids.v1(), equalTo(1L)); - String inferenceModelId = ids.v2().iterator().next(); + String inferenceModelId = ids.v2().keySet().iterator().next(); PlainActionFuture getTrainedModelFuture = new PlainActionFuture<>(); trainedModelProvider.getTrainedModel(inferenceModelId, GetTrainedModelsAction.Includes.all(), getTrainedModelFuture); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 74a2fd36cce40..ed7292e29ad96 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -25,6 +25,7 @@ import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.ParseField; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.ClusterSettings; @@ -141,6 +142,7 @@ import org.elasticsearch.xpack.core.ml.action.UpdateJobAction; import org.elasticsearch.xpack.core.ml.action.UpdateModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.UpdateProcessAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAliasAction; import org.elasticsearch.xpack.core.ml.action.UpgradeJobModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.ValidateDetectorAction; import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction; @@ -219,6 +221,7 @@ import org.elasticsearch.xpack.ml.action.TransportUpdateJobAction; import org.elasticsearch.xpack.ml.action.TransportUpdateModelSnapshotAction; import org.elasticsearch.xpack.ml.action.TransportUpdateProcessAction; +import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelAliasAction; import org.elasticsearch.xpack.ml.action.TransportUpgradeJobModelSnapshotAction; import org.elasticsearch.xpack.ml.action.TransportValidateDetectorAction; import org.elasticsearch.xpack.ml.action.TransportValidateJobConfigAction; @@ -239,6 +242,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; +import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; import org.elasticsearch.xpack.ml.inference.aggs.InferencePipelineAggregationBuilder; import org.elasticsearch.xpack.ml.inference.aggs.InternalInferenceAggregation; @@ -316,6 +320,7 @@ import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction; import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAction; +import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAliasAction; import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction; @@ -933,6 +938,7 @@ public List getRestHandlers(Settings settings, RestController restC new RestGetTrainedModelsStatsAction(), new RestPutTrainedModelAction(), new RestUpgradeJobModelSnapshotAction(), + new RestPutTrainedModelAliasAction(), // CAT Handlers new RestCatJobsAction(), new RestCatTrainedModelsAction(), @@ -1016,7 +1022,8 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class), new ActionHandler<>(PutTrainedModelAction.INSTANCE, TransportPutTrainedModelAction.class), new ActionHandler<>(UpgradeJobModelSnapshotAction.INSTANCE, TransportUpgradeJobModelSnapshotAction.class), - usageAction, + new ActionHandler<>(PutTrainedModelAliasAction.INSTANCE, TransportPutTrainedModelAliasAction.class), + usageAction, infoAction); } @@ -1121,6 +1128,13 @@ public List getNamedXContent() { namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers()); + namedXContent.add( + new NamedXContentRegistry.Entry( + Metadata.Custom.class, + new ParseField(ModelAliasMetadata.NAME), + ModelAliasMetadata::fromXContent + ) + ); return namedXContent; } @@ -1131,6 +1145,8 @@ public List getNamedWriteables() { // Custom metadata namedWriteables.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, "ml", MlMetadata::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(NamedDiff.class, "ml", MlMetadata.MlMetadataDiff::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, ModelAliasMetadata.NAME, ModelAliasMetadata::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelAliasMetadata.NAME, ModelAliasMetadata::readDiffFrom)); // Persistent tasks params namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.DATAFEED_TASK_NAME, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java index 4820912b99bc8..9f51f41f83b57 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java @@ -14,10 +14,12 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction; +import org.elasticsearch.cluster.AckedClusterStateUpdateTask; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.ingest.IngestMetadata; @@ -29,11 +31,15 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; +import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; +import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -80,13 +86,60 @@ protected void masterOperation(Task task, return; } - trainedModelProvider.deleteTrainedModel(request.getId(), ActionListener.wrap( - r -> { - auditor.info(request.getId(), "trained model deleted"); - listener.onResponse(AcknowledgedResponse.TRUE); - }, + final ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(state); + final List modelAliases = new ArrayList<>(); + for (Map.Entry modelAliasEntry : currentMetadata.modelAliases().entrySet()) { + if (modelAliasEntry.getValue().getModelId().equals(id)) { + modelAliases.add(modelAliasEntry.getKey()); + } + } + for (String modelAlias : modelAliases) { + if (referencedModels.contains(modelAlias)) { + listener.onFailure(new ElasticsearchStatusException( + "Cannot delete model [{}] as it has a model_alias [{}] that is still referenced by ingest processors", + RestStatus.CONFLICT, + id, + modelAlias)); + return; + } + } + + ActionListener nameDeletionListener = ActionListener.wrap( + ack -> trainedModelProvider.deleteTrainedModel(request.getId(), ActionListener.wrap( + r -> { + auditor.info(request.getId(), "trained model deleted"); + listener.onResponse(AcknowledgedResponse.TRUE); + }, + listener::onFailure + )), + listener::onFailure - )); + ); + + // No reason to update cluster state, simply delete the model + if (modelAliases.isEmpty()) { + nameDeletionListener.onResponse(AcknowledgedResponse.of(true)); + return; + } + + clusterService.submitStateUpdateTask("delete-trained-model-alias", new AckedClusterStateUpdateTask(request, nameDeletionListener) { + @Override + public ClusterState execute(final ClusterState currentState) { + final ClusterState.Builder builder = ClusterState.builder(currentState); + final ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(currentState); + if (currentMetadata.modelAliases().isEmpty()) { + return currentState; + } + final Map newMetadata = new HashMap<>(currentMetadata.modelAliases()); + logger.info("[{}] delete model model_aliases {}", request.getId(), modelAliases); + modelAliases.forEach(newMetadata::remove); + final ModelAliasMetadata modelAliasMetadata = new ModelAliasMetadata(newMetadata); + builder.metadata(Metadata.builder(currentState.getMetadata()) + .putCustom(ModelAliasMetadata.NAME, modelAliasMetadata) + .build()); + return builder.build(); + } + }); } private Set getReferencedModelKeys(IngestMetadata ingestMetadata) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java index f8d78db353634..201fa4a353519 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.tasks.Task; @@ -18,22 +19,27 @@ import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Response; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import java.util.Collections; import java.util.HashSet; +import java.util.Map; import java.util.Set; public class TransportGetTrainedModelsAction extends HandledTransportAction { private final TrainedModelProvider provider; + private final ClusterService clusterService; @Inject public TransportGetTrainedModelsAction(TransportService transportService, ActionFilters actionFilters, + ClusterService clusterService, TrainedModelProvider trainedModelProvider) { super(GetTrainedModelsAction.NAME, transportService, actionFilters, GetTrainedModelsAction.Request::new); this.provider = trainedModelProvider; + this.clusterService = clusterService; } @Override @@ -41,7 +47,7 @@ protected void doExecute(Task task, Request request, ActionListener li Response.Builder responseBuilder = Response.builder(); - ActionListener>> idExpansionListener = ActionListener.wrap( + ActionListener>>> idExpansionListener = ActionListener.wrap( totalAndIds -> { responseBuilder.setTotalCount(totalAndIds.v1()); @@ -58,8 +64,10 @@ protected void doExecute(Task task, Request request, ActionListener li } if (request.getIncludes().isIncludeModelDefinition()) { + Map.Entry> modelIdAndAliases = totalAndIds.v2().entrySet().iterator().next(); provider.getTrainedModel( - totalAndIds.v2().iterator().next(), + modelIdAndAliases.getKey(), + modelIdAndAliases.getValue(), request.getIncludes(), ActionListener.wrap( config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()), @@ -80,11 +88,11 @@ protected void doExecute(Task task, Request request, ActionListener li }, listener::onFailure ); - provider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), new HashSet<>(request.getTags()), + ModelAliasMetadata.fromState(clusterService.state()), idExpansionListener); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java index 945e621daa601..4e8f0a8783a0d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java @@ -20,6 +20,7 @@ import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.metrics.CounterMetric; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.ingest.IngestStats; @@ -28,6 +29,7 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; +import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -42,6 +44,7 @@ import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; @@ -73,6 +76,7 @@ protected void doExecute(Task task, GetTrainedModelsStatsAction.Request request, ActionListener listener) { + final ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(clusterService.state()); GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder(); ActionListener> inferenceStatsListener = ActionListener.wrap( @@ -84,20 +88,30 @@ protected void doExecute(Task task, ActionListener nodesStatsListener = ActionListener.wrap( nodesStatsResponse -> { + Set allPossiblePipelineReferences = responseBuilder.getExpandedIdsWithAliases() + .entrySet() + .stream() + .flatMap(entry -> Stream.concat(entry.getValue().stream(), Stream.of(entry.getKey()))) + .collect(Collectors.toSet()); + Map> pipelineIdsByModelIdsOrAliases = pipelineIdsByModelIdsOrAliases(clusterService.state(), + ingestService, + allPossiblePipelineReferences); Map modelIdIngestStats = inferenceIngestStatsByModelId(nodesStatsResponse, - pipelineIdsByModelIds(clusterService.state(), - ingestService, - responseBuilder.getExpandedIds())); + currentMetadata, + pipelineIdsByModelIdsOrAliases + ); responseBuilder.setIngestStatsByModelId(modelIdIngestStats); - trainedModelProvider.getInferenceStats(responseBuilder.getExpandedIds().toArray(new String[0]), inferenceStatsListener); + trainedModelProvider.getInferenceStats( + responseBuilder.getExpandedIdsWithAliases().keySet().toArray(new String[0]), + inferenceStatsListener + ); }, listener::onFailure ); - ActionListener>> idsListener = ActionListener.wrap( + ActionListener>>> idsListener = ActionListener.wrap( tuple -> { - responseBuilder.setExpandedIds(tuple.v2()) - .setTotalModelCount(tuple.v1()); + responseBuilder.setExpandedIdsWithAliases(tuple.v2()).setTotalModelCount(tuple.v1()); String[] ingestNodes = ingestNodes(clusterService.state()); NodesStatsRequest nodesStatsRequest = new NodesStatsRequest(ingestNodes).clear() .addMetric(NodesStatsRequest.Metric.INGEST.metricName()); @@ -105,27 +119,36 @@ protected void doExecute(Task task, }, listener::onFailure ); - trainedModelProvider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), Collections.emptySet(), + currentMetadata, idsListener); } static Map inferenceIngestStatsByModelId(NodesStatsResponse response, + ModelAliasMetadata currentMetadata, Map> modelIdToPipelineId) { Map ingestStatsMap = new HashMap<>(); - - modelIdToPipelineId.forEach((modelId, pipelineIds) -> { + Map> trueModelIdToPipelines = modelIdToPipelineId.entrySet() + .stream() + .collect(Collectors.toMap( + entry -> { + String maybeModelId = currentMetadata.getModelId(entry.getKey()); + return maybeModelId == null ? entry.getKey() : maybeModelId; + }, + Map.Entry::getValue, + Sets::union + )); + trueModelIdToPipelines.forEach((modelId, pipelineIds) -> { List collectedStats = response.getNodes() .stream() .map(nodeStats -> ingestStatsForPipelineIds(nodeStats, pipelineIds)) .collect(Collectors.toList()); ingestStatsMap.put(modelId, mergeStats(collectedStats)); }); - return ingestStatsMap; } @@ -139,7 +162,7 @@ static String[] ingestNodes(final ClusterState clusterState) { return ingestNodes; } - static Map> pipelineIdsByModelIds(ClusterState state, IngestService ingestService, Set modelIds) { + static Map> pipelineIdsByModelIdsOrAliases(ClusterState state, IngestService ingestService, Set modelIds) { IngestMetadata ingestMetadata = state.metadata().custom(IngestMetadata.TYPE); Map> pipelineIdsByModelIds = new HashMap<>(); if (ingestMetadata == null) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java index d50ba9fd6addc..8ba42634b68f7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java @@ -69,7 +69,9 @@ protected void doExecute(Task task, Request request, ActionListener li typedChainTaskExecutor.execute(ActionListener.wrap( inferenceResultsInterfaces -> { model.release(); - listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces).build()); + listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces) + .setModelId(model.getModelId()) + .build()); }, e -> { model.release(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java index 7423083851956..fed61310f1c35 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java @@ -38,6 +38,7 @@ import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import java.io.IOException; @@ -123,6 +124,13 @@ protected void masterOperation(Task task, .setEstimatedHeapMemory(request.getTrainedModelConfig().getModelDefinition().ramBytesUsed()) .setEstimatedOperations(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().estimatedNumOperations()) .build(); + if (ModelAliasMetadata.fromState(state).getModelId(trainedModelConfig.getModelId()) != null) { + listener.onFailure(ExceptionsHelper.badRequestException( + "requested model_id [{}] is the same as an existing model_alias. Model model_aliases and ids must be unique", + request.getTrainedModelConfig().getModelId() + )); + return; + } ActionListener tagsModelIdCheckListener = ActionListener.wrap( r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAliasAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAliasAction.java new file mode 100644 index 0000000000000..7e4ea2660fc3f --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAliasAction.java @@ -0,0 +1,209 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction; +import org.elasticsearch.cluster.AckedClusterStateUpdateTask; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.logging.HeaderWarning; +import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.XPackField; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAliasAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.function.Predicate; + +import static org.elasticsearch.xpack.core.ml.job.messages.Messages.TRAINED_MODEL_INPUTS_DIFFER_SIGNIFICANTLY; + +public class TransportPutTrainedModelAliasAction extends AcknowledgedTransportMasterNodeAction { + + private static final Logger logger = LogManager.getLogger(TransportPutTrainedModelAliasAction.class); + + private final XPackLicenseState licenseState; + private final TrainedModelProvider trainedModelProvider; + private final InferenceAuditor auditor; + + @Inject + public TransportPutTrainedModelAliasAction( + TransportService transportService, + TrainedModelProvider trainedModelProvider, + ClusterService clusterService, + ThreadPool threadPool, + XPackLicenseState licenseState, + ActionFilters actionFilters, + InferenceAuditor auditor, + IndexNameExpressionResolver indexNameExpressionResolver) { + super( + PutTrainedModelAliasAction.NAME, + transportService, + clusterService, + threadPool, + actionFilters, + PutTrainedModelAliasAction.Request::new, + indexNameExpressionResolver, + ThreadPool.Names.SAME + ); + this.licenseState = licenseState; + this.trainedModelProvider = trainedModelProvider; + this.auditor = auditor; + } + + @Override + protected void masterOperation( + Task task, + PutTrainedModelAliasAction.Request request, + ClusterState state, + ActionListener listener + ) throws Exception { + final boolean mlSupported = licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING); + final Predicate isLicensed = (model) -> mlSupported || licenseState.isAllowedByLicense(model.getLicenseLevel()); + final String oldModelId = ModelAliasMetadata.fromState(state).getModelId(request.getModelAlias()); + + if (oldModelId != null && (request.isReassign() == false)) { + listener.onFailure(ExceptionsHelper.badRequestException( + "cannot assign model_alias [{}] to model_id [{}] as model_alias already refers to [{}]. " + + + "Set parameter [reassign] to [true] if model_alias should be reassigned.", + request.getModelAlias(), + request.getModelId(), + oldModelId)); + return; + } + Set modelIds = new HashSet<>(); + modelIds.add(request.getModelAlias()); + modelIds.add(request.getModelId()); + if (oldModelId != null) { + modelIds.add(oldModelId); + } + trainedModelProvider.getTrainedModels(modelIds, GetTrainedModelsAction.Includes.empty(), true, ActionListener.wrap( + models -> { + TrainedModelConfig newModel = null; + TrainedModelConfig oldModel = null; + for (TrainedModelConfig config : models) { + if (config.getModelId().equals(request.getModelId())) { + newModel = config; + } + if (config.getModelId().equals(oldModelId)) { + oldModel = config; + } + if (config.getModelId().equals(request.getModelAlias())) { + listener.onFailure( + ExceptionsHelper.badRequestException("model_alias cannot be the same as an existing trained model_id") + ); + return; + } + } + if (newModel == null) { + listener.onFailure( + ExceptionsHelper.missingTrainedModel(request.getModelId()) + ); + return; + } + if (isLicensed.test(newModel) == false) { + listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); + return; + } + // if old model is null, none of these validations matter + // we should still allow reassignment even if the old model was some how deleted and the alias still refers to it + if (oldModel != null) { + // validate inference configs are the same type. Moving an alias from regression -> classification seems dangerous + if (newModel.getInferenceConfig() != null && oldModel.getInferenceConfig() != null) { + if (newModel.getInferenceConfig().getName().equals(oldModel.getInferenceConfig().getName()) == false) { + listener.onFailure( + ExceptionsHelper.badRequestException( + "cannot reassign model_alias [{}] to model [{}] " + + "with inference config type [{}] from model [{}] with type [{}]", + request.getModelAlias(), + newModel.getModelId(), + newModel.getInferenceConfig().getName(), + oldModel.getModelId(), + oldModel.getInferenceConfig().getName() + ) + ); + return; + } + } + + Set oldInputFields = new HashSet<>(oldModel.getInput().getFieldNames()); + Set newInputFields = new HashSet<>(newModel.getInput().getFieldNames()); + // TODO should we fail in this case??? + if (Sets.difference(oldInputFields, newInputFields).size() > (oldInputFields.size() / 2) + || Sets.intersection(newInputFields, oldInputFields).size() < (oldInputFields.size() / 2)) { + String warning = Messages.getMessage( + TRAINED_MODEL_INPUTS_DIFFER_SIGNIFICANTLY, + request.getModelId(), + oldModelId); + auditor.warning(oldModelId, warning); + logger.warn("[{}] {}", oldModelId, warning); + HeaderWarning.addWarning(warning); + } + } + clusterService.submitStateUpdateTask("update-model-alias", new AckedClusterStateUpdateTask(request, listener) { + @Override + public ClusterState execute(final ClusterState currentState) { + return updateModelAlias(currentState, request); + } + }); + + }, + listener::onFailure + )); + } + + static ClusterState updateModelAlias(final ClusterState currentState, final PutTrainedModelAliasAction.Request request) { + final ClusterState.Builder builder = ClusterState.builder(currentState); + final ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(currentState); + String currentModelId = currentMetadata.getModelId(request.getModelAlias()); + final Map newMetadata = new HashMap<>(currentMetadata.modelAliases()); + if (currentModelId == null) { + logger.info("creating new model_alias [{}] for model [{}]", request.getModelAlias(), request.getModelId()); + } else { + logger.info( + "updating model_alias [{}] to refer to model [{}] from model [{}]", + request.getModelAlias(), + request.getModelId(), + currentModelId + ); + } + newMetadata.put(request.getModelAlias(), new ModelAliasMetadata.ModelAliasEntry(request.getModelId())); + final ModelAliasMetadata modelAliasMetadata = new ModelAliasMetadata(newMetadata); + builder.metadata(Metadata.builder(currentState.getMetadata()).putCustom(ModelAliasMetadata.NAME, modelAliasMetadata).build()); + return builder.build(); + } + + @Override + protected ClusterBlockException checkBlock(PutTrainedModelAliasAction.Request request, ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ModelAliasMetadata.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ModelAliasMetadata.java new file mode 100644 index 0000000000000..ecf28a78ec307 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ModelAliasMetadata.java @@ -0,0 +1,222 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference; + +import org.elasticsearch.Version; +import org.elasticsearch.cluster.AbstractDiffable; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.Diff; +import org.elasticsearch.cluster.DiffableUtils; +import org.elasticsearch.cluster.NamedDiff; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * Custom {@link Metadata} implementation for storing a map of model aliases that point to model IDs + */ +public class ModelAliasMetadata implements Metadata.Custom { + + public static final String NAME = "trained_model_alias"; + + public static final ModelAliasMetadata EMPTY = new ModelAliasMetadata(new HashMap<>()); + + public static ModelAliasMetadata fromState(ClusterState cs) { + ModelAliasMetadata modelAliasMetadata = cs.metadata().custom(NAME); + return modelAliasMetadata == null ? EMPTY : modelAliasMetadata; + } + + public static NamedDiff readDiffFrom(StreamInput in) throws IOException { + return new ModelAliasMetadataDiff(in); + } + + private static final ParseField MODEL_ALIASES = new ParseField("model_aliases"); + private static final ParseField MODEL_ID = new ParseField("model_id"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + // to protect BWC serialization + true, + args -> new ModelAliasMetadata((Map)args[0]) + ); + + static { + PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> { + Map modelAliases = new HashMap<>(); + while (p.nextToken() != XContentParser.Token.END_OBJECT) { + String modelAlias = p.currentName(); + modelAliases.put(modelAlias, ModelAliasEntry.fromXContent(p)); + } + return modelAliases; + }, MODEL_ALIASES); + } + + public static ModelAliasMetadata fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final Map modelAliases; + + public ModelAliasMetadata(Map modelAliases) { + this.modelAliases = Collections.unmodifiableMap(modelAliases); + } + + public ModelAliasMetadata(StreamInput in) throws IOException { + this.modelAliases = Collections.unmodifiableMap(in.readMap(StreamInput::readString, ModelAliasEntry::new)); + } + + public Map modelAliases() { + return modelAliases; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(MODEL_ALIASES.getPreferredName()); + for (Map.Entry modelAliasEntry : modelAliases.entrySet()) { + builder.field(modelAliasEntry.getKey(), modelAliasEntry.getValue()); + } + builder.endObject(); + return builder; + } + + @Override + public Diff diff(Metadata.Custom previousState) { + return new ModelAliasMetadataDiff((ModelAliasMetadata) previousState, this); + } + + @Override + public EnumSet context() { + return Metadata.ALL_CONTEXTS; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public Version getMinimalSupportedVersion() { + // TODO change after backport + return Version.V_8_0_0; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(this.modelAliases, StreamOutput::writeString, (stream, val) -> val.writeTo(stream)); + } + + public String getModelId(String modelAlias) { + ModelAliasEntry entry = this.modelAliases.get(modelAlias); + if (entry == null) { + return null; + } + return entry.modelId; + } + + static class ModelAliasMetadataDiff implements NamedDiff { + + final Diff> modelAliasesDiff; + + ModelAliasMetadataDiff(ModelAliasMetadata before, ModelAliasMetadata after) { + this.modelAliasesDiff = DiffableUtils.diff(before.modelAliases, after.modelAliases, DiffableUtils.getStringKeySerializer()); + } + + ModelAliasMetadataDiff(StreamInput in) throws IOException { + this.modelAliasesDiff = DiffableUtils.readJdkMapDiff(in, DiffableUtils.getStringKeySerializer(), + ModelAliasEntry::new, ModelAliasEntry::readDiffFrom); + } + + @Override + public Metadata.Custom apply(Metadata.Custom part) { + return new ModelAliasMetadata(modelAliasesDiff.apply(((ModelAliasMetadata) part).modelAliases)); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + modelAliasesDiff.writeTo(out); + } + } + + public static class ModelAliasEntry extends AbstractDiffable implements ToXContentObject { + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "model_alias_metadata_alias_entry", + // to protect BWC serialization + true, + args -> new ModelAliasEntry((String)args[0]) + ); + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID); + } + + private static Diff readDiffFrom(StreamInput in) throws IOException { + return readDiffFrom(ModelAliasEntry::new, in); + } + + private static ModelAliasEntry fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final String modelId; + + public ModelAliasEntry(String modelId) { + this.modelId = modelId; + } + + ModelAliasEntry(StreamInput in) throws IOException { + this.modelId = in.readString(); + } + + public String getModelId() { + return modelId; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MODEL_ID.getPreferredName(), modelId); + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ModelAliasEntry modelAliasEntry = (ModelAliasEntry) o; + return Objects.equals(modelId, modelAliasEntry.modelId); + } + + @Override + public int hashCode() { + return Objects.hash(modelId); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index c1e6b69517a5c..761154fa54341 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -157,7 +157,12 @@ void mutateDocument(InternalInferModelAction.Response response, IngestDocument i throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR); } assert response.getInferenceResults().size() == 1; - InferenceResults.writeResult(response.getInferenceResults().get(0), ingestDocument, targetField, modelId); + InferenceResults.writeResult( + response.getInferenceResults().get(0), + ingestDocument, + targetField, + response.getModelId() != null ? response.getModelId() : modelId + ); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index f9a922b286a79..01caec5d0e166 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.cache.Cache; import org.elasticsearch.common.cache.CacheBuilder; +import org.elasticsearch.common.cache.CacheLoader; import org.elasticsearch.common.cache.RemovalNotification; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; @@ -37,12 +38,14 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import java.util.ArrayDeque; +import java.util.Collections; import java.util.EnumSet; import java.util.HashMap; import java.util.HashSet; @@ -50,6 +53,7 @@ import java.util.Map; import java.util.Queue; import java.util.Set; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; /** @@ -108,11 +112,14 @@ private ModelAndConsumer(LocalModel model, Consumer consumer) { } } - private static final Logger logger = LogManager.getLogger(ModelLoadingService.class); private final TrainedModelStatsService modelStatsService; private final Cache localModelCache; + // Referenced models can be model aliases or IDs private final Set referencedModels = new HashSet<>(); + private final Map modelAliasToId = new HashMap<>(); + private final Map> modelIdToModelAliases = new HashMap<>(); + private final Map> modelIdToUpdatedModelAliases = new HashMap<>(); private final Map>> loadingListeners = new HashMap<>(); private final TrainedModelProvider provider; private final Set shouldNotAudit; @@ -148,8 +155,13 @@ public ModelLoadingService(TrainedModelProvider trainedModelProvider, this.trainedModelCircuitBreaker = ExceptionsHelper.requireNonNull(trainedModelCircuitBreaker, "trainedModelCircuitBreaker"); } + // for testing + String getModelId(String modelIdOrAlias) { + return modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias); + } + boolean isModelCached(String modelId) { - return localModelCache.get(modelId) != null; + return localModelCache.get(modelAliasToId.getOrDefault(modelId, modelId)) != null; } /** @@ -195,11 +207,12 @@ public void getModelForSearch(String modelId, ActionListener modelAc * The main difference being that models for search are always cached whereas pipeline models * are only cached if they are referenced by an ingest pipeline * - * @param modelId the model to get + * @param modelIdOrAlias the model id or model alias to get * @param consumer which feature is requesting the model * @param modelActionListener the listener to alert when the model has been retrieved. */ - private void getModel(String modelId, Consumer consumer, ActionListener modelActionListener) { + private void getModel(String modelIdOrAlias, Consumer consumer, ActionListener modelActionListener) { + final String modelId = modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias); ModelAndConsumer cachedModel = localModelCache.get(modelId); if (cachedModel != null) { cachedModel.consumers.add(consumer); @@ -210,12 +223,16 @@ private void getModel(String modelId, Consumer consumer, ActionListener new ParameterizedMessage("[{}] loaded from cache", modelId)); + logger.trace(() -> new ParameterizedMessage("[{}] (model_alias [{}]) loaded from cache", modelId, modelIdOrAlias)); return; } - if (loadModelIfNecessary(modelId, consumer, modelActionListener)) { - logger.trace(() -> new ParameterizedMessage("[{}] is loading or loaded, added new listener to queue", modelId)); + if (loadModelIfNecessary(modelIdOrAlias, consumer, modelActionListener)) { + logger.trace(() -> new ParameterizedMessage( + "[{}] (model_alias [{}]) is loading or loaded, added new listener to queue", + modelId, + modelIdOrAlias + )); } } @@ -224,14 +241,15 @@ private void getModel(String modelId, Consumer consumer, ActionListener modelActionListener) { + private boolean loadModelIfNecessary(String modelIdOrAlias, Consumer consumer, ActionListener modelActionListener) { synchronized (loadingListeners) { + final String modelId = modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias); ModelAndConsumer cachedModel = localModelCache.get(modelId); if (cachedModel != null) { cachedModel.consumers.add(consumer); @@ -257,13 +275,21 @@ private boolean loadModelIfNecessary(String modelId, Consumer consumer, ActionLi if (Consumer.PIPELINE == consumer && referencedModels.contains(modelId) == false) { // The model is requested by a pipeline but not referenced by any ingest pipelines. // This means it is a simulate call and the model should not be cached + logger.trace(() -> new ParameterizedMessage( + "[{}] (model_alias [{}]) not actively loading, eager loading without cache", + modelId, + modelIdOrAlias + )); loadWithoutCaching(modelId, modelActionListener); } else { - logger.trace(() -> new ParameterizedMessage("[{}] attempting to load and cache", modelId)); + logger.trace(() -> new ParameterizedMessage( + "[{}] (model_alias [{}]) attempting to load and cache", + modelId, + modelIdOrAlias + )); loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener)); loadModel(modelId, consumer); } - return false; } // synchronized (loadingListeners) } @@ -304,7 +330,6 @@ private void loadModel(String modelId, Consumer consumer) { private void loadWithoutCaching(String modelId, ActionListener modelActionListener) { // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called // by a simulated pipeline - logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId)); provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), ActionListener.wrap( trainedModelConfig -> { // Verify we can pull the model into memory without causing OOM @@ -377,34 +402,41 @@ private void handleLoadSuccess(String modelId, trainedModelConfig.getLicenseLevel(), modelStatsService, trainedModelCircuitBreaker); - boolean modelAcquired = false; + final ModelAndConsumerLoader modelAndConsumerLoader = new ModelAndConsumerLoader(new ModelAndConsumer(loadedModel, consumer)); synchronized (loadingListeners) { - listeners = loadingListeners.remove(modelId); - // if there are no listeners, simply release and leave - if (listeners == null) { - loadedModel.release(); - return; - } - + populateNewModelAlias(modelId); // If the model is referenced, that means it is currently in a pipeline somewhere // Also, if the consume is a search consumer, we should always cache it - if (referencedModels.contains(modelId) || consumer.equals(Consumer.SEARCH)) { - // temporarily increase the reference count before adding to - // the cache in case the model is evicted before the listeners - // are called in which case acquire() would throw. - loadedModel.acquire(); - localModelCache.put(modelId, new ModelAndConsumer(loadedModel, consumer)); + if (referencedModels.contains(modelId) + || Sets.haveNonEmptyIntersection(modelIdToModelAliases.getOrDefault(modelId, new HashSet<>()), referencedModels) + || consumer.equals(Consumer.SEARCH)) { + try { + // The local model may already be in cache. If it is, we don't bother adding it to cache. + // If it isn't, we flip an `isLoaded` flag, and increment the model counter to make sure if it is evicted + // between now and when the listeners access it, the circuit breaker reflects actual usage. + localModelCache.computeIfAbsent(modelId, modelAndConsumerLoader); + } catch (ExecutionException ee) { + logger.warn(() -> new ParameterizedMessage("[{}] threw when attempting add to cache", modelId), ee); + } shouldNotAudit.remove(modelId); - modelAcquired = true; + } + listeners = loadingListeners.remove(modelId); + // if there are no listeners, we should just exit + if (listeners == null) { + // If we newly added it into cache, release the model so that the circuit breaker can still accurately keep track + // of memory + if(modelAndConsumerLoader.isLoaded()) { + loadedModel.release(); + } + return; } } // synchronized (loadingListeners) for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { loadedModel.acquire(); listener.onResponse(loadedModel); } - // account for the acquire in the synchronized block above - // We cannot simply utilize the same conditionals as `referencedModels` could have changed once we exited the synchronized block - if (modelAcquired) { + // account for the acquire in the synchronized block above if the model was loaded into the cache + if (modelAndConsumerLoader.isLoaded()) { loadedModel.release(); } } @@ -413,6 +445,7 @@ private void handleLoadFailure(String modelId, Exception failure) { Queue> listeners; synchronized (loadingListeners) { listeners = loadingListeners.remove(modelId); + populateNewModelAlias(modelId); if (listeners == null) { return; } @@ -424,6 +457,20 @@ private void handleLoadFailure(String modelId, Exception failure) { } } + private void populateNewModelAlias(String modelId) { + Set newModelAliases = modelIdToUpdatedModelAliases.remove(modelId); + if (newModelAliases != null && newModelAliases.isEmpty() == false) { + logger.trace(() -> new ParameterizedMessage( + "[{}] model is now loaded, setting new model_aliases {}", + modelId, + newModelAliases + )); + for (String modelAlias: newModelAliases) { + modelAliasToId.put(modelAlias, modelId); + } + } + } + private void cacheEvictionListener(RemovalNotification notification) { try { if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) { @@ -438,12 +485,15 @@ private void cacheEvictionListener(RemovalNotification INFERENCE_MODEL_CACHE_TTL.getKey()); auditIfNecessary(notification.getKey(), msg); } - - logger.trace(() -> new ParameterizedMessage("Persisting stats for evicted model [{}]", - notification.getValue().model.getModelId())); + String modelId = modelAliasToId.getOrDefault(notification.getKey(), notification.getKey()); + logger.trace(() -> new ParameterizedMessage( + "Persisting stats for evicted model [{}] (model_aliases {})", + modelId, + modelIdToModelAliases.getOrDefault(modelId, new HashSet<>()) + )); // If the model is no longer referenced, flush the stats to persist as soon as possible - notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false); + notification.getValue().model.persistStats(referencedModels.contains(modelId) == false); } finally { notification.getValue().model.release(); } @@ -451,46 +501,112 @@ private void cacheEvictionListener(RemovalNotification @Override public void clusterChanged(ClusterChangedEvent event) { - // If ingest data has not changed or if the current node is not an ingest node, don't bother caching models - if (event.changedCustomMetadataSet().contains(IngestMetadata.TYPE) == false || - event.state().nodes().getLocalNode().isIngestNode() == false) { + final boolean prefetchModels = event.state().nodes().getLocalNode().isIngestNode(); + // If we are not prefetching models and there were no model alias changes, don't bother handling the changes + if ((prefetchModels == false) + && (event.changedCustomMetadataSet().contains(IngestMetadata.TYPE) == false) + && (event.changedCustomMetadataSet().contains(ModelAliasMetadata.NAME) == false)) { return; } ClusterState state = event.state(); IngestMetadata currentIngestMetadata = state.metadata().custom(IngestMetadata.TYPE); - Set allReferencedModelKeys = getReferencedModelKeys(currentIngestMetadata); - if (allReferencedModelKeys.equals(referencedModels)) { - return; - } - Set referencedModelsBeforeClusterState = null; + Set allReferencedModelKeys = event.changedCustomMetadataSet().contains(IngestMetadata.TYPE) ? + getReferencedModelKeys(currentIngestMetadata) : + new HashSet<>(referencedModels); + Set referencedModelsBeforeClusterState; Set loadingModelBeforeClusterState = null; - Set removedModels = null; + Set removedModels; + Map> addedModelViaAliases = new HashMap<>(); + Map> oldIdToAliases; synchronized (loadingListeners) { + oldIdToAliases = new HashMap<>(modelIdToModelAliases); + Map changedAliases = gatherLazyChangedAliasesAndUpdateModelAliases( + event, + prefetchModels, + allReferencedModelKeys + ); + + // if we are not prefetching, exit now. + if (prefetchModels == false) { + return; + } + referencedModelsBeforeClusterState = new HashSet<>(referencedModels); if (logger.isTraceEnabled()) { loadingModelBeforeClusterState = new HashSet<>(loadingListeners.keySet()); } removedModels = Sets.difference(referencedModelsBeforeClusterState, allReferencedModelKeys); - // Remove all cached models that are not referenced by any processors - // and are not used in search - removedModels.forEach(modelId -> { - ModelAndConsumer modelAndConsumer = localModelCache.get(modelId); - if (modelAndConsumer != null && modelAndConsumer.consumers.contains(Consumer.SEARCH) == false) { - localModelCache.invalidate(modelId); - } - }); // Remove the models that are no longer referenced referencedModels.removeAll(removedModels); shouldNotAudit.removeAll(removedModels); + // Remove all cached models that are not referenced by any processors + // and are not used in search + for (String modelAliasOrId : removedModels) { + String modelId = changedAliases.getOrDefault(modelAliasOrId, modelAliasToId.getOrDefault(modelAliasOrId, modelAliasOrId)); + // If the "old" model_alias is referenced, we don't want to invalidate. This way the model that now has the model_alias + // can be loaded in first + boolean oldModelAliasesNotReferenced = Sets.haveEmptyIntersection(referencedModels, + oldIdToAliases.getOrDefault(modelId, Collections.emptySet())); + // If the model itself is referenced, we shouldn't evict. + boolean modelIsNotReferenced = referencedModels.contains(modelId) == false; + // If a model_alias change causes it to NOW be referenced, we shouldn't attempt to evict it + boolean newModelAliasesNotReferenced = Sets.haveEmptyIntersection(referencedModels, + modelIdToModelAliases.getOrDefault(modelId, Collections.emptySet())); + if (oldModelAliasesNotReferenced && newModelAliasesNotReferenced && modelIsNotReferenced) { + ModelAndConsumer modelAndConsumer = localModelCache.get(modelId); + if (modelAndConsumer != null && modelAndConsumer.consumers.contains(Consumer.SEARCH) == false) { + logger.trace("[{} ({})] invalidated from cache", modelId, modelAliasOrId); + localModelCache.invalidate(modelId); + } + } + } // Remove all that are still referenced, i.e. the intersection of allReferencedModelKeys and referencedModels allReferencedModelKeys.removeAll(referencedModels); + for (String newlyReferencedModel : allReferencedModelKeys) { + // check if the model_alias has changed in this round + String modelId = changedAliases.getOrDefault( + newlyReferencedModel, + // If the model_alias hasn't changed, get the model id IF it is a model_alias, otherwise we assume it is an id + modelAliasToId.getOrDefault( + newlyReferencedModel, + newlyReferencedModel + ) + ); + // Verify that it isn't an old model id but just a new model_alias + if (referencedModels.contains(modelId) == false) { + addedModelViaAliases.computeIfAbsent(modelId, k -> new HashSet<>()).add(newlyReferencedModel); + } + } + // For any previously referenced model, the model_alias COULD have changed, so it is actually a NEWLY referenced model + for (Map.Entry modelAliasAndId : changedAliases.entrySet()) { + String modelAlias = modelAliasAndId.getKey(); + String modelId = modelAliasAndId.getValue(); + if (referencedModels.contains(modelAlias)) { + // we need to load the underlying model since its model_alias is referenced + addedModelViaAliases.computeIfAbsent(modelId, k -> new HashSet<>()).add(modelAlias); + // If we are in cache, keep the old translation for now, it will be updated later + String oldModelId = modelAliasToId.get(modelAlias); + if (oldModelId != null && localModelCache.get(oldModelId) != null) { + modelIdToUpdatedModelAliases.computeIfAbsent(modelId, k -> new HashSet<>()).add(modelAlias); + } else { + // If we are not cached, might as well add the translation right away as new callers will have to load + // from disk anyways. + modelAliasToId.put(modelAlias, modelId); + } + } else { + // Add model_alias and id here, since the model_alias wasn't previously referenced, + // no reason to wait on updating the model_alias -> model_id mapping + modelAliasToId.put(modelAlias, modelId); + } + } + // Gather ALL currently referenced model ids referencedModels.addAll(allReferencedModelKeys); // Populate loadingListeners key so we know that we are currently loading the model - for (String modelId : allReferencedModelKeys) { + for (String modelId : addedModelViaAliases.keySet()) { loadingListeners.computeIfAbsent(modelId, (s) -> new ArrayDeque<>()); } } // synchronized (loadingListeners) @@ -503,9 +619,51 @@ public void clusterChanged(ClusterChangedEvent event) { logger.trace("cluster state event changed referenced models: before {} after {}", referencedModelsBeforeClusterState, referencedModels); } + if (oldIdToAliases.equals(modelIdToModelAliases) == false) { + logger.trace("model id to alias mappings changed. before {} after {}. Model alias to IDs {}", + oldIdToAliases, + modelIdToModelAliases, + modelAliasToId); + } + if (addedModelViaAliases.isEmpty() == false) { + logger.trace("adding new models via model_aliases and ids: {}", addedModelViaAliases); + } + if (modelIdToUpdatedModelAliases.isEmpty() == false) { + logger.trace("delayed model aliases to update {}", modelIdToModelAliases); + } } removedModels.forEach(this::auditUnreferencedModel); - loadModelsForPipeline(allReferencedModelKeys); + loadModelsForPipeline(addedModelViaAliases.keySet()); + } + + private Map gatherLazyChangedAliasesAndUpdateModelAliases(ClusterChangedEvent event, + boolean prefetchModels, + Set allReferencedModelKeys) { + Map changedAliases = new HashMap<>(); + if (event.changedCustomMetadataSet().contains(ModelAliasMetadata.NAME)) { + final Map modelAliasesToIds = new HashMap<>( + ModelAliasMetadata.fromState(event.state()).modelAliases() + ); + modelIdToModelAliases.clear(); + for (Map.Entry aliasToId : modelAliasesToIds.entrySet()) { + modelIdToModelAliases.computeIfAbsent(aliasToId.getValue().getModelId(), k -> new HashSet<>()).add(aliasToId.getKey()); + java.lang.String modelId = modelAliasToId.get(aliasToId.getKey()); + if (modelId != null + && modelId.equals(aliasToId.getValue().getModelId()) == false) { + if (prefetchModels && allReferencedModelKeys.contains(aliasToId.getKey())) { + changedAliases.put(aliasToId.getKey(), aliasToId.getValue().getModelId()); + } else { + modelAliasToId.put(aliasToId.getKey(), aliasToId.getValue().getModelId()); + } + } + if (modelId == null) { + modelAliasToId.put(aliasToId.getKey(), aliasToId.getValue().getModelId()); + } + } + Set removedAliases = Sets.difference(modelAliasToId.keySet(), modelAliasesToIds.keySet()); + modelAliasToId.keySet().removeAll(removedAliases); + } + return changedAliases; } private void auditIfNecessary(String modelId, MessageSupplier msg) { @@ -600,4 +758,25 @@ void addModelLoadedListener(String modelId, ActionListener modelLoad }); } } + + private static class ModelAndConsumerLoader implements CacheLoader { + + private boolean loaded; + private final ModelAndConsumer modelAndConsumer; + + ModelAndConsumerLoader(ModelAndConsumer modelAndConsumer) { + this.modelAndConsumer = modelAndConsumer; + } + + boolean isLoaded() { + return loaded; + } + + @Override + public ModelAndConsumer load(String key) throws Exception { + loaded = true; + modelAndConsumer.model.acquire(); + return modelAndConsumer; + } + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index c5e054277cf59..37afd5d193289 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -80,6 +80,7 @@ import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; import java.io.IOException; import java.io.InputStream; @@ -97,6 +98,7 @@ import java.util.Objects; import java.util.Set; import java.util.TreeSet; +import java.util.function.Function; import java.util.stream.Collectors; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; @@ -442,6 +444,13 @@ public void getTrainedModelForInference(final String modelId, final ActionListen public void getTrainedModel(final String modelId, final GetTrainedModelsAction.Includes includes, final ActionListener finalListener) { + getTrainedModel(modelId, Collections.emptySet(), includes, finalListener); + } + + public void getTrainedModel(final String modelId, + final Set modelAliases, + final GetTrainedModelsAction.Includes includes, + final ActionListener finalListener) { if (MODELS_STORED_AS_RESOURCE.contains(modelId)) { try { @@ -455,6 +464,7 @@ public void getTrainedModel(final String modelId, ActionListener getTrainedModelListener = ActionListener.wrap( modelBuilder -> { + modelBuilder.setModelAliases(modelAliases); if ((includes.isIncludeFeatureImportanceBaseline() || includes.isIncludeTotalFeatureImportance() || includes.isIncludeHyperparameters()) == false) { finalListener.onResponse(modelBuilder.build()); @@ -570,6 +580,18 @@ public void getTrainedModel(final String modelId, multiSearchResponseActionListener); } + public void getTrainedModels(Set modelIds, + GetTrainedModelsAction.Includes includes, + boolean allowNoResources, + final ActionListener> finalListener) { + getTrainedModels( + modelIds.stream().collect(Collectors.toMap(Function.identity(), _k -> Collections.emptySet())), + includes, + allowNoResources, + finalListener + ); + } + /** * Gets all the provided trained config model objects * @@ -577,11 +599,15 @@ public void getTrainedModel(final String modelId, * This does no expansion on the ids. * It assumes that there are fewer than 10k. */ - public void getTrainedModels(Set modelIds, + public void getTrainedModels(Map> modelIds, GetTrainedModelsAction.Includes includes, boolean allowNoResources, final ActionListener> finalListener) { - QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0]))); + QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery( + QueryBuilders + .idsQuery() + .addIds(modelIds.keySet().toArray(new String[0])) + ); SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) .addSort(TrainedModelConfig.MODEL_ID.getPreferredName(), SortOrder.ASC) @@ -590,8 +616,8 @@ public void getTrainedModels(Set modelIds, .setSize(modelIds.size()) .request(); List configs = new ArrayList<>(modelIds.size()); - Set modelsInIndex = Sets.difference(modelIds, MODELS_STORED_AS_RESOURCE); - Set modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds); + Set modelsInIndex = Sets.difference(modelIds.keySet(), MODELS_STORED_AS_RESOURCE); + Set modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds.keySet()); for(String modelId : modelsAsResource) { try { configs.add(loadModelFromResource(modelId, true)); @@ -613,12 +639,12 @@ public void getTrainedModels(Set modelIds, if ((includes.isIncludeFeatureImportanceBaseline() || includes.isIncludeTotalFeatureImportance() || includes.isIncludeHyperparameters()) == false) { finalListener.onResponse(modelBuilders.stream() - .map(TrainedModelConfig.Builder::build) + .map(b -> b.setModelAliases(modelIds.get(b.getModelId())).build()) .sorted(Comparator.comparing(TrainedModelConfig::getModelId)) .collect(Collectors.toList())); return; } - this.getTrainedModelMetadata(modelIds, ActionListener.wrap( + this.getTrainedModelMetadata(modelIds.keySet(), ActionListener.wrap( metadata -> finalListener.onResponse(modelBuilders.stream() .map(builder -> { @@ -633,9 +659,8 @@ public void getTrainedModels(Set modelIds, if (includes.isIncludeHyperparameters()) { builder.setHyperparameters(modelMetadata.getHyperparameters()); } - } - return builder.build(); + return builder.setModelAliases(modelIds.get(builder.getModelId())).build(); }) .sorted(Comparator.comparing(TrainedModelConfig::getModelId)) .collect(Collectors.toList())), @@ -679,7 +704,7 @@ public void getTrainedModels(Set modelIds, // We previously expanded the IDs. // If the config has gone missing between then and now we should throw if allowNoResources is false // Otherwise, treat it as if it was never expanded to begin with. - Set missingConfigs = Sets.difference(modelIds, observedIds); + Set missingConfigs = Sets.difference(modelIds.keySet(), observedIds); if (missingConfigs.isEmpty() == false && allowNoResources == false) { getTrainedModelListener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs)); return; @@ -729,8 +754,23 @@ public void expandIds(String idExpression, boolean allowNoResources, PageParams pageParams, Set tags, - ActionListener>> idsListener) { + ModelAliasMetadata modelAliasMetadata, + ActionListener>>> idsListener) { String[] tokens = Strings.tokenizeToStringArray(idExpression, ","); + Set expandedIdsFromAliases = new HashSet<>(); + if (Strings.isAllOrWildcard(tokens) == false) { + for (String token : tokens) { + if (Regex.isSimpleMatchPattern(token)) { + for (String modelAlias : modelAliasMetadata.modelAliases().keySet()) { + if (Regex.simpleMatch(token, modelAlias)) { + expandedIdsFromAliases.add(modelAliasMetadata.getModelId(modelAlias)); + } + } + } else if (modelAliasMetadata.getModelId(token) != null) { + expandedIdsFromAliases.add(modelAliasMetadata.getModelId(token)); + } + } + } Set matchedResourceIds = matchedResourceIds(tokens); Set foundResourceIds; if (tags.isEmpty()) { @@ -744,12 +784,17 @@ public void expandIds(String idExpression, } } } + expandedIdsFromAliases.addAll(Arrays.asList(tokens)); + + // We need to include the translated model alias, and ANY tokens that were not translated + String[] tokensForQuery = expandedIdsFromAliases.toArray(new String[0]); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() .sort(SortBuilders.fieldSort(TrainedModelConfig.MODEL_ID.getPreferredName()) // If there are no resources, there might be no mapping for the id field. // This makes sure we don't get an error if that happens. .unmappedType("long")) - .query(buildExpandIdsQuery(tokens, tags)) + .query(buildExpandIdsQuery(tokensForQuery, tags)) // We "buffer" the from and size to take into account models stored as resources. // This is so we handle the edge cases when the model that is stored as a resource is at the start/end of // a page. @@ -785,9 +830,28 @@ public void expandIds(String idExpression, foundFromDocs.add(idValue.toString()); } } - Set allFoundIds = collectIds(pageParams, foundResourceIds, foundFromDocs); + Map> allFoundIds = collectIds(pageParams, foundResourceIds, foundFromDocs) + .stream() + .collect(Collectors.toMap(Function.identity(), k -> new HashSet<>())); + + // We technically have matched on model tokens and any reversed referenced aliases + // We may end up with "over matching" on the aliases (matching on an alias that was not provided) + // But the expanded ID matcher does not care. + Set matchedTokens = new HashSet<>(allFoundIds.keySet()); + + // We should gather ALL model aliases referenced by the given model IDs + // This way the callers have access to them + modelAliasMetadata.modelAliases().forEach((alias, modelIdEntry) -> { + final String modelId = modelIdEntry.getModelId(); + if (allFoundIds.containsKey(modelId)) { + allFoundIds.get(modelId).add(alias); + matchedTokens.add(alias); + } + }); + + // Reverse lookup to see what model aliases were matched by their found trained model IDs ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, allowNoResources); - requiredMatches.filterMatchedIds(allFoundIds); + requiredMatches.filterMatchedIds(matchedTokens); if (requiredMatches.hasUnmatchedIds()) { idsListener.onFailure(ExceptionsHelper.missingTrainedModel(requiredMatches.unmatchedIdsString())); } else { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAliasAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAliasAction.java new file mode 100644 index 0000000000000..7c440f2cc77a5 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAliasAction.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.ml.rest.inference; + +import static java.util.Collections.singletonList; +import static org.elasticsearch.rest.RestRequest.Method.PUT; + +import java.io.IOException; +import java.util.List; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAliasAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.MachineLearning; + +public class RestPutTrainedModelAliasAction extends BaseRestHandler { + + @Override + public List routes() { + return singletonList( + new Route( + PUT, + MachineLearning.BASE_PATH + + "trained_models/{" + + TrainedModelConfig.MODEL_ID.getPreferredName() + + "}/model_aliases/{" + + PutTrainedModelAliasAction.Request.MODEL_ALIAS + + "}" + + ) + ); + } + + @Override + public String getName() { + return "ml_put_trained_model_alias_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String modelAlias = restRequest.param(PutTrainedModelAliasAction.Request.MODEL_ALIAS); + String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + boolean reassign = restRequest.paramAsBoolean(PutTrainedModelAliasAction.Request.REASSIGN, false); + return channel -> client.execute( + PutTrainedModelAliasAction.INSTANCE, + new PutTrainedModelAliasAction.Request(modelAlias, modelId, reassign), + new RestToXContentListener<>(channel) + ); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java index 69210a22b5554..2aa52e5af0e46 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java @@ -35,6 +35,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.junit.Before; @@ -129,7 +130,6 @@ public void setUpVariables() { null, Collections.singletonList(SKINNY_INGEST_PLUGIN), client); } - public void testInferenceIngestStatsByModelId() { List nodeStatsList = Arrays.asList( buildNodeStats( @@ -198,6 +198,7 @@ public void testInferenceIngestStatsByModelId() { put("trained_model_2", new HashSet<>(Arrays.asList("pipeline1", "pipeline2"))); }}; Map ingestStatsMap = TransportGetTrainedModelsStatsAction.inferenceIngestStatsByModelId(response, + ModelAliasMetadata.EMPTY, pipelineIdsByModelIds); assertThat(ingestStatsMap.keySet(), equalTo(new HashSet<>(Arrays.asList("trained_model_1", "trained_model_2")))); @@ -238,7 +239,7 @@ public void testPipelineIdsByModelIds() throws IOException { ClusterState clusterState = buildClusterStateWithModelReferences(modelId1, modelId2, modelId3); Map> pipelineIdsByModelIds = - TransportGetTrainedModelsStatsAction.pipelineIdsByModelIds(clusterState, ingestService, modelIds); + TransportGetTrainedModelsStatsAction.pipelineIdsByModelIdsOrAliases(clusterState, ingestService, modelIds); assertThat(pipelineIdsByModelIds.keySet(), equalTo(modelIds)); assertThat(pipelineIdsByModelIds, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index 731a876bea99b..b5c7b274072de 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -74,6 +74,7 @@ public void testMutateDocumentWithClassification() { ClassificationConfig.EMPTY_PARAMS, 1.0, 1.0)), + null, true); inferenceProcessor.mutateDocument(response, document); @@ -110,6 +111,7 @@ public void testMutateDocumentClassificationTopNClasses() { classificationConfig, 0.6, 0.6)), + null, true); inferenceProcessor.mutateDocument(response, document); @@ -152,6 +154,7 @@ public void testMutateDocumentClassificationFeatureInfluence() { classificationConfig, 0.6, 0.6)), + null, true); inferenceProcessor.mutateDocument(response, document); @@ -193,6 +196,7 @@ public void testMutateDocumentClassificationTopNClassesWithSpecificField() { classificationConfig, 0.6, 0.6)), + null, true); inferenceProcessor.mutateDocument(response, document); @@ -218,14 +222,16 @@ public void testMutateDocumentRegression() { IngestDocument document = new IngestDocument(source, ingestMetadata); InternalInferModelAction.Response response = new InternalInferModelAction.Response( - Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig)), true); + Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig)), + null, + true); inferenceProcessor.mutateDocument(response, document); assertThat(document.getFieldValue("ml.my_processor.foo", Double.class), equalTo(0.7)); assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("regression_model")); } - public void testMutateDocumentRegressionWithTopFetures() { + public void testMutateDocumentRegressionWithTopFeatures() { RegressionConfig regressionConfig = new RegressionConfig("foo", 2); RegressionConfigUpdate regressionConfigUpdate = new RegressionConfigUpdate("foo", 2); InferenceProcessor inferenceProcessor = new InferenceProcessor(client, @@ -245,7 +251,9 @@ public void testMutateDocumentRegressionWithTopFetures() { featureInfluence.add(new RegressionFeatureImportance("feature_2", -42.0)); InternalInferModelAction.Response response = new InternalInferModelAction.Response( - Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig, featureInfluence)), true); + Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig, featureInfluence)), + null, + true); inferenceProcessor.mutateDocument(response, document); assertThat(document.getFieldValue("ml.my_processor.foo", Double.class), equalTo(0.7)); @@ -383,7 +391,9 @@ public void testHandleResponseLicenseChanged() { assertThat(inferenceProcessor.buildRequest(document).isPreviouslyLicensed(), is(false)); InternalInferModelAction.Response response = new InternalInferModelAction.Response( - Collections.singletonList(new RegressionInferenceResults(0.7, RegressionConfig.EMPTY_PARAMS)), true); + Collections.singletonList(new RegressionInferenceResults(0.7, RegressionConfig.EMPTY_PARAMS)), + null, + true); inferenceProcessor.handleResponse(response, document, (doc, ex) -> { assertThat(doc, is(not(nullValue()))); assertThat(ex, is(nullValue())); @@ -392,7 +402,9 @@ public void testHandleResponseLicenseChanged() { assertThat(inferenceProcessor.buildRequest(document).isPreviouslyLicensed(), is(true)); response = new InternalInferModelAction.Response( - Collections.singletonList(new RegressionInferenceResults(0.7, RegressionConfig.EMPTY_PARAMS)), false); + Collections.singletonList(new RegressionInferenceResults(0.7, RegressionConfig.EMPTY_PARAMS)), + null, + false); inferenceProcessor.handleResponse(response, document, (doc, ex) -> { assertThat(doc, is(not(nullValue()))); @@ -424,11 +436,37 @@ public void testMutateDocumentWithWarningResult() { IngestDocument document = new IngestDocument(source, ingestMetadata); InternalInferModelAction.Response response = new InternalInferModelAction.Response( - Collections.singletonList(new WarningInferenceResults("something broke")), true); + Collections.singletonList(new WarningInferenceResults("something broke")), null, true); inferenceProcessor.mutateDocument(response, document); assertThat(document.hasField(targetField), is(false)); assertThat(document.hasField("ml.warning"), is(true)); assertThat(document.hasField("ml.my_processor"), is(false)); } + + public void testMutateDocumentWithModelIdResult() { + String modelAlias = "special_model"; + String modelId = "regression-123"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + auditor, + "my_processor", + null, + "ml.my_processor", + modelAlias, + new RegressionConfigUpdate("foo", null), + Collections.emptyMap()); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + InternalInferModelAction.Response response = new InternalInferModelAction.Response( + Collections.singletonList(new RegressionInferenceResults(0.7, new RegressionConfig("foo"))), + modelId, + true); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue("ml.my_processor.foo", Double.class), equalTo(0.7)); + assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo(modelId)); + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index d9bf1f84edd96..5207616023f8f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.unit.ByteSizeValue; @@ -44,6 +45,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -59,10 +61,12 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME; import static org.hamcrest.Matchers.equalTo; @@ -282,7 +286,6 @@ public boolean matches(final Object o) { verify(trainedModelProvider, times(5)).getTrainedModelForInference(eq(model3), any()); } - public void testWhenCacheEnabledButNotIngestNode() throws Exception { String model1 = "test-uncached-not-ingest-model-1"; withTrainedModel(model1, 1L); @@ -538,6 +541,101 @@ public void testReferenceCounting_ModelIsNotCached() throws ExecutionException, assertEquals(1, model.getReferenceCount()); } + public void testGetCachedModelViaModelAliases() throws Exception { + String model1 = "test-load-model-1"; + String model2 = "test-load-model-2"; + withTrainedModel(model1, 1L); + withTrainedModel(model2, 1L); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + trainedModelStatsService, + Settings.EMPTY, + "test-node", + circuitBreaker); + + modelLoadingService.clusterChanged(aliasChangeEvent( + true, + new String[]{"loaded_model"}, + true, + Arrays.asList(Tuple.tuple(model1, "loaded_model")) + )); + + String[] modelIds = new String[]{model1, "loaded_model"}; + for(int i = 0; i < 10; i++) { + String model = modelIds[i%2]; + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModelForPipeline(model, future); + assertThat(future.get(), is(not(nullValue()))); + } + + verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model1), any()); + + assertTrue(modelLoadingService.isModelCached(model1)); + assertTrue(modelLoadingService.isModelCached("loaded_model")); + + // alias change only + modelLoadingService.clusterChanged(aliasChangeEvent( + true, + new String[]{"loaded_model"}, + false, + Arrays.asList(Tuple.tuple(model2, "loaded_model")) + )); + + modelIds = new String[]{model2, "loaded_model"}; + for(int i = 0; i < 10; i++) { + String model = modelIds[i%2]; + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModelForPipeline(model, future); + assertThat(future.get(), is(not(nullValue()))); + } + + verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model2), any()); + assertTrue(modelLoadingService.isModelCached(model2)); + assertTrue(modelLoadingService.isModelCached("loaded_model")); + } + + public void testAliasesGetUpdatedEvenWhenNotIngestNode() throws IOException { + String model1 = "test-load-model-1"; + withTrainedModel(model1, 1L); + String model2 = "test-load-model-2"; + withTrainedModel(model2, 1L); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + trainedModelStatsService, + Settings.EMPTY, + "test-node", + circuitBreaker); + + modelLoadingService.clusterChanged(aliasChangeEvent( + false, + new String[0], + false, + Arrays.asList(Tuple.tuple(model1, "loaded_model")) + )); + + assertThat(modelLoadingService.getModelId("loaded_model"), equalTo(model1)); + + modelLoadingService.clusterChanged(aliasChangeEvent( + false, + new String[0], + false, + Arrays.asList( + Tuple.tuple(model1, "loaded_model_again"), + Tuple.tuple(model1, "loaded_model_foo"), + Tuple.tuple(model2, "loaded_model") + ) + )); + assertThat(modelLoadingService.getModelId("loaded_model"), equalTo(model2)); + assertThat(modelLoadingService.getModelId("loaded_model_foo"), equalTo(model1)); + assertThat(modelLoadingService.getModelId("loaded_model_again"), equalTo(model1)); + } + @SuppressWarnings("unchecked") private void withTrainedModel(String modelId, long size) { InferenceDefinition definition = mock(InferenceDefinition.class); @@ -601,6 +699,21 @@ private static ClusterChangedEvent ingestChangedEvent(String... modelId) throws return ingestChangedEvent(true, modelId); } + private static ClusterChangedEvent aliasChangeEvent(boolean isIngestNode, + String[] modelId, + boolean ingestToo, + List> modelIdAndAliases) throws IOException { + ClusterChangedEvent event = mock(ClusterChangedEvent.class); + Set set = new HashSet<>(); + set.add(ModelAliasMetadata.NAME); + if (ingestToo) { + set.add(IngestMetadata.TYPE); + } + when(event.changedCustomMetadataSet()).thenReturn(set); + when(event.state()).thenReturn(withModelReferencesAndAliasChange(isIngestNode, modelId, modelIdAndAliases)); + return event; + } + private static ClusterChangedEvent ingestChangedEvent(boolean isIngestNode, String... modelId) throws IOException { ClusterChangedEvent event = mock(ClusterChangedEvent.class); when(event.changedCustomMetadataSet()).thenReturn(Collections.singleton(IngestMetadata.TYPE)); @@ -609,14 +722,17 @@ private static ClusterChangedEvent ingestChangedEvent(boolean isIngestNode, Stri } private static ClusterState buildClusterStateWithModelReferences(boolean isIngestNode, String... modelId) throws IOException { - Map configurations = new HashMap<>(modelId.length); - for (String id : modelId) { - configurations.put("pipeline_with_model_" + id, newConfigurationWithInferenceProcessor(id)); - } - IngestMetadata ingestMetadata = new IngestMetadata(configurations); + return builder(isIngestNode).metadata(addIngest(Metadata.builder(), modelId)).build(); + } + + private static ClusterState withModelReferencesAndAliasChange(boolean isIngestNode, + String[] modelId, + List> modelIdAndAliases) throws IOException { + return builder(isIngestNode).metadata(addAliases(addIngest(Metadata.builder(), modelId), modelIdAndAliases)).build(); + } + private static ClusterState.Builder builder(boolean isIngestNode) { return ClusterState.builder(new ClusterName("_name")) - .metadata(Metadata.builder().putCustom(IngestMetadata.TYPE, ingestMetadata)) .nodes(DiscoveryNodes.builder().add( new DiscoveryNode("node_name", "node_id", @@ -625,8 +741,23 @@ private static ClusterState buildClusterStateWithModelReferences(boolean isInges isIngestNode ? Collections.singleton(DiscoveryNodeRole.INGEST_ROLE) : Collections.emptySet(), Version.CURRENT)) .localNodeId("node_id") - .build()) - .build(); + .build() + ); + } + + private static Metadata.Builder addIngest(Metadata.Builder builder, String... modelId) throws IOException { + Map configurations = new HashMap<>(modelId.length); + for (String id : modelId) { + configurations.put("pipeline_with_model_" + id, newConfigurationWithInferenceProcessor(id)); + } + IngestMetadata ingestMetadata = new IngestMetadata(configurations); + return builder.putCustom(IngestMetadata.TYPE, ingestMetadata); + } + + private static Metadata.Builder addAliases(Metadata.Builder builder, List> modelIdAndAliases) { + ModelAliasMetadata modelAliasMetadata = new ModelAliasMetadata(modelIdAndAliases.stream() + .collect(Collectors.toMap(Tuple::v2, t -> new ModelAliasMetadata.ModelAliasEntry(t.v1())))); + return builder.putCustom(ModelAliasMetadata.NAME, modelAliasMetadata); } private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId) throws IOException { diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index ec6daa1b44b77..7a99f18b7a402 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -135,6 +135,7 @@ public class Constants { "cluster:admin/xpack/ml/filters/update", "cluster:admin/xpack/ml/inference/delete", "cluster:admin/xpack/ml/inference/put", + "cluster:admin/xpack/ml/inference/model_aliases/put", "cluster:admin/xpack/ml/job/close", "cluster:admin/xpack/ml/job/data/post", "cluster:admin/xpack/ml/job/delete", diff --git a/x-pack/plugin/src/test/java/org/elasticsearch/xpack/test/rest/AbstractXPackRestTest.java b/x-pack/plugin/src/test/java/org/elasticsearch/xpack/test/rest/AbstractXPackRestTest.java index bd3fd47586ea1..9cd81093d9e0d 100644 --- a/x-pack/plugin/src/test/java/org/elasticsearch/xpack/test/rest/AbstractXPackRestTest.java +++ b/x-pack/plugin/src/test/java/org/elasticsearch/xpack/test/rest/AbstractXPackRestTest.java @@ -18,12 +18,12 @@ import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.plugins.MetadataUpgrader; import org.elasticsearch.test.SecuritySettingsSourceField; -import org.elasticsearch.test.rest.ESRestTestCase; import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; import org.elasticsearch.test.rest.yaml.ClientYamlTestResponse; import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; import org.elasticsearch.xpack.core.ml.MlConfigIndex; import org.elasticsearch.xpack.core.ml.MlMetaIndex; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields; diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model_alias.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model_alias.json new file mode 100644 index 0000000000000..c07e96397fe29 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model_alias.json @@ -0,0 +1,40 @@ +{ + "ml.put_trained_model_alias":{ + "documentation":{ + "url":"https://www.elastic.co/guide/en/elasticsearch/reference/current/put-trained-models-aliases.html", + "description":"Creates a new model alias (or reassigns an existing one) to refer to the trained model" + }, + "stability":"beta", + "visibility":"public", + "headers":{ + "accept": [ "application/json"], + "content_type": ["application/json"] + }, + "url":{ + "paths":[ + { + "path":"/_ml/trained_models/{model_id}/model_aliases/{model_alias}", + "methods":[ + "PUT" + ], + "parts":{ + "model_alias":{ + "type":"string", + "description":"The trained model alias to update" + }, + "model_id": { + "type": "string", + "description": "The trained model where the model alias should be assigned" + } + } + } + ] + }, + "params":{ + "reassign":{ + "type":"boolean", + "description":"If the model_alias already exists and points to a separate model_id, this parameter must be true. Defaults to false." + } + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/get_trained_model_stats.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/get_trained_model_stats.yml index 4acb694835f68..93f23e1f0f405 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/get_trained_model_stats.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/get_trained_model_stats.yml @@ -79,6 +79,12 @@ setup: } } } + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ml.put_trained_model_alias: + model_alias: "my-regression" + model_id: "a-unused-regression-model1" --- "Test get stats given missing trained model": @@ -175,3 +181,20 @@ setup: - match: { count: 1 } - match: { trained_model_stats.0.model_id: another-regression-model } - match: { trained_model_stats.0.pipeline_count: 0 } + + +# test with model alias + - do: + ml.get_trained_models_stats: + model_id: "my-regression" + + - match: { count: 1 } + - match: { trained_model_stats.0.model_id: a-unused-regression-model1 } + + - do: + ml.get_trained_models_stats: + model_id: "my-regression,another-regression-model" + + - match: { count: 2 } + - match: { trained_model_stats.0.model_id: a-unused-regression-model1 } + - match: { trained_model_stats.1.model_id: another-regression-model } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index de7ce694f9857..0994bdf33319c 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -561,30 +561,6 @@ setup: ml.delete_trained_model: model_id: "missing-trained-model" --- -"Test delete given used trained model": - - do: - ingest.put_pipeline: - id: "regression-model-pipeline" - body: > - { - "processors": [ - { - "inference" : { - "model_id" : "a-regression-model-0", - "inference_config": {"regression": {}}, - "target_field": "regression_field", - "field_map": {} - } - } - ] - } - - match: { acknowledged: true } - - - do: - catch: conflict - ml.delete_trained_model: - model_id: "a-regression-model-0" ---- "Test get pre-packaged trained models": - do: ml.get_trained_models: @@ -851,3 +827,89 @@ setup: model_id: "a-regression-model-1" include_model_definition: true decompress_definition: false +--- +"Test put model model aliases": + + - do: + ml.put_trained_model_alias: + model_alias: "regression-model" + model_id: "a-regression-model-1" + - do: + ml.get_trained_models: + model_id: "regression-model,a-classification-model" + + - match: { count: 2 } + - length: { trained_model_configs: 2 } + - match: { trained_model_configs.0.model_id: "a-classification-model" } + - match: { trained_model_configs.1.model_id: "a-regression-model-1" } + + - do: + ml.put_trained_model_alias: + model_alias: "regression-model" + model_id: "a-regression-model-0" + reassign: true + - do: + ml.get_trained_models: + model_id: "regression-model,a-classification-model" + + - match: { count: 2 } + - length: { trained_model_configs: 2 } + - match: { trained_model_configs.0.model_id: "a-classification-model" } + - match: { trained_model_configs.1.model_id: "a-regression-model-0" } + + - do: + ml.put_trained_model_alias: + model_alias: "regression-model-again" + model_id: "a-regression-model-0" + - do: + ml.get_trained_models: + model_id: "a-regression-model-*" + size: 1 + + - length: { trained_model_configs: 1 } + - match: { trained_model_configs.0.model_id: "a-regression-model-0" } + - match: { trained_model_configs.0.metadata.model_aliases.0: "regression-model" } + - match: { trained_model_configs.0.metadata.model_aliases.1: "regression-model-again" } +--- +"Test update model alias with model id referring to missing model": + - do: + catch: missing + ml.put_trained_model_alias: + model_alias: "regression-model" + model_id: "missing-model" +--- +"Test update model alias with bad alias": + - do: + catch: /must start with alphanumeric and cannot end with numbers/ + ml.put_trained_model_alias: + model_alias: "regression-model-123123" + model_id: "regression-model-123123" + - do: + catch: bad_request + ml.put_trained_model_alias: + model_alias: "z-classification-model" + model_id: "z-classification-model" +--- +"Test update model alias where alias exists but old model id is different inference type": + - do: + ml.put_trained_model_alias: + model_alias: "regression-model" + model_id: "a-regression-model-0" + - do: + catch: bad_request + ml.put_trained_model_alias: + model_alias: "regression-model" + model_id: "a-classification-model" + reassign: true +--- +"Test update model alias where alias exists but reassign is false": + - do: + ml.put_trained_model_alias: + model_alias: "regression-model" + model_id: "a-regression-model-0" + - do: + catch: bad_request + ml.put_trained_model_alias: + model_alias: "regression-model" + model_id: "a-regression-model-1" + reassign: false