Skip to content

Commit 26eef89

Browse files
authored
[ML] adds new trained model alias API to simplify trained model updates and deployments (#68922)
A `model_alias` allows trained models to be referred by a user defined moniker. This not only improves the readability and simplicity of numerous API calls, but it allows for simpler deployment and upgrade procedures for trained models. Previously, if you referenced a model ID directly within an ingest pipeline, when you have a new model that performs better than an earlier referenced model, you have to update the pipeline itself. If this model was used in numerous pipelines, ALL those pipelines would have to be updated. When using a `model_alias` in an ingest pipeline, only that `model_alias` needs to be updated. Then, the underlying referenced model will change in place for all ingest pipelines automatically. An additional benefit is that the model referenced is not changed until it is fully loaded into cache, this way throughput is not hampered by changing models.
1 parent d24f5cb commit 26eef89

File tree

40 files changed

+1962
-219
lines changed

40 files changed

+1962
-219
lines changed

docs/reference/ml/df-analytics/apis/get-trained-models-stats.asciidoc

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ request by using a comma-separated list of model IDs or a wildcard expression.
4848

4949
`<model_id>`::
5050
(Optional, string)
51-
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
51+
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id-or-alias]
5252

5353

5454
[[ml-get-trained-models-stats-query-params]]

docs/reference/ml/df-analytics/apis/get-trained-models.asciidoc

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ using a comma-separated list of model IDs or a wildcard expression.
5050

5151
`<model_id>`::
5252
(Optional, string)
53-
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
53+
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id-or-alias]
5454

5555

5656
[[ml-get-trained-models-query-params]]

docs/reference/ml/df-analytics/apis/index.asciidoc

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ include::ml-df-analytics-apis.asciidoc[leveloffset=+1]
22
//CREATE
33
include::put-dfanalytics.asciidoc[leveloffset=+2]
44
include::put-trained-models.asciidoc[leveloffset=+2]
5+
include::put-trained-models-aliases.asciidoc[leveloffset=+2]
56
//UPDATE
67
include::update-dfanalytics.asciidoc[leveloffset=+2]
78
//DELETE

docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ You can use the following APIs to perform {infer} operations.
2222
* <<get-trained-models>>
2323
* <<get-trained-models-stats>>
2424
* <<delete-trained-models>>
25+
* <<put-trained-models-aliases>>
2526

26-
You can deploy a trained model to make predictions in an ingest pipeline or in
27+
You can deploy a trained model to make predictions in an ingest pipeline or in
2728
an aggregation. Refer to the following documentation to learn more.
2829

