Skip to content

Commit 1c1d451

Browse files
authored
[ML][Inference] don't return inflated definition when storing trained models (#52573)
When `PUT` is called to store a trained model, it is useful to return the newly create model config. But, it is NOT useful to return the inflated definition. These definitions can be large and returning the inflated definition causes undo work on the server and client side.
1 parent 1326fdb commit 1c1d451

File tree

7 files changed

+99
-9
lines changed

7 files changed

+99
-9
lines changed

docs/java-rest/high-level/ml/put-trained-model.asciidoc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ include::../execution.asciidoc[]
4646
==== Response
4747

4848
The returned +{response}+ contains the newly created trained model.
49+
The +{response}+ will omit the model definition as a precaution against
50+
streaming large model definitions back to the client.
4951

5052
["source","java",subs="attributes,callouts,macros"]
5153
--------------------------------------------------

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
280280
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
281281
// We don't store the definition in the same document as the configuration
282282
if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) {
283-
if (params.paramAsBoolean(DECOMPRESS_DEFINITION, true)) {
283+
if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) {
284284
builder.field(DEFINITION.getPreferredName(), definition);
285285
} else {
286286
builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedString());
@@ -371,6 +371,9 @@ public Builder(TrainedModelConfig config) {
371371
this.tags = config.getTags();
372372
this.metadata = config.getMetadata();
373373
this.input = config.getInput();
374+
this.estimatedOperations = config.estimatedOperations;
375+
this.estimatedHeapMemory = config.estimatedHeapMemory;
376+
this.licenseLevel = config.licenseLevel.description();
374377
}
375378

376379
public Builder setModelId(String modelId) {

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,21 +143,21 @@ public void testToXContentWithParams() throws IOException {
143143
"platinum");
144144

145145
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
146-
assertThat(reference.utf8ToString(), containsString("\"definition\""));
146+
assertThat(reference.utf8ToString(), containsString("\"compressed_definition\""));
147147

148148
reference = XContentHelper.toXContent(config,
149149
XContentType.JSON,
150150
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")),
151151
false);
152152
assertThat(reference.utf8ToString(), not(containsString("definition")));
153+
assertThat(reference.utf8ToString(), not(containsString("compressed_definition")));
153154

154155
reference = XContentHelper.toXContent(config,
155156
XContentType.JSON,
156-
new ToXContent.MapParams(Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, "false")),
157+
new ToXContent.MapParams(Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, "true")),
157158
false);
158-
assertThat(reference.utf8ToString(), not(containsString("\"definition\"")));
159-
assertThat(reference.utf8ToString(), containsString("compressed_definition"));
160-
assertThat(reference.utf8ToString(), containsString(lazyModelDefinition.getCompressedString()));
159+
assertThat(reference.utf8ToString(), containsString("\"definition\""));
160+
assertThat(reference.utf8ToString(), not(containsString("compressed_definition")));
161161
}
162162

163163
public void testParseWithBothDefinitionAndCompressedSupplied() throws IOException {
@@ -180,7 +180,7 @@ public void testParseWithBothDefinitionAndCompressedSupplied() throws IOExceptio
180180
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
181181
Map<String, Object> objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2();
182182

183-
objectMap.put(TrainedModelConfig.COMPRESSED_DEFINITION.getPreferredName(), lazyModelDefinition.getCompressedString());
183+
objectMap.put(TrainedModelConfig.DEFINITION.getPreferredName(), config.getModelDefinition());
184184

185185
try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(objectMap);
186186
XContentParser parser = XContentType.JSON

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ public void testGetTrainedModels() throws IOException {
9393
assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\""));
9494
assertThat(response, containsString("\"estimated_heap_memory_usage\""));
9595
assertThat(response, containsString("\"definition\""));
96+
assertThat(response, not(containsString("\"compressed_definition\"")));
9697
assertThat(response, containsString("\"count\":1"));
9798

