Skip to content

Commit 18ec001

Browse files
committed
[ML][Inference] add tags url param to GET (elastic#51330)
Adds a new URL parameter, `tags` to the GET _ml/inference/<model_id> endpoint. This parameter allows the list of models to be further reduced to those who contain all the provided tags.
1 parent ded7407 commit 18ec001

File tree

16 files changed

+177
-12
lines changed

16 files changed

+177
-12
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,9 @@ static Request getTrainedModels(GetTrainedModelsRequest getTrainedModelsRequest)
755755
params.putParam(GetTrainedModelsRequest.INCLUDE_MODEL_DEFINITION,
756756
Boolean.toString(getTrainedModelsRequest.getIncludeDefinition()));
757757
}
758+
if (getTrainedModelsRequest.getTags() != null) {
759+
params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags()));
760+
}
758761
Request request = new Request(HttpGet.METHOD_NAME, endpoint);
759762
request.addParameters(params.asMap());
760763
return request;

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.client.Validatable;
2323
import org.elasticsearch.client.ValidationException;
2424
import org.elasticsearch.client.core.PageParams;
25+
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
2526
import org.elasticsearch.common.Nullable;
2627

2728
import java.util.Arrays;
@@ -34,12 +35,14 @@ public class GetTrainedModelsRequest implements Validatable {
3435
public static final String ALLOW_NO_MATCH = "allow_no_match";
3536
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
3637
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
38+
public static final String TAGS = "tags";
3739

3840
private final List<String> ids;
3941
private Boolean allowNoMatch;
4042
private Boolean includeDefinition;
4143
private Boolean decompressDefinition;
4244
private PageParams pageParams;
45+
private List<String> tags;
4346

4447
/**
4548
* Helper method to create a request that will get ALL TrainedModelConfigs
@@ -111,6 +114,29 @@ public GetTrainedModelsRequest setDecompressDefinition(Boolean decompressDefinit
111114
return this;
112115
}
113116

117+
public List<String> getTags() {
118+
return tags;
119+
}
120+
121+
/**
122+
* The tags that the trained model must match. These correspond to {@link TrainedModelConfig#getTags()}.
123+
*
124+
* The models returned will match ALL tags supplied.
125+
* If none are provided, only the provided ids are used to find models
126+
* @param tags The tags to match when finding models
127+
*/
128+
public GetTrainedModelsRequest setTags(List<String> tags) {
129+
this.tags = tags;
130+
return this;
131+
}
132+
133+
/**
134+
* See {@link GetTrainedModelsRequest#setTags(List)}
135+
*/
136+
public GetTrainedModelsRequest setTags(String... tags) {
137+
return setTags(Arrays.asList(tags));
138+
}
139+
114140
@Override
115141
public Optional<ValidationException> validate() {
116142
if (ids == null || ids.isEmpty()) {

client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,7 @@ public void testGetTrainedModels() {
834834
.setAllowNoMatch(false)
835835
.setDecompressDefinition(true)
836836
.setIncludeDefinition(false)
837+
.setTags("tag1", "tag2")
837838
.setPageParams(new PageParams(100, 300));
838839

839840
Request request = MLRequestConverters.getTrainedModels(getRequest);
@@ -845,6 +846,7 @@ public void testGetTrainedModels() {
845846
hasEntry("size", "300"),
846847
hasEntry("allow_no_match", "false"),
847848
hasEntry("decompress_definition", "true"),
849+
hasEntry("tags", "tag1,tag2"),
848850
hasEntry("include_model_definition", "false")
849851
));
850852
assertNull(request.getEntity());

client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3587,8 +3587,10 @@ public void testGetTrainedModels() throws Exception {
35873587
.setPageParams(new PageParams(0, 1)) // <2>
35883588
.setIncludeDefinition(false) // <3>
35893589
.setDecompressDefinition(false) // <4>
3590-
.setAllowNoMatch(true); // <5>
3590+
.setAllowNoMatch(true) // <5>
3591+
.setTags("regression"); // <6>
35913592
// end::get-trained-models-request
3593+
request.setTags((List<String>)null);
35923594

35933595
// tag::get-trained-models-execute
35943596
GetTrainedModelsResponse response = client.machineLearning().getTrainedModels(request, RequestOptions.DEFAULT);

docs/java-rest/high-level/ml/get-trained-models.asciidoc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ include-tagged::{doc-tests-file}[{api}-request]
2929
<5> Allow empty response if no Trained Models match the provided ID patterns.
3030
If false, an error will be thrown if no Trained Models match the
3131
ID patterns.
32+
<6> An optional list of tags used to narrow the model search. A Trained Model
33+
can have many tags or none. The trained models in the response will
34+
contain all the provided tags.
3235

3336
include::../execution.asciidoc[]
3437

docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=include-model-definition]
7474
(Optional, integer)
7575
include::{docdir}/ml/ml-shared.asciidoc[tag=size]
7676

77+
`tags`::
78+
(Optional, string)
79+
include::{docdir}/ml/ml-shared.asciidoc[tag=tags]
7780

7881
[[ml-get-inference-response-codes]]
7982
==== {api-response-codes-title}
@@ -96,4 +99,4 @@ The following example gets configuration information for all the trained models:
9699
--------------------------------------------------
97100
GET _ml/inference/
98101
--------------------------------------------------
99-
// TEST[skip:TBD]
102+
// TEST[skip:TBD]

docs/reference/ml/ml-shared.asciidoc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,12 @@ to `false`. When `true`, only a single model must match the ID patterns
639639
provided, otherwise a bad request is returned.
640640
end::include-model-definition[]
641641

642+
tag::tags[]
643+
A comma delimited string of tags. A {infer} model can have many tags, or none.
644+
When supplied, only {infer} models that contain all the supplied tags are
645+
returned.
646+
end::tags[]
647+
642648
tag::indices[]
643649
An array of index names. Wildcards are supported. For example:
644650
`["it_ops_metrics", "server*"]`.

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
*/
66
package org.elasticsearch.xpack.core.ml.action;
77

8+
import org.elasticsearch.Version;
89
import org.elasticsearch.action.ActionType;
910
import org.elasticsearch.common.ParseField;
1011
import org.elasticsearch.common.io.stream.StreamInput;
@@ -33,18 +34,26 @@ public static class Request extends AbstractGetResourcesRequest {
3334

3435
public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition");
3536
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
37+
public static final ParseField TAGS = new ParseField("tags");
3638

3739
private final boolean includeModelDefinition;
40+
private final List<String> tags;
3841

39-
public Request(String id, boolean includeModelDefinition) {
42+
public Request(String id, boolean includeModelDefinition, List<String> tags) {
4043
setResourceId(id);
4144
setAllowNoResources(true);
4245
this.includeModelDefinition = includeModelDefinition;
46+
this.tags = tags == null ? Collections.emptyList() : tags;
4347
}
4448

4549
public Request(StreamInput in) throws IOException {
4650
super(in);
4751
this.includeModelDefinition = in.readBoolean();
52+
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
53+
this.tags = in.readStringList();
54+
} else {
55+
this.tags = Collections.emptyList();
56+
}
4857
}
4958

5059
@Override
@@ -56,15 +65,22 @@ public boolean isIncludeModelDefinition() {
5665
return includeModelDefinition;
5766
}
5867

68+
public List<String> getTags() {
69+
return tags;
70+
}
71+
5972
@Override
6073
public void writeTo(StreamOutput out) throws IOException {
6174
super.writeTo(out);
6275
out.writeBoolean(includeModelDefinition);
76+
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
77+
out.writeStringCollection(tags);
78+
}
6379
}
6480

6581
@Override
6682
public int hashCode() {
67-
return Objects.hash(super.hashCode(), includeModelDefinition);
83+
return Objects.hash(super.hashCode(), includeModelDefinition, tags);
6884
}
6985

7086
@Override
@@ -76,7 +92,7 @@ public boolean equals(Object obj) {
7692
return false;
7793
}
7894
Request other = (Request) obj;
79-
return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition;
95+
return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition && Objects.equals(tags, other.tags);
8096
}
8197
}
8298

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCas
1414

1515
@Override
1616
protected Request createTestInstance() {
17-
Request request = new Request(randomAlphaOfLength(20), randomBoolean());
17+
Request request = new Request(randomAlphaOfLength(20),
18+
randomBoolean(),
19+
randomBoolean() ? null :
20+
randomList(10, () -> randomAlphaOfLength(10)));
1821
request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
1922
return request;
2023
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
2121

2222
import java.util.Collections;
23+
import java.util.HashSet;
2324
import java.util.Set;
2425

2526

@@ -70,7 +71,11 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
7071
listener::onFailure
7172
);
7273

73-
provider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idExpansionListener);
74+
provider.expandIds(request.getResourceId(),
75+
request.isAllowNoResources(),
76+
request.getPageParams(),
77+
new HashSet<>(request.getTags()),
78+
idExpansionListener);
7479
}
7580

