Skip to content

Commit 3250dd7

Browse files
authored
[7.x] [ML] adds new trained model alias API to simplify trained model updates and deployments (#68922) (#69208)
* [ML] adds new trained model alias API to simplify trained model updates and deployments (#68922) A `model_alias` allows trained models to be referred by a user defined moniker. This not only improves the readability and simplicity of numerous API calls, but it allows for simpler deployment and upgrade procedures for trained models. Previously, if you referenced a model ID directly within an ingest pipeline, when you have a new model that performs better than an earlier referenced model, you have to update the pipeline itself. If this model was used in numerous pipelines, ALL those pipelines would have to be updated. When using a `model_alias` in an ingest pipeline, only that `model_alias` needs to be updated. Then, the underlying referenced model will change in place for all ingest pipelines automatically. An additional benefit is that the model referenced is not changed until it is fully loaded into cache, this way throughput is not hampered by changing models.
1 parent ee19703 commit 3250dd7

File tree

41 files changed

+1957
-219
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1957
-219
lines changed

docs/reference/ml/df-analytics/apis/get-trained-models-stats.asciidoc

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ request by using a comma-separated list of model IDs or a wildcard expression.
4747

4848
`<model_id>`::
4949
(Optional, string)
50-
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
50+
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id-or-alias]
5151

5252

5353
[[ml-get-trained-models-stats-query-params]]

docs/reference/ml/df-analytics/apis/get-trained-models.asciidoc

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ using a comma-separated list of model IDs or a wildcard expression.
5050

5151
`<model_id>`::
5252
(Optional, string)
53-
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
53+
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id-or-alias]
5454

5555

5656
[[ml-get-trained-models-query-params]]

docs/reference/ml/df-analytics/apis/index.asciidoc

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ include::ml-df-analytics-apis.asciidoc[leveloffset=+1]
22
//CREATE
33
include::put-dfanalytics.asciidoc[leveloffset=+2]
44
include::put-trained-models.asciidoc[leveloffset=+2]
5+
include::put-trained-models-aliases.asciidoc[leveloffset=+2]
56
//UPDATE
67
include::update-dfanalytics.asciidoc[leveloffset=+2]
78
//DELETE

docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ You can use the following APIs to perform {infer} operations.
2222
* <<get-trained-models>>
2323
* <<get-trained-models-stats>>
2424
* <<delete-trained-models>>
25+
* <<put-trained-models-aliases>>
2526

26-
You can deploy a trained model to make predictions in an ingest pipeline or in
27+
You can deploy a trained model to make predictions in an ingest pipeline or in
2728
an aggregation. Refer to the following documentation to learn more.
2829

2930
* <<inference-processor,{infer-cap} processor>>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
[role="xpack"]
2+
[testenv="platinum"]
3+
[[put-trained-models-aliases]]
4+
= Put Trained Models Aliases API
5+
[subs="attributes"]
6+
++++
7+
<titleabbrev>Put Trained Models Aliases</titleabbrev>
8+
++++
9+
10+
Creates a trained models alias. These model aliases can be used instead of the trained model ID
11+
when referencing the model in the stack. Model aliases must be unique, and a trained model can have
12+
more than one model alias referring to it. But a model alias can only refer to a single trained model.
13+
14+
beta::[]
15+
16+
[[ml-put-trained-models-aliases-request]]
17+
== {api-request-title}
18+
19+
`PUT _ml/trained_models/<model_id>/model_aliases/<model_alias>`
20+
21+
22+
[[ml-put-trained-models-aliases-prereq]]
23+
== {api-prereq-title}
24+
25+
If the {es} {security-features} are enabled, you must have the following
26+
built-in roles and privileges:
27+
28+
* `machine_learning_admin`
29+
30+
For more information, see <<built-in-roles>>, <<security-privileges>>, and
31+
{ml-docs-setup-privileges}.
32+
33+
[[ml-put-trained-models-aliases-desc]]
34+
== {api-description-title}
35+
36+
This API creates a new model alias to refer to trained models, or updates an existing
37+
trained model's alias.
38+
39+
When updating an existing model alias to a new model ID, this API will return a error if the models
40+
are of different inference types. Example, if attempting to put the model alias
41+
`flights-delay-prediction` from a regression model to a classification model, the API will error.
42+
43+
The API will return a warning if there are very few input fields in common between the old
44+
and new models for the model alias.
45+
46+
[[ml-put-trained-models-aliases-path-params]]
47+
== {api-path-parms-title}
48+
49+
`model_id`::
50+
(Required, string)
51+
The trained model ID to which the model alias should refer.
52+
53+
`model_alias`::
54+
(Required, string)
55+
The model alias to create or update. The model_alias cannot end in numbers.
56+
57+
[[ml-put-trained-models-aliases-query-params]]
58+
== {api-query-parms-title}
59+
60+
`reassign`::
61+
(Optional, boolean)
62+
Should the `model_alias` get reassigned to the provided `model_id` if it is already
63+
assigned to a model. Defaults to false. The API will return an error if the `model_alias`
64+
is already assigned to a model but this parameter is `false`.
65+
66+
[[ml-put-trained-models-aliases-example]]
67+
== {api-examples-title}
68+
69+
[[ml-put-trained-models-aliases-example-new-alias]]
70+
=== Creating a new model alias
71+
72+
The following example shows how to create a new model alias for a trained model ID.
73+
74+
[source,console]
75+
--------------------------------------------------
76+
PUT _ml/trained_models/flight-delay-prediction-1574775339910/model_aliases/flight_delay_model
77+
--------------------------------------------------
78+
// TEST[skip:setup kibana sample data]
79+
80+
[[ml-put-trained-models-aliases-example-put-alias]]
81+
=== Updating an existing model alias
82+
83+
The following example shows how to reassign an existing model alias for a trained model ID.
84+
85+
[source,console]
86+
--------------------------------------------------
87+
PUT _ml/trained_models/flight-delay-prediction-1580004349800/model_aliases/flight_delay_model?reassign=true
88+
--------------------------------------------------
89+
// TEST[skip:setup kibana sample data]