9899
getModel = client().performRequest(new Request("GET",

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,10 @@ protected void masterOperation(Task task,
108108

109109
ActionListener<Void> tagsModelIdCheckListener = ActionListener.wrap(
110110
r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap(
111-
storedConfig -> listener.onResponse(new PutTrainedModelAction.Response(trainedModelConfig)),
111+
bool -> {
112+
TrainedModelConfig configToReturn = new TrainedModelConfig.Builder(trainedModelConfig).clearDefinition().build();
113+
listener.onResponse(new PutTrainedModelAction.Response(configToReturn));
114+
},
112115
listener::onFailure
113116
)),
114117
listener::onFailure

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,14 @@
88
import org.elasticsearch.client.node.NodeClient;
99
import org.elasticsearch.cluster.metadata.MetaData;
1010
import org.elasticsearch.common.Strings;
11+
import org.elasticsearch.common.xcontent.ToXContent;
12+
import org.elasticsearch.common.xcontent.ToXContentObject;
13+
import org.elasticsearch.common.xcontent.XContentBuilder;
1114
import org.elasticsearch.rest.BaseRestHandler;
15+
import org.elasticsearch.rest.BytesRestResponse;
16+
import org.elasticsearch.rest.RestChannel;
1217
import org.elasticsearch.rest.RestRequest;
18+
import org.elasticsearch.rest.RestResponse;
1319
import org.elasticsearch.rest.action.RestToXContentListener;
1420
import org.elasticsearch.xpack.core.action.util.PageParams;
1521
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
@@ -18,7 +24,9 @@
1824

1925
import java.io.IOException;
2026
import java.util.Collections;
27+
import java.util.HashMap;
2128
import java.util.List;
29+
import java.util.Map;
2230
import java.util.Set;
2331

2432
import static java.util.Arrays.asList;
@@ -34,6 +42,8 @@ public List<Route> routes() {
3442
new Route(GET, MachineLearning.BASE_PATH + "inference"));
3543
}
3644

45+
private static final Map<String, String> DEFAULT_TO_XCONTENT_VALUES =
46+
Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, Boolean.toString(true));
3747
@Override
3848
public String getName() {
3949
return "ml_get_trained_models_action";
@@ -56,12 +66,33 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
5666
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
5767
}
5868
request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources()));
59-
return channel -> client.execute(GetTrainedModelsAction.INSTANCE, request, new RestToXContentListener<>(channel));
69+
return channel -> client.execute(GetTrainedModelsAction.INSTANCE,
70+
request,
71+
new RestToXContentListenerWithDefaultValues<>(channel, DEFAULT_TO_XCONTENT_VALUES));
6072
}
6173

6274
@Override
6375
protected Set<String> responseParams() {
6476
return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION);
6577
}
6678

79+
private static class RestToXContentListenerWithDefaultValues<T extends ToXContentObject> extends RestToXContentListener<T> {
80+
private final Map<String, String> defaultToXContentParamValues;
81+
82+
private RestToXContentListenerWithDefaultValues(RestChannel channel, Map<String, String> defaultToXContentParamValues) {
83+
super(channel);
84+
this.defaultToXContentParamValues = defaultToXContentParamValues;
85+
}
86+
87+
@Override
88+
public RestResponse buildResponse(T response, XContentBuilder builder) throws Exception {
89+
assert response.isFragment() == false; //would be nice if we could make default methods final
90+
Map<String, String> params = new HashMap<>(channel.request().params());
91+
defaultToXContentParamValues.forEach((k, v) ->
92+
params.computeIfAbsent(k, defaultToXContentParamValues::get)
93+
);
94+
response.toXContent(builder, new ToXContent.MapParams(params));
95+
return new BytesRestResponse(getStatus(response), builder);
96+
}
97+
}
6798
}

x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,3 +460,53 @@ setup:
460460
}
461461
}
462462
}
463+
---
464+
"Test put model":
465+
- do:
466+
ml.put_trained_model:
467+
model_id: my-regression-model
468+
body: >
469+
{
470+
"description": "model for tests",
471+
"input": {"field_names": ["field1", "field2"]},
472+
"definition": {
473+
"preprocessors": [],
474+
"trained_model": {
475+
"ensemble": {
476+
"target_type": "regression",
477+
"trained_models": [
478+
{
479+
"tree": {
480+
"feature_names": ["field1", "field2"],
481+
"tree_structure": [
482+
{"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2},
483+
{"node_index": 1, "leaf_value": 0},
484+
{"node_index": 2, "leaf_value": 1}
485+
],
486+
"target_type": "regression"
487+
}
488+
},
489+
{
490+
"tree": {
491+
"feature_names": ["field1", "field2"],
492+
"tree_structure": [
493+
{"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2},
494+
{"node_index": 1, "leaf_value": 0},
495+
{"node_index": 2, "leaf_value": 1}
496+
],
497+
"target_type": "regression"
498+
}
499+
}
500+
]
501+
}
502+
}
503+
}
504+
}
505+
- match: { model_id: my-regression-model }
506+
- match: { estimated_operations: 6 }
507+
- is_false: definition
508+
- is_false: compressed_definition
509+
- is_true: license_level
510+
- is_true: create_time
511+
- is_true: version
512+
- is_true: estimated_heap_memory_usage_bytes

0 commit comments

Comments
 (0)