7681
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
3131

3232
import java.util.ArrayList;
33+
import java.util.Collections;
3334
import java.util.HashMap;
3435
import java.util.Iterator;
3536
import java.util.LinkedHashMap;
@@ -94,7 +95,11 @@ protected void doExecute(Task task,
9495
listener::onFailure
9596
);
9697

97-
trainedModelProvider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idsListener);
98+
trainedModelProvider.expandIds(request.getResourceId(),
99+
request.isAllowNoResources(),
100+
request.getPageParams(),
101+
Collections.emptySet(),
102+
idsListener);
98103
}
99104

100105
static Map<String, IngestStats> inferenceIngestStatsByPipelineId(NodesStatsResponse response,

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
import java.io.InputStream;
7171
import java.net.URL;
7272
import java.util.ArrayList;
73+
import java.util.Collection;
7374
import java.util.Collections;
7475
import java.util.Comparator;
7576
import java.util.HashSet;
@@ -382,14 +383,15 @@ public void deleteTrainedModel(String modelId, ActionListener<Boolean> listener)
382383
public void expandIds(String idExpression,
383384
boolean allowNoResources,
384385
@Nullable PageParams pageParams,
386+
Set<String> tags,
385387
ActionListener<Tuple<Long, Set<String>>> idsListener) {
386388
String[] tokens = Strings.tokenizeToStringArray(idExpression, ",");
387389
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
388390
.sort(SortBuilders.fieldSort(TrainedModelConfig.MODEL_ID.getPreferredName())
389391
// If there are no resources, there might be no mapping for the id field.
390392
// This makes sure we don't get an error if that happens.
391393
.unmappedType("long"))
392-
.query(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName()));
394+
.query(buildExpandIdsQuery(tokens, tags));
393395
if (pageParams != null) {
394396
sourceBuilder.from(pageParams.getFrom()).size(pageParams.getSize());
395397
}
@@ -405,13 +407,23 @@ public void expandIds(String idExpression,
405407
indicesOptions.expandWildcardsClosed(),
406408
indicesOptions))
407409
.source(sourceBuilder);
410+
Set<String> foundResourceIds = new LinkedHashSet<>();
411+
if (tags.isEmpty()) {
412+
foundResourceIds.addAll(matchedResourceIds(tokens));
413+
} else {
414+
for(String resourceId : matchedResourceIds(tokens)) {
415+
// Does the model as a resource have all the tags?
416+
if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
417+
foundResourceIds.add(resourceId);
418+
}
419+
}
420+
}
408421

