Skip to content

Commit e163559

Browse files
authored
[7.x] [ML] Add new include flag to GET inference/<model_id> API for model training metadata (#61922) (#62620)
* [ML] Add new include flag to GET inference/<model_id> API for model training metadata (#61922) Adds new flag include to the get trained models API The flag initially has two valid values: definition, total_feature_importance. Consequently, the old include_model_definition flag is now deprecated. When total_feature_importance is included, the total_feature_importance field is included in the model metadata object. Including definition is the same as previously setting include_model_definition=true. * fixing test * Update x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java
1 parent e1a4a30 commit e163559

File tree

28 files changed

+833
-162
lines changed

28 files changed

+833
-162
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -779,9 +779,9 @@ static Request getTrainedModels(GetTrainedModelsRequest getTrainedModelsRequest)
779779
params.putParam(GetTrainedModelsRequest.DECOMPRESS_DEFINITION,
780780
Boolean.toString(getTrainedModelsRequest.getDecompressDefinition()));
781781
}
782-
if (getTrainedModelsRequest.getIncludeDefinition() != null) {
783-
params.putParam(GetTrainedModelsRequest.INCLUDE_MODEL_DEFINITION,
784-
Boolean.toString(getTrainedModelsRequest.getIncludeDefinition()));
782+
if (getTrainedModelsRequest.getIncludes().isEmpty() == false) {
783+
params.putParam(GetTrainedModelsRequest.INCLUDE,
784+
Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getIncludes()));
785785
}
786786
if (getTrainedModelsRequest.getTags() != null) {
787787
params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags()));

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java

+26-8
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,26 @@
2626
import org.elasticsearch.common.Nullable;
2727

2828
import java.util.Arrays;
29+
import java.util.Collections;
30+
import java.util.HashSet;
2931
import java.util.List;
3032
import java.util.Objects;
3133
import java.util.Optional;
34+
import java.util.Set;
3235