docs/reference/ml/ml-shared.asciidoc

+4
Original file line numberDiff line numberDiff line change
@@ -1149,6 +1149,10 @@ tag::model-id[]
11491149
The unique identifier of the trained model.
11501150
end::model-id[]
11511151

1152+
tag::model-id-or-alias[]
1153+
The unique identifier of the trained model or a model alias.
1154+
end::model-id-or-alias[]
1155+
11521156
tag::model-memory-limit[]
11531157
The approximate maximum amount of memory resources that are required for
11541158
analytical processing. Once this limit is approached, data pruning becomes

server/src/main/java/org/elasticsearch/common/util/set/Sets.java

+6
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ public static <T> boolean haveEmptyIntersection(Set<T> left, Set<T> right) {
6060
return left.stream().noneMatch(right::contains);
6161
}
6262

63+
public static <T> boolean haveNonEmptyIntersection(Set<T> left, Set<T> right) {
64+
Objects.requireNonNull(left);
65+
Objects.requireNonNull(right);
66+
return left.stream().anyMatch(right::contains);
67+
}
68+
6369
/**
6470
* The relative complement, or difference, of the specified left and right set. Namely, the resulting set contains all the elements that
6571
* are in the left set but not in the right set. Neither input is mutated by this operation, an entirely new set is returned.

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java

+8
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
109109
import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
110110
import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction;
111+
import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata;
111112
import org.elasticsearch.xpack.core.rollup.action.RollupIndexerAction;
112113
import org.elasticsearch.xpack.core.ml.action.FlushJobAction;
113114
import org.elasticsearch.xpack.core.ml.action.ForecastJobAction;
@@ -534,6 +535,8 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
534535
// logstash
535536
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.LOGSTASH, LogstashFeatureSetUsage::new),
536537
// ML - Custom metadata
538+
new NamedWriteableRegistry.Entry(Metadata.Custom.class, ModelAliasMetadata.NAME, ModelAliasMetadata::new),
539+
new NamedWriteableRegistry.Entry(NamedDiff.class, ModelAliasMetadata.NAME, ModelAliasMetadata::readDiffFrom),
537540
new NamedWriteableRegistry.Entry(Metadata.Custom.class, "ml", MlMetadata::new),
538541
new NamedWriteableRegistry.Entry(NamedDiff.class, "ml", MlMetadata.MlMetadataDiff::new),
539542
// ML - Persistent action requests
@@ -712,6 +715,11 @@ public List<NamedXContentRegistry.Entry> getNamedXContent() {
712715
// ML - Custom metadata
713716
new NamedXContentRegistry.Entry(Metadata.Custom.class, new ParseField("ml"),
714717
parser -> MlMetadata.LENIENT_PARSER.parse(parser, null).build()),
718+
new NamedXContentRegistry.Entry(
719+
Metadata.Custom.class,
720+
new ParseField(ModelAliasMetadata.NAME),
721+
ModelAliasMetadata::fromXContent
722+
),
715723
// ML - Persistent action requests
716724
new NamedXContentRegistry.Entry(PersistentTaskParams.class, new ParseField(MlTasks.DATAFEED_TASK_NAME),
717725
StartDatafeedAction.DatafeedParams::fromXContent),

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java

+7-7
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ protected Reader<Response.TrainedModelStats> getReader() {
191191
public static class Builder {
192192

193193
private long totalModelCount;
194-
private Set<String> expandedIds;
194+
private Map<String, Set<String>> expandedIdsWithAliases;
195195
private Map<String, IngestStats> ingestStatsMap;
196196
private Map<String, InferenceStats> inferenceStatsMap;
197197

@@ -200,13 +200,13 @@ public Builder setTotalModelCount(long totalModelCount) {
200200
return this;
201201
}
202202

203-
public Builder setExpandedIds(Set<String> expandedIds) {
204-
this.expandedIds = expandedIds;
203+
public Builder setExpandedIdsWithAliases(Map<String, Set<String>> expandedIdsWithAliases) {
204+
this.expandedIdsWithAliases = expandedIdsWithAliases;
205205
return this;
206206
}
207207

208-
public Set<String> getExpandedIds() {
209-
return this.expandedIds;
208+
public Map<String, Set<String>> getExpandedIdsWithAliases() {
209+
return this.expandedIdsWithAliases;
210210
}
211211

212212
public Builder setIngestStatsByModelId(Map<String, IngestStats> ingestStatsByModelId) {
@@ -220,8 +220,8 @@ public Builder setInferenceStatsByModelId(Map<String, InferenceStats> infereceSt
220220
}
221221

222222
public Response build() {
223-
List<TrainedModelStats> trainedModelStats = new ArrayList<>(expandedIds.size());
224-
expandedIds.forEach(id -> {
223+
List<TrainedModelStats> trainedModelStats = new ArrayList<>(expandedIdsWithAliases.size());
224+
expandedIdsWithAliases.keySet().forEach(id -> {
225225
IngestStats ingestStats = ingestStatsMap.get(id);
226226
InferenceStats inferenceStats = inferenceStatsMap.get(id);
227227
trainedModelStats.add(new TrainedModelStats(

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java

+26-4
Original file line numberDiff line numberDiff line change
@@ -160,18 +160,25 @@ public String toString() {
160160
public static class Response extends ActionResponse {
161161

162162
private final List<InferenceResults> inferenceResults;
163+
private final String modelId;
163164
private final boolean isLicensed;
164165

165-
public Response(List<InferenceResults> inferenceResults, boolean isLicensed) {
166+
public Response(List<InferenceResults> inferenceResults, String modelId, boolean isLicensed) {
166167
super();
167168
this.inferenceResults = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(inferenceResults, "inferenceResults"));
168169
this.isLicensed = isLicensed;
170+
this.modelId = modelId;
169171
}
170172

171173
public Response(StreamInput in) throws IOException {
172174
super(in);
173175
this.inferenceResults = Collections.unmodifiableList(in.readNamedWriteableList(InferenceResults.class));
174176
this.isLicensed = in.readBoolean();
177+
if (in.getVersion().onOrAfter(Version.V_7_13_0)) {
178+
this.modelId = in.readOptionalString();
179+
} else {
180+
this.modelId = null;
181+
}
175182
}
176183

177184
public List<InferenceResults> getInferenceResults() {
@@ -182,23 +189,32 @@ public boolean isLicensed() {
182189
return isLicensed;
183190
}
184191

192+
public String getModelId() {
193+
return modelId;
194+
}
195+
185196
@Override
186197
public void writeTo(StreamOutput out) throws IOException {
187198
out.writeNamedWriteableList(inferenceResults);
188199
out.writeBoolean(isLicensed);
200+
if (out.getVersion().onOrAfter(Version.V_7_13_0)) {
201+
out.writeOptionalString(modelId);
202+
}
189203
}
190204

191205
@Override
192206
public boolean equals(Object o) {
193207
if (this == o) return true;
194208
if (o == null || getClass() != o.getClass()) return false;
195209
InternalInferModelAction.Response that = (InternalInferModelAction.Response) o;
196-
return isLicensed == that.isLicensed && Objects.equals(inferenceResults, that.inferenceResults);
210+
return isLicensed == that.isLicensed
211+
&& Objects.equals(inferenceResults, that.inferenceResults)
212+
&& Objects.equals(modelId, that.modelId);
197213
}
198214

199215
@Override
200216
public int hashCode() {
201-
return Objects.hash(inferenceResults, isLicensed);
217+
return Objects.hash(inferenceResults, isLicensed, modelId);
202218
}
203219

204220
public static Builder builder() {
@@ -207,6 +223,7 @@ public static Builder builder() {
207223

208224
public static class Builder {
209225
private List<InferenceResults> inferenceResults;
226+
private String modelId;
210227
private boolean isLicensed;
211228

212229
public Builder setInferenceResults(List<InferenceResults> inferenceResults) {
@@ -219,8 +236,13 @@ public Builder setLicensed(boolean licensed) {
219236
return this;
220237
}
221238

239+
public Builder setModelId(String modelId) {
240+
this.modelId = modelId;
241+
return this;
242+
}
243+
222244
public Response build() {
223-
return new Response(inferenceResults, isLicensed);
245+
return new Response(inferenceResults, modelId, isLicensed);
224246
}
225247
}
226248

0 commit comments

Comments
 (0)