409422
executeAsyncWithOrigin(client.threadPool().getThreadContext(),
410423
ML_ORIGIN,
411424
searchRequest,
412425
ActionListener.<SearchResponse>wrap(
413426
response -> {
414-
Set<String> foundResourceIds = new LinkedHashSet<>(matchedResourceIds(tokens));
415427
long totalHitCount = response.getHits().getTotalHits().value + foundResourceIds.size();
416428
for (SearchHit hit : response.getHits().getHits()) {
417429
Map<String, Object> docSource = hit.getSourceAsMap();
@@ -434,7 +446,15 @@ public void expandIds(String idExpression,
434446
idsListener::onFailure
435447
),
436448
client::search);
449+
}
437450

451+
static QueryBuilder buildExpandIdsQuery(String[] tokens, Collection<String> tags) {
452+
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery()
453+
.filter(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName()));
454+
for(String tag : tags) {
455+
boolQueryBuilder.filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), tag));
456+
}
457+
return QueryBuilders.constantScoreQuery(boolQueryBuilder);
438458
}
439459

440460
TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefinition) {
@@ -468,7 +488,7 @@ TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefiniti
468488
}
469489
}
470490

471-
private QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) {
491+
private static QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) {
472492
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
473493
.filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME));
474494

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
import org.elasticsearch.xpack.ml.MachineLearning;
1919