3336
public class GetTrainedModelsRequest implements Validatable {
3437

38+
private static final String DEFINITION = "definition";
39+
private static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
3540
public static final String ALLOW_NO_MATCH = "allow_no_match";
36-
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
3741
public static final String FOR_EXPORT = "for_export";
3842
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
3943
public static final String TAGS = "tags";
44+
public static final String INCLUDE = "include";
4045

4146
private final List<String> ids;
4247
private Boolean allowNoMatch;
43-
private Boolean includeDefinition;
48+
private Set<String> includes = new HashSet<>();
4449
private Boolean decompressDefinition;
4550
private Boolean forExport;
4651
private PageParams pageParams;
@@ -86,19 +91,32 @@ public GetTrainedModelsRequest setPageParams(@Nullable PageParams pageParams) {
8691
return this;
8792
}
8893

89-
public Boolean getIncludeDefinition() {
90-
return includeDefinition;
94+
public Set<String> getIncludes() {
95+
return Collections.unmodifiableSet(includes);
96+
}
97+
98+
public GetTrainedModelsRequest includeDefinition() {
99+
this.includes.add(DEFINITION);
100+
return this;
101+
}
102+
103+
public GetTrainedModelsRequest includeTotalFeatureImportance() {
104+
this.includes.add(TOTAL_FEATURE_IMPORTANCE);
105+
return this;
91106
}
92107

93108
/**
94109
* Whether to include the full model definition.
95110
*
96111
* The full model definition can be very large.
97-
*
112+
* @deprecated Use {@link GetTrainedModelsRequest#includeDefinition()}
98113
* @param includeDefinition If {@code true}, the definition is included.
99114
*/
115+
@Deprecated
100116
public GetTrainedModelsRequest setIncludeDefinition(Boolean includeDefinition) {
101-
this.includeDefinition = includeDefinition;
117+
if (includeDefinition != null && includeDefinition) {
118+
return this.includeDefinition();
119+
}
102120
return this;
103121
}
104122

@@ -173,13 +191,13 @@ public boolean equals(Object o) {
173191
return Objects.equals(ids, other.ids)
174192
&& Objects.equals(allowNoMatch, other.allowNoMatch)
175193
&& Objects.equals(decompressDefinition, other.decompressDefinition)
176-
&& Objects.equals(includeDefinition, other.includeDefinition)
194+
&& Objects.equals(includes, other.includes)
177195
&& Objects.equals(forExport, other.forExport)
178196
&& Objects.equals(pageParams, other.pageParams);
179197
}
180198

181199
@Override
182200
public int hashCode() {
183-
return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition, forExport);
201+
return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includes, forExport);
184202
}
185203
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.elasticsearch.client.ml.inference.trainedmodel.metadata;
21+
22+
import org.elasticsearch.common.Nullable;
23+
import org.elasticsearch.common.ParseField;
24+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
25+
import org.elasticsearch.common.xcontent.ObjectParser;
26+
import org.elasticsearch.common.xcontent.ToXContentObject;
27+
import org.elasticsearch.common.xcontent.XContentBuilder;
28+
import org.elasticsearch.common.xcontent.XContentParseException;
29+
import org.elasticsearch.common.xcontent.XContentParser;
30+
31+
import java.io.IOException;
32+
import java.util.Collections;
33+
import java.util.List;
34+
import java.util.Objects;
35+
36+
public class TotalFeatureImportance implements ToXContentObject {
37+
38+
private static final String NAME = "total_feature_importance";
39+
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
40+
public static final ParseField IMPORTANCE = new ParseField("importance");
41+
public static final ParseField CLASSES = new ParseField("classes");
42+
public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude");
43+
public static final ParseField MIN = new ParseField("min");
44+
public static final ParseField MAX = new ParseField("max");
45+
46+
@SuppressWarnings("unchecked")
47+
public static final ConstructingObjectParser<TotalFeatureImportance, Void> PARSER = new ConstructingObjectParser<>(NAME,
48+
true,
49+
a -> new TotalFeatureImportance((String)a[0], (Importance)a[1], (List<ClassImportance>)a[2]));
50+
51+
static {
52+
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
53+
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), Importance.PARSER, IMPORTANCE);
54+
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), ClassImportance.PARSER, CLASSES);
55+
}
56+
57+
public static TotalFeatureImportance fromXContent(XContentParser parser) {
58+
return PARSER.apply(parser, null);
59+
}
60+
61+
public final String featureName;
62+
public final Importance importance;
63+
public final List<ClassImportance> classImportances;
64+
65+
TotalFeatureImportance(String featureName, @Nullable Importance importance, @Nullable List<ClassImportance> classImportances) {
66+
this.featureName = featureName;
67+
this.importance = importance;
68+
this.classImportances = classImportances == null ? Collections.emptyList() : classImportances;
69+
}
70+
71+
@Override
72+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
73+
builder.startObject();
74+
builder.field(FEATURE_NAME.getPreferredName(), featureName);
75+
if (importance != null) {
76+
builder.field(IMPORTANCE.getPreferredName(), importance);
77+
}
78+
if (classImportances.isEmpty() == false) {
79+
builder.field(CLASSES.getPreferredName(), classImportances);
80+
}
81+
builder.endObject();
82+
return builder;
83+
}
84+
85+
@Override
86+
public boolean equals(Object o) {
87+
if (this == o) return true;
88+
if (o == null || getClass() != o.getClass()) return false;
89+
TotalFeatureImportance that = (TotalFeatureImportance) o;
90+
return Objects.equals(that.importance, importance)
91+
&& Objects.equals(featureName, that.featureName)
92+
&& Objects.equals(classImportances, that.classImportances);
93+
}
94+
95+
@Override
96+
public int hashCode() {
97+
return Objects.hash(featureName, importance, classImportances);
98+
}
99+
100+
public static class Importance implements ToXContentObject {
101+
private static final String NAME = "importance";
102+
103+
public static final ConstructingObjectParser<Importance, Void> PARSER = new ConstructingObjectParser<>(NAME,
104+
true,
105+
a -> new Importance((double)a[0], (double)a[1], (double)a[2]));
106+
107+
static {
108+
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE);
109+
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MIN);
110+
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MAX);
111+
}
112+
113+
private final double meanMagnitude;
114+
private final double min;
115+
private final double max;
116+
117+
public Importance(double meanMagnitude, double min, double max) {
118+
this.meanMagnitude = meanMagnitude;
119+
this.min = min;
120+
this.max = max;
121+
}
122+
123+
@Override
124+
public boolean equals(Object o) {
125+
if (this == o) return true;
126+
if (o == null || getClass() != o.getClass()) return false;
127+
Importance that = (Importance) o;
128+
return Double.compare(that.meanMagnitude, meanMagnitude) == 0 &&
129+
Double.compare(that.min, min) == 0 &&
130+
Double.compare(that.max, max) == 0;
131+
}
132+
133+
@Override
134+
public int hashCode() {
135+
return Objects.hash(meanMagnitude, min, max);
136+
}
137+
138+
@Override
139+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
140+
builder.startObject();
141+
builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
142+
builder.field(MIN.getPreferredName(), min);
143+
builder.field(MAX.getPreferredName(), max);
144+
builder.endObject();
145+
return builder;
146+
}
147+
}
148+
149+
public static class ClassImportance implements ToXContentObject {
150+
private static final String NAME = "total_class_importance";
151+
152+
public static final ParseField CLASS_NAME = new ParseField("class_name");
153+
public static final ParseField IMPORTANCE = new ParseField("importance");
154+
155+
public static final ConstructingObjectParser<ClassImportance, Void> PARSER = new ConstructingObjectParser<>(NAME,
156+
true,
157+
a -> new ClassImportance(a[0], (Importance)a[1]));
158+
159+
static {
160+
PARSER.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
161+
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
162+
return p.text();
163+
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
164+
return p.numberValue();
165+
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
166+
return p.booleanValue();
167+
}
168+
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
169+
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
170+
PARSER.declareObject(ConstructingObjectParser.constructorArg(), Importance.PARSER, IMPORTANCE);
171+
}
172+
173+
public static ClassImportance fromXContent(XContentParser parser) {
174+
return PARSER.apply(parser, null);
175+
}
176+
177+
public final Object className;
178+
public final Importance importance;
179+
180+
ClassImportance(Object className, Importance importance) {
181+
this.className = className;
182+
this.importance = importance;
183+
}
184+
185+
@Override
186+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
187+
builder.startObject();
188+
builder.field(CLASS_NAME.getPreferredName(), className);
189+
builder.field(IMPORTANCE.getPreferredName(), importance);
190+
builder.endObject();
191+
return builder;
192+
}
193+
194+
@Override
195+
public boolean equals(Object o) {
196+
if (this == o) return true;
197+
if (o == null || getClass() != o.getClass()) return false;
198+
ClassImportance that = (ClassImportance) o;
199+
return Objects.equals(that.importance, importance) && Objects.equals(className, that.className);
200+
}
201+
202+
@Override
203+
public int hashCode() {
204+
return Objects.hash(className, importance);
205+
}
206+
207+
}
208+
}