2930
* <<inference-processor,{infer-cap} processor>>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
[role="xpack"]
2+
[testenv="platinum"]
3+
[[put-trained-models-aliases]]
4+
= Put Trained Models Aliases API
5+
[subs="attributes"]
6+
++++
7+
<titleabbrev>Put Trained Models Aliases</titleabbrev>
8+
++++
9+
10+
Creates a trained models alias. These model aliases can be used instead of the trained model ID
11+
when referencing the model in the stack. Model aliases must be unique, and a trained model can have
12+
more than one model alias referring to it. But a model alias can only refer to a single trained model.
13+
14+
beta::[]
15+
16+
[[ml-put-trained-models-aliases-request]]
17+
== {api-request-title}
18+
19+
`PUT _ml/trained_models/<model_id>/model_aliases/<model_alias>`
20+
21+
22+
[[ml-put-trained-models-aliases-prereq]]
23+
== {api-prereq-title}
24+
25+
If the {es} {security-features} are enabled, you must have the following
26+
built-in roles and privileges:
27+
28+
* `machine_learning_admin`
29+
30+
For more information, see <<built-in-roles>>, <<security-privileges>>, and
31+
{ml-docs-setup-privileges}.
32+
33+
[[ml-put-trained-models-aliases-desc]]
34+
== {api-description-title}
35+
36+
This API creates a new model alias to refer to trained models, or updates an existing
37+
trained model's alias.
38+
39+
When updating an existing model alias to a new model ID, this API will return a error if the models
40+
are of different inference types. Example, if attempting to put the model alias
41+
`flights-delay-prediction` from a regression model to a classification model, the API will error.
42+
43+
The API will return a warning if there are very few input fields in common between the old
44+
and new models for the model alias.
45+
46+
[[ml-put-trained-models-aliases-path-params]]
47+
== {api-path-parms-title}
48+
49+
`model_id`::
50+
(Required, string)
51+
The trained model ID to which the model alias should refer.
52+
53+
`model_alias`::
54+
(Required, string)
55+
The model alias to create or update. The model_alias cannot end in numbers.
56+
57+
[[ml-put-trained-models-aliases-query-params]]
58+
== {api-query-parms-title}
59+
60+
`reassign`::
61+
(Optional, boolean)
62+
Should the `model_alias` get reassigned to the provided `model_id` if it is already
63+
assigned to a model. Defaults to false. The API will return an error if the `model_alias`
64+
is already assigned to a model but this parameter is `false`.
65+
66+
[[ml-put-trained-models-aliases-example]]
67+
== {api-examples-title}
68+
69+
[[ml-put-trained-models-aliases-example-new-alias]]
70+
=== Creating a new model alias
71+
72+
The following example shows how to create a new model alias for a trained model ID.
73+
74+
[source,console]
75+
--------------------------------------------------
76+
PUT _ml/trained_models/flight-delay-prediction-1574775339910/model_aliases/flight_delay_model
77+
--------------------------------------------------
78+
// TEST[skip:setup kibana sample data]
79+
80+
[[ml-put-trained-models-aliases-example-put-alias]]
81+
=== Updating an existing model alias
82+
83+
The following example shows how to reassign an existing model alias for a trained model ID.
84+
85+
[source,console]
86+
--------------------------------------------------
87+
PUT _ml/trained_models/flight-delay-prediction-1580004349800/model_aliases/flight_delay_model?reassign=true
88+
--------------------------------------------------
89+
// TEST[skip:setup kibana sample data]

docs/reference/ml/ml-shared.asciidoc

+4
Original file line numberDiff line numberDiff line change
@@ -1149,6 +1149,10 @@ tag::model-id[]
11491149
The unique identifier of the trained model.
11501150
end::model-id[]
11511151

1152+
tag::model-id-or-alias[]
1153+
The unique identifier of the trained model or a model alias.
1154+
end::model-id-or-alias[]
1155+
11521156
tag::model-memory-limit[]
11531157
The approximate maximum amount of memory resources that are required for
11541158
analytical processing. Once this limit is approached, data pruning becomes

server/src/main/java/org/elasticsearch/common/util/set/Sets.java

+6
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ public static <T> boolean haveEmptyIntersection(Set<T> left, Set<T> right) {
6262
return left.stream().noneMatch(right::contains);
6363
}
6464