2020
import java.io.IOException;
21+
import java.util.Arrays;
2122
import java.util.Collections;
23+
import java.util.List;
2224
import java.util.Set;
2325

2426
import static org.elasticsearch.rest.RestRequest.Method.GET;
@@ -47,7 +49,8 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
4749
GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION.getPreferredName(),
4850
false
4951
);
50-
GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition);
52+
List<String> tags = Arrays.asList(restRequest.paramAsStringArray(TrainedModelConfig.TAGS.getPreferredName(), Strings.EMPTY_ARRAY));
53+
GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition, tags);
5154
if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) {
5255
request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM),
5356
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,24 @@
99
import org.elasticsearch.action.support.PlainActionFuture;
1010
import org.elasticsearch.client.Client;
1111
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
12+
import org.elasticsearch.index.query.BoolQueryBuilder;
13+
import org.elasticsearch.index.query.ConstantScoreQueryBuilder;
14+
import org.elasticsearch.index.query.QueryBuilder;
15+
import org.elasticsearch.index.query.TermQueryBuilder;
1216
import org.elasticsearch.test.ESTestCase;
1317
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
1418
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
1519
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
1620
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
1721

22+
import java.util.Arrays;
23+
1824
import static org.hamcrest.Matchers.equalTo;
25+
import static org.hamcrest.Matchers.instanceOf;
1926
import static org.hamcrest.Matchers.is;
2027
import static org.hamcrest.Matchers.not;
2128
import static org.hamcrest.Matchers.nullValue;
29+
import static org.hamcrest.Matchers.oneOf;
2230
import static org.mockito.Mockito.mock;
2331

2432
public class TrainedModelProviderTests extends ESTestCase {
@@ -60,6 +68,24 @@ public void testGetModelThatExistsAsResource() throws Exception {
6068
}
6169
}
6270

71+
public void testExpandIdsQuery() {
72+
QueryBuilder queryBuilder = TrainedModelProvider.buildExpandIdsQuery(new String[]{"model*", "trained_mode"},
73+
Arrays.asList("tag1", "tag2"));
74+
assertThat(queryBuilder, is(instanceOf(ConstantScoreQueryBuilder.class)));
75+
76+
QueryBuilder innerQuery = ((ConstantScoreQueryBuilder)queryBuilder).innerQuery();
77+
assertThat(innerQuery, is(instanceOf(BoolQueryBuilder.class)));
78+
79+
((BoolQueryBuilder)innerQuery).filter().forEach(qb -> {
80+
if (qb instanceof TermQueryBuilder) {
81+
assertThat(((TermQueryBuilder)qb).fieldName(), equalTo(TrainedModelConfig.TAGS.getPreferredName()));
82+
assertThat(((TermQueryBuilder)qb).value(), is(oneOf("tag1", "tag2")));
83+
return;
84+
}
85+
assertThat(qb, is(instanceOf(BoolQueryBuilder.class)));
86+
});
87+
}
88+
6389
public void testGetModelThatExistsAsResourceButIsMissing() {
6490
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
6591
ElasticsearchException ex = expectThrows(ElasticsearchException.class,

0 commit comments

Comments
 (0)