client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -894,7 +894,7 @@ public void testGetTrainedModels() {
894894
GetTrainedModelsRequest getRequest = new GetTrainedModelsRequest(modelId1, modelId2, modelId3)
895895
.setAllowNoMatch(false)
896896
.setDecompressDefinition(true)
897-
.setIncludeDefinition(false)
897+
.includeDefinition()
898898
.setTags("tag1", "tag2")
899899
.setPageParams(new PageParams(100, 300));
900900

@@ -908,7 +908,7 @@ public void testGetTrainedModels() {
908908
hasEntry("allow_no_match", "false"),
909909
hasEntry("decompress_definition", "true"),
910910
hasEntry("tags", "tag1,tag2"),
911-
hasEntry("include_model_definition", "false")
911+
hasEntry("include", "definition")
912912
));
913913
assertNull(request.getEntity());
914914
}

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

+10-3
Original file line numberDiff line numberDiff line change
@@ -2257,7 +2257,10 @@ public void testGetTrainedModels() throws Exception {
22572257

22582258
{
22592259
GetTrainedModelsResponse getTrainedModelsResponse = execute(
2260-
new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(true).setIncludeDefinition(true),
2260+
new GetTrainedModelsRequest(modelIdPrefix + 0)
2261+
.setDecompressDefinition(true)
2262+
.includeDefinition()
2263+
.includeTotalFeatureImportance(),
22612264
machineLearningClient::getTrainedModels,
22622265
machineLearningClient::getTrainedModelsAsync);
22632266

@@ -2268,7 +2271,10 @@ public void testGetTrainedModels() throws Exception {
22682271
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0));
22692272

22702273
getTrainedModelsResponse = execute(
2271-
new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(true),
2274+
new GetTrainedModelsRequest(modelIdPrefix + 0)
2275+
.setDecompressDefinition(false)
2276+
.includeTotalFeatureImportance()
2277+
.includeDefinition(),
22722278
machineLearningClient::getTrainedModels,
22732279
machineLearningClient::getTrainedModelsAsync);
22742280

@@ -2279,7 +2285,8 @@ public void testGetTrainedModels() throws Exception {
22792285
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0));
22802286

22812287
getTrainedModelsResponse = execute(
2282-
new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(false),
2288+
new GetTrainedModelsRequest(modelIdPrefix + 0)
2289+
.setDecompressDefinition(false),
22832290
machineLearningClient::getTrainedModels,
22842291
machineLearningClient::getTrainedModelsAsync);
22852292
assertThat(getTrainedModelsResponse.getCount(), equalTo(1L));

client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

+6-5
Original file line numberDiff line numberDiff line change
@@ -3694,11 +3694,12 @@ public void testGetTrainedModels() throws Exception {
36943694
// tag::get-trained-models-request
36953695
GetTrainedModelsRequest request = new GetTrainedModelsRequest("my-trained-model") // <1>
36963696
.setPageParams(new PageParams(0, 1)) // <2>
3697-
.setIncludeDefinition(false) // <3>
3698-
.setDecompressDefinition(false) // <4>
3699-
.setAllowNoMatch(true) // <5>
3700-
.setTags("regression") // <6>
3701-
.setForExport(false); // <7>
3697+
.includeDefinition() // <3>
3698+
.includeTotalFeatureImportance() // <4>
3699+
.setDecompressDefinition(false) // <5>
3700+
.setAllowNoMatch(true) // <6>
3701+
.setTags("regression") // <7>
3702+
.setForExport(false); // <8>
37023703
// end::get-trained-models-request
37033704
request.setTags((List<String>)null);
37043705

0 commit comments

Comments
 (0)