65+
public static <T> boolean haveNonEmptyIntersection(Set<T> left, Set<T> right) {
66+
Objects.requireNonNull(left);
67+
Objects.requireNonNull(right);
68+
return left.stream().anyMatch(right::contains);
69+
}
70+
6571
/**
6672
* The relative complement, or difference, of the specified left and right set. Namely, the resulting set contains all the elements that
6773
* are in the left set but not in the right set. Neither input is mutated by this operation, an entirely new set is returned.

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java

+7-7
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ protected Reader<Response.TrainedModelStats> getReader() {
171171
public static class Builder {
172172

173173
private long totalModelCount;
174-
private Set<String> expandedIds;
174+
private Map<String, Set<String>> expandedIdsWithAliases;
175175
private Map<String, IngestStats> ingestStatsMap;
176176
private Map<String, InferenceStats> inferenceStatsMap;
177177

@@ -180,13 +180,13 @@ public Builder setTotalModelCount(long totalModelCount) {
180180
return this;
181181
}
182182

183-
public Builder setExpandedIds(Set<String> expandedIds) {
184-
this.expandedIds = expandedIds;
183+
public Builder setExpandedIdsWithAliases(Map<String, Set<String>> expandedIdsWithAliases) {
184+
this.expandedIdsWithAliases = expandedIdsWithAliases;
185185
return this;
186186
}
187187

188-
public Set<String> getExpandedIds() {
189-
return this.expandedIds;
188+
public Map<String, Set<String>> getExpandedIdsWithAliases() {
189+
return this.expandedIdsWithAliases;
190190
}
191191

192192
public Builder setIngestStatsByModelId(Map<String, IngestStats> ingestStatsByModelId) {
@@ -200,8 +200,8 @@ public Builder setInferenceStatsByModelId(Map<String, InferenceStats> infereceSt
200200
}
201201

202202
public Response build() {
203-
List<TrainedModelStats> trainedModelStats = new ArrayList<>(expandedIds.size());
204-
expandedIds.forEach(id -> {
203+
List<TrainedModelStats> trainedModelStats = new ArrayList<>(expandedIdsWithAliases.size());
204+
expandedIdsWithAliases.keySet().forEach(id -> {
205205
IngestStats ingestStats = ingestStatsMap.get(id);
206206
InferenceStats inferenceStats = inferenceStatsMap.get(id);
207207
trainedModelStats.add(new TrainedModelStats(

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java

+26-4
Original file line numberDiff line numberDiff line change
@@ -143,18 +143,25 @@ public int hashCode() {
143143
public static class Response extends ActionResponse {
144144

145145
private final List<InferenceResults> inferenceResults;
146+
private final String modelId;
146147
private final boolean isLicensed;
147148

148-
public Response(List<InferenceResults> inferenceResults, boolean isLicensed) {
149+
public Response(List<InferenceResults> inferenceResults, String modelId, boolean isLicensed) {
149150
super();
150151
this.inferenceResults = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(inferenceResults, "inferenceResults"));
151152
this.isLicensed = isLicensed;
153+
this.modelId = modelId;
152154
}
153155

154156
public Response(StreamInput in) throws IOException {
155157
super(in);
156158
this.inferenceResults = Collections.unmodifiableList(in.readNamedWriteableList(InferenceResults.class));
157159
this.isLicensed = in.readBoolean();
160+
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
161+
this.modelId = in.readOptionalString();
162+
} else {
163+
this.modelId = null;
164+
}
158165
}
159166

160167
public List<InferenceResults> getInferenceResults() {
@@ -165,23 +172,32 @@ public boolean isLicensed() {
165172
return isLicensed;
166173
}
167174

175+
public String getModelId() {
176+
return modelId;
177+
}
178+
168179
@Override
169180
public void writeTo(StreamOutput out) throws IOException {
170181
out.writeNamedWriteableList(inferenceResults);
171182
out.writeBoolean(isLicensed);
183+
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
184+
out.writeOptionalString(modelId);
185+
}
172186
}
173187

174188
@Override
175189
public boolean equals(Object o) {
176190
if (this == o) return true;
177191
if (o == null || getClass() != o.getClass()) return false;
178192
InternalInferModelAction.Response that = (InternalInferModelAction.Response) o;
179-
return isLicensed == that.isLicensed && Objects.equals(inferenceResults, that.inferenceResults);
193+
return isLicensed == that.isLicensed
194+
&& Objects.equals(inferenceResults, that.inferenceResults)
195+
&& Objects.equals(modelId, that.modelId);
180196
}
181197

182198
@Override
183199
public int hashCode() {
184-
return Objects.hash(inferenceResults, isLicensed);
200+
return Objects.hash(inferenceResults, isLicensed, modelId);
185201
}
186202

187203
public static Builder builder() {
@@ -190,6 +206,7 @@ public static Builder builder() {
190206

191207
public static class Builder {
192208
private List<InferenceResults> inferenceResults;
209+
private String modelId;
193210
private boolean isLicensed;
194211

195212
public Builder setInferenceResults(List<InferenceResults> inferenceResults) {
@@ -202,8 +219,13 @@ public Builder setLicensed(boolean licensed) {
202219
return this;
203220
}
204221

222+
public Builder setModelId(String modelId) {
223+
this.modelId = modelId;
224+
return this;
225+
}
226+
205227
public Response build() {
206-
return new Response(inferenceResults, isLicensed);
228+
return new Response(inferenceResults, modelId, isLicensed);
207229
}
208230
}
209231

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.ml.action;
9+
10+
import org.elasticsearch.action.ActionRequestValidationException;
11+
import org.elasticsearch.action.ActionType;
12+
import org.elasticsearch.action.support.master.AcknowledgedRequest;
13+
import org.elasticsearch.action.support.master.AcknowledgedResponse;
14+
import org.elasticsearch.common.io.stream.StreamInput;
15+
import org.elasticsearch.common.io.stream.StreamOutput;
16+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
17+
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
18+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
19+
20+
import java.io.IOException;
21+
import java.util.Locale;
22+
import java.util.Objects;
23+
import java.util.regex.Pattern;
24+
25+
import static org.elasticsearch.action.ValidateActions.addValidationError;
26+
import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INVALID_MODEL_ALIAS;
27+
28+
public class PutTrainedModelAliasAction extends ActionType<AcknowledgedResponse> {
29+
30+
// NOTE this is similar to our valid ID check. The difference here is that model_aliases cannot end in numbers
31+
// This is to protect our automatic model naming conventions from hitting weird model_alias conflicts
32+
private static final Pattern VALID_MODEL_ALIAS_CHAR_PATTERN = Pattern.compile("[a-z0-9](?:[a-z0-9_\\-\\.]*[a-z])?");
33+
34+
public static final PutTrainedModelAliasAction INSTANCE = new PutTrainedModelAliasAction();
35+
public static final String NAME = "cluster:admin/xpack/ml/inference/model_aliases/put";
36+
37+
private PutTrainedModelAliasAction() {
38+
super(NAME, AcknowledgedResponse::readFrom);
39+
}
40+
41+
public static class Request extends AcknowledgedRequest<Request> {
42+
43+
public static final String MODEL_ALIAS = "model_alias";
44+
public static final String REASSIGN = "reassign";
45+
46+
private final String modelAlias;
47+
private final String modelId;
48+
private final boolean reassign;
49+
50+
public Request(String modelAlias, String modelId, boolean reassign) {
51+
this.modelAlias = ExceptionsHelper.requireNonNull(modelAlias, MODEL_ALIAS);
52+
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
53+
this.reassign = reassign;
54+
}
55+
56+
public Request(StreamInput in) throws IOException {
57+
super(in);
58+
this.modelAlias = in.readString();
59+
this.modelId = in.readString();
60+
this.reassign = in.readBoolean();
61+
}
62+
63+
public String getModelAlias() {
64+
return modelAlias;
65+
}
66+
67+
public String getModelId() {
68+
return modelId;
69+
}
70+
71+
public boolean isReassign() {
72+
return reassign;
73+
}
74+
75+
@Override
76+
public void writeTo(StreamOutput out) throws IOException {
77+
super.writeTo(out);
78+
out.writeString(modelAlias);
79+
out.writeString(modelId);
80+
out.writeBoolean(reassign);
81+
}
82+
83+
@Override
84+
public ActionRequestValidationException validate() {
85+
ActionRequestValidationException validationException = null;
86+
if (modelAlias.equals(modelId)) {
87+
validationException = addValidationError(
88+
String.format(
89+
Locale.ROOT,
90+
"model_alias [%s] cannot equal model_id [%s]",
91+
modelAlias,
92+
modelId
93+
),
94+
validationException
95+
);
96+
}
97+
if (VALID_MODEL_ALIAS_CHAR_PATTERN.matcher(modelAlias).matches() == false) {
98+
validationException = addValidationError(Messages.getMessage(INVALID_MODEL_ALIAS, modelAlias), validationException);
99+
}
100+
return validationException;
101+
}
102+
103+
@Override
104+
public boolean equals(Object o) {
105+
if (this == o) return true;
106+
if (o == null || getClass() != o.getClass()) return false;
107+
Request request = (Request) o;
108+
return Objects.equals(modelAlias, request.modelAlias)
109+
&& Objects.equals(modelId, request.modelId)
110+
&& Objects.equals(reassign, request.reassign);
111+
}
112+
113+
@Override
114+
public int hashCode() {
115+
return Objects.hash(modelAlias, modelId, reassign);
116+
}
117+
118+
}
119+
}

0 commit comments

Comments
 (0)