From 5b5d2127ec1ea1eefb7e95f45945f614a8df6642 Mon Sep 17 00:00:00 2001
From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com>
Date: Thu, 14 Nov 2019 12:55:42 -0500
Subject: [PATCH 1/6] [ML][Inference] PUT API
---
.../client/MLRequestConverters.java | 11 +
.../client/MachineLearningClient.java | 44 ++++
.../client/ml/PutTrainedModelRequest.java | 66 ++++++
.../client/ml/PutTrainedModelResponse.java | 63 ++++++
.../ml/inference/TrainedModelConfig.java | 19 +-
.../ml/inference/TrainedModelInput.java | 5 +
.../client/MLRequestConvertersTests.java | 17 ++
.../client/MachineLearningIT.java | 85 +++-----
.../MlClientDocumentationIT.java | 89 ++++++--
.../ml/PutTrainedModelActionRequestTests.java | 52 +++++
.../PutTrainedModelActionResponseTests.java | 52 +++++
.../ml/inference/TrainedModelConfigTests.java | 35 ++--
.../trainedmodel/ensemble/EnsembleTests.java | 14 +-
.../trainedmodel/tree/TreeTests.java | 2 +-
.../high-level/ml/put-trained-model.asciidoc | 51 +++++
.../high-level/supported-apis.asciidoc | 2 +
.../core/ml/action/PutTrainedModelAction.java | 137 +++++++++++++
.../core/ml/inference/TrainedModelConfig.java | 117 ++++++++---
.../ml/inference/TrainedModelDefinition.java | 6 +
.../xpack/core/ml/job/messages/Messages.java | 4 +
.../PutTrainedModelActionRequestTests.java | 45 +++++
.../PutTrainedModelActionResponseTests.java | 45 +++++
.../ml/inference/TrainedModelConfigTests.java | 34 ++--
.../ml/qa/ml-with-security/build.gradle | 1 -
.../ml/integration/InferenceIngestIT.java | 81 ++++----
.../xpack/ml/integration/TrainedModelIT.java | 129 ++++++------
.../xpack/ml/MachineLearning.java | 7 +-
.../TransportPutTrainedModelAction.java | 190 ++++++++++++++++++
.../persistence/TrainedModelProvider.java | 2 +
.../inference/RestPutTrainedModelAction.java | 42 ++++
.../MachineLearningLicensingTests.java | 106 +++-------
.../integration/TrainedModelProviderIT.java | 1 -
.../api/ml.put_trained_model.json | 28 +++
.../rest-api-spec/test/ml/inference_crud.yml | 153 ++++++++++++--
34 files changed, 1383 insertions(+), 352 deletions(-)
create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelRequest.java
create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelResponse.java
create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionRequestTests.java
create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionResponseTests.java
create mode 100644 docs/java-rest/high-level/ml/put-trained-model.asciidoc
create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java
create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionRequestTests.java
create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionResponseTests.java
create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java
create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java
create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model.json
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java
index 4967d8091c961..2e077f547e34f 100644
--- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java
@@ -73,6 +73,7 @@
import org.elasticsearch.client.ml.PutDatafeedRequest;
import org.elasticsearch.client.ml.PutFilterRequest;
import org.elasticsearch.client.ml.PutJobRequest;
+import org.elasticsearch.client.ml.PutTrainedModelRequest;
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest;
@@ -792,6 +793,16 @@ static Request deleteTrainedModel(DeleteTrainedModelRequest deleteRequest) {
return new Request(HttpDelete.METHOD_NAME, endpoint);
}
+ static Request putTrainedModel(PutTrainedModelRequest putTrainedModelRequest) throws IOException {
+ String endpoint = new EndpointBuilder()
+ .addPathPartAsIs("_ml", "inference")
+ .addPathPart(putTrainedModelRequest.getTrainedModelConfig().getModelId())
+ .build();
+ Request request = new Request(HttpPut.METHOD_NAME, endpoint);
+ request.setEntity(createEntity(putTrainedModelRequest, REQUEST_BODY_CONTENT_TYPE));
+ return request;
+ }
+
static Request putFilter(PutFilterRequest putFilterRequest) throws IOException {
String endpoint = new EndpointBuilder()
.addPathPartAsIs("_ml")
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java
index 0a71b8ddb0172..3c0bcfd9230ea 100644
--- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java
@@ -100,6 +100,8 @@
import org.elasticsearch.client.ml.PutFilterResponse;
import org.elasticsearch.client.ml.PutJobRequest;
import org.elasticsearch.client.ml.PutJobResponse;
+import org.elasticsearch.client.ml.PutTrainedModelRequest;
+import org.elasticsearch.client.ml.PutTrainedModelResponse;
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
import org.elasticsearch.client.ml.RevertModelSnapshotResponse;
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
@@ -2340,6 +2342,48 @@ public Cancellable getTrainedModelsAsync(GetTrainedModelsRequest request,
Collections.emptySet());
}
+ /**
+ * Put trained model config
+ *
+ * For additional info
+ * see
+ * PUT Trained Model Config documentation
+ *
+ * @param request The {@link PutTrainedModelRequest}
+ * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
+ * @return {@link PutTrainedModelResponse} response object
+ */
+ public PutTrainedModelResponse putTrainedModel(PutTrainedModelRequest request, RequestOptions options) throws IOException {
+ return restHighLevelClient.performRequestAndParseEntity(request,
+ MLRequestConverters::putTrainedModel,
+ options,
+ PutTrainedModelResponse::fromXContent,
+ Collections.emptySet());
+ }
+
+ /**
+ * Get trained model config asynchronously and notifies listener upon completion
+ *
+ * For additional info
+ * see
+ * PUT Trained Model Config documentation
+ *
+ * @param request The {@link PutTrainedModelRequest}
+ * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
+ * @param listener Listener to be notified upon request completion
+ * @return cancellable that may be used to cancel the request
+ */
+ public Cancellable putTrainedModelAsync(PutTrainedModelRequest request,
+ RequestOptions options,
+ ActionListener listener) {
+ return restHighLevelClient.performRequestAsyncAndParseEntity(request,
+ MLRequestConverters::putTrainedModel,
+ options,
+ PutTrainedModelResponse::fromXContent,
+ listener,
+ Collections.emptySet());
+ }
+
/**
* Gets trained model stats
*
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelRequest.java
new file mode 100644
index 0000000000000..780ec31771baa
--- /dev/null
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelRequest.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.client.Validatable;
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Objects;
+
+
+public class PutTrainedModelRequest implements Validatable, ToXContentObject {
+
+ private final TrainedModelConfig config;
+
+ public PutTrainedModelRequest(TrainedModelConfig config) {
+ this.config = config;
+ }
+
+ public TrainedModelConfig getTrainedModelConfig() {
+ return config;
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+ return config.toXContent(builder, params);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ PutTrainedModelRequest request = (PutTrainedModelRequest) o;
+ return Objects.equals(config, request.config);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(config);
+ }
+
+ @Override
+ public final String toString() {
+ return Strings.toString(config);
+ }
+}
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelResponse.java
new file mode 100644
index 0000000000000..3bc81f1812940
--- /dev/null
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelResponse.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Objects;
+
+
+public class PutTrainedModelResponse implements ToXContentObject {
+
+ private final TrainedModelConfig trainedModelConfig;
+
+ public static PutTrainedModelResponse fromXContent(XContentParser parser) throws IOException {
+ return new PutTrainedModelResponse(TrainedModelConfig.PARSER.parse(parser, null).build());
+ }
+
+ public PutTrainedModelResponse(TrainedModelConfig trainedModelConfig) {
+ this.trainedModelConfig = trainedModelConfig;
+ }
+
+ public TrainedModelConfig getResponse() {
+ return trainedModelConfig;
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ return trainedModelConfig.toXContent(builder, params);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ PutTrainedModelResponse response = (PutTrainedModelResponse) o;
+ return Objects.equals(trainedModelConfig, response.trainedModelConfig);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(trainedModelConfig);
+ }
+}
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java
index 23eb01fb3b153..9d2b323cf4880 100644
--- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java
@@ -30,6 +30,7 @@
import java.io.IOException;
import java.time.Instant;
+import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@@ -111,7 +112,7 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx
this.modelId = modelId;
this.createdBy = createdBy;
this.version = version;
- this.createTime = Instant.ofEpochMilli(createTime.toEpochMilli());
+ this.createTime = createTime == null ? null : Instant.ofEpochMilli(createTime.toEpochMilli());
this.definition = definition;
this.compressedDefinition = compressedDefinition;
this.description = description;
@@ -293,12 +294,12 @@ public Builder setModelId(String modelId) {
return this;
}
- public Builder setCreatedBy(String createdBy) {
+ private Builder setCreatedBy(String createdBy) {
this.createdBy = createdBy;
return this;
}
- public Builder setVersion(Version version) {
+ private Builder setVersion(Version version) {
this.version = version;
return this;
}
@@ -312,7 +313,7 @@ public Builder setDescription(String description) {
return this;
}
- public Builder setCreateTime(Instant createTime) {
+ private Builder setCreateTime(Instant createTime) {
this.createTime = createTime;
return this;
}
@@ -322,6 +323,10 @@ public Builder setTags(List tags) {
return this;
}
+ public Builder setTags(String... tags) {
+ return setTags(Arrays.asList(tags));
+ }
+
public Builder setMetadata(Map metadata) {
this.metadata = metadata;
return this;
@@ -347,17 +352,17 @@ public Builder setInput(TrainedModelInput input) {
return this;
}
- public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
+ private Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
this.estimatedHeapMemory = estimatedHeapMemory;
return this;
}
- public Builder setEstimatedOperations(Long estimatedOperations) {
+ private Builder setEstimatedOperations(Long estimatedOperations) {
this.estimatedOperations = estimatedOperations;
return this;
}
- public Builder setLicenseLevel(String licenseLevel) {
+ private Builder setLicenseLevel(String licenseLevel) {
this.licenseLevel = licenseLevel;
return this;
}
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelInput.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelInput.java
index 10f849cac481a..9b19323023d81 100644
--- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelInput.java
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelInput.java
@@ -25,6 +25,7 @@
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
+import java.util.Arrays;
import java.util.List;
import java.util.Objects;
@@ -48,6 +49,10 @@ public TrainedModelInput(List fieldNames) {
this.fieldNames = fieldNames;
}
+ public TrainedModelInput(String... fieldNames) {
+ this(Arrays.asList(fieldNames));
+ }
+
public static TrainedModelInput fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java
index 475dda254448f..34fdda50bd1e1 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java
@@ -71,6 +71,7 @@
import org.elasticsearch.client.ml.PutDatafeedRequest;
import org.elasticsearch.client.ml.PutFilterRequest;
import org.elasticsearch.client.ml.PutJobRequest;
+import org.elasticsearch.client.ml.PutTrainedModelRequest;
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest;
@@ -91,6 +92,8 @@
import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.client.ml.inference.TrainedModelConfigTests;
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
import org.elasticsearch.client.ml.job.config.Detector;
import org.elasticsearch.client.ml.job.config.Job;
@@ -874,6 +877,20 @@ public void testDeleteTrainedModel() {
assertNull(request.getEntity());
}
+ public void testPutTrainedModel() throws IOException {
+ TrainedModelConfig trainedModelConfig = TrainedModelConfigTests.createTestTrainedModelConfig();
+ PutTrainedModelRequest putTrainedModelRequest = new PutTrainedModelRequest(trainedModelConfig);
+
+ Request request = MLRequestConverters.putTrainedModel(putTrainedModelRequest);
+
+ assertEquals(HttpPut.METHOD_NAME, request.getMethod());
+ assertThat(request.getEndpoint(), equalTo("/_ml/inference/" + trainedModelConfig.getModelId()));
+ try (XContentParser parser = createParser(JsonXContent.jsonXContent, request.getEntity().getContent())) {
+ TrainedModelConfig parsedTrainedModelConfig = TrainedModelConfig.PARSER.apply(parser, null).build();
+ assertThat(parsedTrainedModelConfig, equalTo(trainedModelConfig));
+ }
+ }
+
public void testPutFilter() throws IOException {
MlFilter filter = MlFilterTests.createRandomBuilder("foo").build();
PutFilterRequest putFilterRequest = new PutFilterRequest(filter);
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
index 547521a089cc6..853103ff1f855 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
@@ -101,6 +101,8 @@
import org.elasticsearch.client.ml.PutFilterResponse;
import org.elasticsearch.client.ml.PutJobRequest;
import org.elasticsearch.client.ml.PutJobResponse;
+import org.elasticsearch.client.ml.PutTrainedModelRequest;
+import org.elasticsearch.client.ml.PutTrainedModelResponse;
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
import org.elasticsearch.client.ml.RevertModelSnapshotResponse;
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
@@ -149,6 +151,7 @@
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
+import org.elasticsearch.client.ml.inference.TrainedModelInput;
import org.elasticsearch.client.ml.inference.TrainedModelStats;
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
@@ -162,14 +165,11 @@
import org.elasticsearch.client.ml.job.process.ModelSnapshot;
import org.elasticsearch.client.ml.job.stats.JobStats;
import org.elasticsearch.common.bytes.BytesArray;
-import org.elasticsearch.common.bytes.BytesReference;
-import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
-import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
@@ -178,11 +178,9 @@
import org.junit.After;
import java.io.IOException;
-import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
@@ -190,7 +188,6 @@
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
-import java.util.zip.GZIPOutputStream;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.anyOf;
@@ -2192,6 +2189,25 @@ public void testGetTrainedModels() throws Exception {
}
}
+ public void testPutTrainedModel() throws Exception {
+ String modelId = "test-put-trained-model";
+
+ MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
+
+ TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
+ TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
+ .setDefinition(definition)
+ .setModelId(modelId)
+ .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
+ .setDescription("test model")
+ .build();
+ PutTrainedModelResponse putTrainedModelResponse = execute(new PutTrainedModelRequest(trainedModelConfig),
+ machineLearningClient::putTrainedModel,
+ machineLearningClient::putTrainedModelAsync);
+ TrainedModelConfig createdModel = putTrainedModelResponse.getResponse();
+ assertThat(createdModel.getModelId(), equalTo(modelId));
+ }
+
public void testGetTrainedModelsStats() throws Exception {
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
String modelIdPrefix = "a-get-trained-model-stats-";
@@ -2474,56 +2490,13 @@ private void openJob(Job job) throws IOException {
private void putTrainedModel(String modelId) throws IOException {
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
- highLevelClient().index(
- new IndexRequest(".ml-inference-000001")
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .source(modelConfigString(modelId), XContentType.JSON)
- .id(modelId),
- RequestOptions.DEFAULT);
-
- highLevelClient().index(
- new IndexRequest(".ml-inference-000001")
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .source(modelDocString(compressDefinition(definition), modelId), XContentType.JSON)
- .id("trained_model_definition_doc-" + modelId + "-0"),
- RequestOptions.DEFAULT);
- }
-
- private String compressDefinition(TrainedModelDefinition definition) throws IOException {
- BytesReference reference = XContentHelper.toXContent(definition, XContentType.JSON, false);
- BytesStreamOutput out = new BytesStreamOutput();
- try (OutputStream compressedOutput = new GZIPOutputStream(out, 4096)) {
- reference.writeTo(compressedOutput);
- }
- return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
- }
-
- private static String modelConfigString(String modelId) {
- return "{\n" +
- " \"doc_type\": \"trained_model_config\",\n" +
- " \"model_id\": \"" + modelId + "\",\n" +
- " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
- " \"description\": \"test model\",\n" +
- " \"version\": \"7.6.0\",\n" +
- " \"license_level\": \"platinum\",\n" +
- " \"created_by\": \"ml_test\",\n" +
- " \"estimated_heap_memory_usage_bytes\": 0," +
- " \"estimated_operations\": 0," +
- " \"created_time\": 0\n" +
- "}";
- }
-
- private static String modelDocString(String compressedDefinition, String modelId) {
- return "" +
- "{" +
- "\"model_id\": \"" + modelId + "\",\n" +
- "\"doc_num\": 0,\n" +
- "\"doc_type\": \"trained_model_definition_doc\",\n" +
- " \"compression_version\": " + 1 + ",\n" +
- " \"total_definition_length\": " + compressedDefinition.length() + ",\n" +
- " \"definition_length\": " + compressedDefinition.length() + ",\n" +
- "\"definition\": \"" + compressedDefinition + "\"\n" +
- "}";
+ TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
+ .setDefinition(definition)
+ .setModelId(modelId)
+ .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
+ .setDescription("test model")
+ .build();
+ highLevelClient().machineLearning().putTrainedModel(new PutTrainedModelRequest(trainedModelConfig), RequestOptions.DEFAULT);
}
private void waitForJobToClose(String jobId) throws Exception {
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
index 37ae59e9b992a..4ee22e6289db0 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
@@ -114,6 +114,8 @@
import org.elasticsearch.client.ml.PutFilterResponse;
import org.elasticsearch.client.ml.PutJobRequest;
import org.elasticsearch.client.ml.PutJobResponse;
+import org.elasticsearch.client.ml.PutTrainedModelRequest;
+import org.elasticsearch.client.ml.PutTrainedModelResponse;
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
import org.elasticsearch.client.ml.RevertModelSnapshotResponse;
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
@@ -165,7 +167,9 @@
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
+import org.elasticsearch.client.ml.inference.TrainedModelInput;
import org.elasticsearch.client.ml.inference.TrainedModelStats;
+import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
import org.elasticsearch.client.ml.job.config.AnalysisLimits;
import org.elasticsearch.client.ml.job.config.DataDescription;
@@ -3625,6 +3629,69 @@ public void onFailure(Exception e) {
}
}
+ public void testPutTrainedModel() throws Exception {
+ TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
+ // tag::put-trained-model-config
+ TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
+ .setDefinition(definition) // <1>
+ .setModelId("my-new-trained-model") // <2>
+ .setInput(new TrainedModelInput("col1", "col2", "col3", "col4")) // <3>
+ .setDescription("test model") // <4>
+ .setMetadata(new HashMap<>()) // <5>
+ .setTags("my_regression_models") // <6>
+ .build();
+ // end::put-trained-model-config
+
+ RestHighLevelClient client = highLevelClient();
+ {
+ // tag::put-trained-model-request
+ PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig); // <1>
+ // end::put-trained-model-request
+
+ // tag::put-trained-model-execute
+ PutTrainedModelResponse response = client.machineLearning().putTrainedModel(request, RequestOptions.DEFAULT);
+ // end::put-trained-model-execute
+
+ // tag::put-trained-model-response
+ TrainedModelConfig model = response.getResponse();
+ // end::put-trained-model-response
+
+ assertThat(model.getModelId(), equalTo(trainedModelConfig.getModelId()));
+ highLevelClient().machineLearning()
+ .deleteTrainedModel(new DeleteTrainedModelRequest("my-new-trained-model"), RequestOptions.DEFAULT);
+ }
+ {
+ PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig);
+
+ // tag::put-trained-model-execute-listener
+ ActionListener listener = new ActionListener<>() {
+ @Override
+ public void onResponse(PutTrainedModelResponse response) {
+ // <1>
+ }
+
+ @Override
+ public void onFailure(Exception e) {
+ // <2>
+ }
+ };
+ // end::put-trained-model-execute-listener
+
+ // Replace the empty listener by a blocking listener in test
+ CountDownLatch latch = new CountDownLatch(1);
+ listener = new LatchedActionListener<>(listener, latch);
+
+ // tag::put-trained-model-execute-async
+ client.machineLearning().putTrainedModelAsync(request, RequestOptions.DEFAULT, listener); // <1>
+ // end::put-trained-model-execute-async
+
+ assertTrue(latch.await(30L, TimeUnit.SECONDS));
+
+ highLevelClient().machineLearning()
+ .deleteTrainedModel(new DeleteTrainedModelRequest("my-new-trained-model"), RequestOptions.DEFAULT);
+ }
+ }
+
public void testGetTrainedModelsStats() throws Exception {
putTrainedModel("my-trained-model");
RestHighLevelClient client = highLevelClient();
@@ -4088,20 +4155,14 @@ private DataFrameAnalyticsState getAnalyticsState(String configId) throws IOExce
}
private void putTrainedModel(String modelId) throws IOException {
- TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
- highLevelClient().index(
- new IndexRequest(".ml-inference-000001")
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .source(modelConfigString(modelId), XContentType.JSON)
- .id(modelId),
- RequestOptions.DEFAULT);
-
- highLevelClient().index(
- new IndexRequest(".ml-inference-000001")
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .source(modelDocString(compressDefinition(definition), modelId), XContentType.JSON)
- .id("trained_model_definition_doc-" + modelId + "-0"),
- RequestOptions.DEFAULT);
+ TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
+ TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
+ .setDefinition(definition)
+ .setModelId(modelId)
+ .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
+ .setDescription("test model")
+ .build();
+ highLevelClient().machineLearning().putTrainedModel(new PutTrainedModelRequest(trainedModelConfig), RequestOptions.DEFAULT);
}
private String compressDefinition(TrainedModelDefinition definition) throws IOException {
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionRequestTests.java
new file mode 100644
index 0000000000000..b3956c5c6afe0
--- /dev/null
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionRequestTests.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.client.ml.inference.TrainedModelConfigTests;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+
+public class PutTrainedModelActionRequestTests extends AbstractXContentTestCase {
+
+ @Override
+ protected PutTrainedModelRequest createTestInstance() {
+ return new PutTrainedModelRequest(TrainedModelConfigTests.createTestTrainedModelConfig());
+ }
+
+ @Override
+ protected PutTrainedModelRequest doParseInstance(XContentParser parser) throws IOException {
+ return new PutTrainedModelRequest(TrainedModelConfig.PARSER.apply(parser, null).build());
+ }
+
+ @Override
+ protected boolean supportsUnknownFields() {
+ return false;
+ }
+
+ @Override
+ protected NamedXContentRegistry xContentRegistry() {
+ return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+ }
+
+}
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionResponseTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionResponseTests.java
new file mode 100644
index 0000000000000..61e1638547b3f
--- /dev/null
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionResponseTests.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.client.ml.inference.TrainedModelConfigTests;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+
+public class PutTrainedModelActionResponseTests extends AbstractXContentTestCase {
+
+ @Override
+ protected PutTrainedModelResponse createTestInstance() {
+ return new PutTrainedModelResponse(TrainedModelConfigTests.createTestTrainedModelConfig());
+ }
+
+ @Override
+ protected PutTrainedModelResponse doParseInstance(XContentParser parser) throws IOException {
+ return new PutTrainedModelResponse(TrainedModelConfig.PARSER.apply(parser, null).build());
+ }
+
+ @Override
+ protected boolean supportsUnknownFields() {
+ return false;
+ }
+
+ @Override
+ protected NamedXContentRegistry xContentRegistry() {
+ return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+ }
+
+}
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java
index 95ebbad837d69..43ab2e5993fde 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java
@@ -37,6 +37,24 @@
public class TrainedModelConfigTests extends AbstractXContentTestCase {
+ public static TrainedModelConfig createTestTrainedModelConfig() {
+ return new TrainedModelConfig(
+ randomAlphaOfLength(10),
+ randomAlphaOfLength(10),
+ Version.CURRENT,
+ randomBoolean() ? null : randomAlphaOfLength(100),
+ Instant.ofEpochMilli(randomNonNegativeLong()),
+ randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
+ randomBoolean() ? null : randomAlphaOfLength(100),
+ randomBoolean() ? null :
+ Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
+ randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
+ randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
+ randomBoolean() ? null : randomNonNegativeLong(),
+ randomBoolean() ? null : randomNonNegativeLong(),
+ randomBoolean() ? null : randomFrom("platinum", "basic"));
+ }
+
@Override
protected TrainedModelConfig doParseInstance(XContentParser parser) throws IOException {
return TrainedModelConfig.fromXContent(parser);
@@ -54,22 +72,7 @@ protected Predicate getRandomFieldsExcludeFilter() {
@Override
protected TrainedModelConfig createTestInstance() {
- return new TrainedModelConfig(
- randomAlphaOfLength(10),
- randomAlphaOfLength(10),
- Version.CURRENT,
- randomBoolean() ? null : randomAlphaOfLength(100),
- Instant.ofEpochMilli(randomNonNegativeLong()),
- randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
- randomBoolean() ? null : randomAlphaOfLength(100),
- randomBoolean() ? null :
- Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
- randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
- randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
- randomBoolean() ? null : randomNonNegativeLong(),
- randomBoolean() ? null : randomNonNegativeLong(),
- randomBoolean() ? null : randomFrom("platinum", "basic"));
-
+ return createTestTrainedModelConfig();
}
@Override
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java
index f2448cbf4c8bb..a26adf5f09c71 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java
@@ -67,15 +67,17 @@ public static Ensemble createRandom(TargetType targetType) {
.collect(Collectors.toList());
int numberOfModels = randomIntBetween(1, 10);
List models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6, targetType))
- .limit(numberOfFeatures)
+ .limit(numberOfModels)
.collect(Collectors.toList());
- OutputAggregator outputAggregator = null;
- if (randomBoolean()) {
- List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
- outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights), new LogisticRegression(weights));
+ List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
+ List possibleAggregators = new ArrayList<>(Arrays.asList(new WeightedMode(weights),
+ new LogisticRegression(weights)));
+ if (targetType.equals(TargetType.REGRESSION)) {
+ possibleAggregators.add(new WeightedSum(weights));
}
+ OutputAggregator outputAggregator = randomFrom(possibleAggregators.toArray(new OutputAggregator[0]));
List categoryLabels = null;
- if (randomBoolean()) {
+ if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) {
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
}
return new Ensemble(featureNames,
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java
index febd1b98c2765..57cb2ba664d77 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java
@@ -84,7 +84,7 @@ public static Tree buildRandomTree(List featureNames, int depth, TargetT
childNodes = nextNodes;
}
List categoryLabels = null;
- if (randomBoolean()) {
+ if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) {
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
}
return builder.setClassificationLabels(categoryLabels)
diff --git a/docs/java-rest/high-level/ml/put-trained-model.asciidoc b/docs/java-rest/high-level/ml/put-trained-model.asciidoc
new file mode 100644
index 0000000000000..5e7694f8221e3
--- /dev/null
+++ b/docs/java-rest/high-level/ml/put-trained-model.asciidoc
@@ -0,0 +1,51 @@
+--
+:api: put-trained-model
+:request: PutTrainedModelRequest
+:response: PutTrainedModelResponse
+--
+[role="xpack"]
+[id="{upid}-{api}"]
+=== Put Trained Model API
+
+Creates a new trained model for inference.
+The API accepts a +{request}+ object as a request and returns a +{response}+.
+
+[id="{upid}-{api}-request"]
+==== Put Trained Model request
+
+A +{request}+ requires the following argument:
+
+["source","java",subs="attributes,callouts,macros"]
+--------------------------------------------------
+include-tagged::{doc-tests-file}[{api}-request]
+--------------------------------------------------
+<1> The configuration of the {infer} Trained Model to create
+
+[id="{upid}-{api}-config"]
+==== Trained Model configuration
+
+The `TrainedModelConfig` object contains all the details about the trained model
+configuration and contains the following arguments:
+
+["source","java",subs="attributes,callouts,macros"]
+--------------------------------------------------
+include-tagged::{doc-tests-file}[{api}-config]
+--------------------------------------------------
+<1> The {infer} definition for the model
+<2> The unique model id
+<3> The input field names for the model definition
+<4> Optionally, a human-readable description
+<5> Optionally, an object map contain metadata about the model
+<6> Optionally, an array of tags to organize the model
+
+include::../execution.asciidoc[]
+
+[id="{upid}-{api}-response"]
+==== Response
+
+The returned +{response}+ contains the newly created trained model.
+
+["source","java",subs="attributes,callouts,macros"]
+--------------------------------------------------
+include-tagged::{doc-tests-file}[{api}-response]
+--------------------------------------------------
diff --git a/docs/java-rest/high-level/supported-apis.asciidoc b/docs/java-rest/high-level/supported-apis.asciidoc
index e0d228b5d1e4b..4b848819702b4 100644
--- a/docs/java-rest/high-level/supported-apis.asciidoc
+++ b/docs/java-rest/high-level/supported-apis.asciidoc
@@ -304,6 +304,7 @@ The Java High Level REST Client supports the following Machine Learning APIs:
* <<{upid}-evaluate-data-frame>>
* <<{upid}-explain-data-frame-analytics>>
* <<{upid}-get-trained-models>>
+* <<{upid}-put-trained-model>>
* <<{upid}-get-trained-models-stats>>
* <<{upid}-delete-trained-model>>
* <<{upid}-put-filter>>
@@ -359,6 +360,7 @@ include::ml/stop-data-frame-analytics.asciidoc[]
include::ml/evaluate-data-frame.asciidoc[]
include::ml/explain-data-frame-analytics.asciidoc[]
include::ml/get-trained-models.asciidoc[]
+include::ml/put-trained-model.asciidoc[]
include::ml/get-trained-models-stats.asciidoc[]
include::ml/delete-trained-model.asciidoc[]
include::ml/put-filter.asciidoc[]
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java
new file mode 100644
index 0000000000000..97045489001a5
--- /dev/null
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java
@@ -0,0 +1,137 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.action.ActionType;
+import org.elasticsearch.action.support.master.AcknowledgedRequest;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.job.messages.Messages;
+
+import java.io.IOException;
+import java.util.Objects;
+
+
+public class PutTrainedModelAction extends ActionType {
+
+ public static final PutTrainedModelAction INSTANCE = new PutTrainedModelAction();
+ public static final String NAME = "cluster:monitor/xpack/ml/inference/put";
+ private PutTrainedModelAction() {
+ super(NAME, Response::new);
+ }
+
+ public static class Request extends AcknowledgedRequest {
+
+ public static Request parseRequest(String modelId, XContentParser parser) {
+ TrainedModelConfig.Builder builder = TrainedModelConfig.STRICT_PARSER.apply(parser, null);
+
+ if (builder.getModelId() == null) {
+ builder.setModelId(modelId).build();
+ } else if (!Strings.isNullOrEmpty(modelId) && !modelId.equals(builder.getModelId())) {
+ // If we have model_id in both URI and body, they must be identical
+ throw new IllegalArgumentException(Messages.getMessage(Messages.INCONSISTENT_ID,
+ TrainedModelConfig.MODEL_ID.getPreferredName(),
+ builder.getModelId(),
+ modelId));
+ }
+ // Validations are done against the builder so we can build the full config object.
+ // This allows us to not worry about serializing a builder class between nodes.
+ return new Request(builder.validate().build());
+ }
+
+ private final TrainedModelConfig config;
+
+ public Request(TrainedModelConfig config) {
+ this.config = config;
+ }
+
+ public Request(StreamInput in) throws IOException {
+ super(in);
+ this.config = new TrainedModelConfig(in);
+ }
+
+ public TrainedModelConfig getTrainedModelConfig() {
+ return config;
+ }
+
+ @Override
+ public ActionRequestValidationException validate() {
+ return null;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ super.writeTo(out);
+ config.writeTo(out);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Request request = (Request) o;
+ return Objects.equals(config, request.config);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(config);
+ }
+
+ @Override
+ public final String toString() {
+ return Strings.toString(config);
+ }
+ }
+
+ public static class Response extends ActionResponse implements ToXContentObject {
+
+ private final TrainedModelConfig trainedModelConfig;
+
+ public Response(TrainedModelConfig trainedModelConfig) {
+ this.trainedModelConfig = trainedModelConfig;
+ }
+
+ public Response(StreamInput in) throws IOException {
+ super(in);
+ trainedModelConfig = new TrainedModelConfig(in);
+ }
+
+ public TrainedModelConfig getResponse() {
+ return trainedModelConfig;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ trainedModelConfig.writeTo(out);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ return trainedModelConfig.toXContent(builder, params);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Response response = (Response) o;
+ return Objects.equals(trainedModelConfig, response.trainedModelConfig);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(trainedModelConfig);
+ }
+ }
+}
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java
index be4d40efc8501..4d59bc336e86b 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java
@@ -7,6 +7,7 @@
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
@@ -34,6 +35,9 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
+import java.util.stream.Collectors;
+
+import static org.elasticsearch.action.ValidateActions.addValidationError;
public class TrainedModelConfig implements ToXContentObject, Writeable {
@@ -352,13 +356,31 @@ public static class Builder {
private Long estimatedHeapMemory;
private Long estimatedOperations;
private LazyModelDefinition definition;
- private String licenseLevel = License.OperationMode.PLATINUM.description();
+ private String licenseLevel;
+
+ public Builder() {}
+
+ public Builder(TrainedModelConfig config) {
+ this.modelId = config.getModelId();
+ this.createdBy = config.getCreatedBy();
+ this.version = config.getVersion();
+ this.createTime = config.getCreateTime();
+ this.definition = config.definition == null ? null : new LazyModelDefinition(config.definition);
+ this.description = config.getDescription();
+ this.tags = config.getTags();
+ this.metadata = config.getMetadata();
+ this.input = config.getInput();
+ }
public Builder setModelId(String modelId) {
this.modelId = modelId;
return this;
}
+ public String getModelId() {
+ return this.modelId;
+ }
+
public Builder setCreatedBy(String createdBy) {
this.createdBy = createdBy;
return this;
@@ -466,51 +488,89 @@ public Builder setLicenseLevel(String licenseLevel) {
return this;
}
- // TODO move to REST level instead of here in the builder
- public void validate() {
+ /**
+ * Runs validations against the builder.
+ * @return The current builder object if validations are successful
+ * @throws ActionRequestValidationException when there are validation failures.
+ */
+ public Builder validate() {
// We require a definition to be available here even though it will be stored in a different doc
- ExceptionsHelper.requireNonNull(definition, DEFINITION);
- ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
-
- if (MlStrings.isValidId(modelId) == false) {
- throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INVALID_ID, MODEL_ID.getPreferredName(), modelId));
+ ActionRequestValidationException validationException = null;
+ if (definition == null) {
+ validationException = addValidationError("[" + DEFINITION.getPreferredName() + "] must not be null.", validationException);
+ }
+ if (modelId == null) {
+ validationException = addValidationError("[" + MODEL_ID.getPreferredName() + "] must not be null.", validationException);
}
- if (MlStrings.hasValidLengthForId(modelId) == false) {
- throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.ID_TOO_LONG,
- MODEL_ID.getPreferredName(),
+ if (modelId != null && MlStrings.isValidId(modelId) == false) {
+ validationException = addValidationError(Messages.getMessage(Messages.INVALID_ID,
+ TrainedModelConfig.MODEL_ID.getPreferredName(),
+ modelId),
+ validationException);
+ }
+ if (modelId != null && MlStrings.hasValidLengthForId(modelId) == false) {
+ validationException = addValidationError(Messages.getMessage(Messages.ID_TOO_LONG,
+ TrainedModelConfig.MODEL_ID.getPreferredName(),
modelId,
- MlStrings.ID_LENGTH_LIMIT));
+ MlStrings.ID_LENGTH_LIMIT), validationException);
+ }
+ List badTags = tags.stream()
+ .filter(tag -> (MlStrings.isValidId(tag) && MlStrings.hasValidLengthForId(tag)) == false)
+ .collect(Collectors.toList());
+ if (badTags.isEmpty() == false) {
+ validationException = addValidationError(Messages.getMessage(Messages.INFERENCE_INVALID_TAGS,
+ badTags,
+ MlStrings.ID_LENGTH_LIMIT),
+ validationException);
+ }
+
+ for(String tag : tags) {
+ if (tag.equals(modelId)) {
+ validationException = addValidationError("none of the tags must equal the model_id", validationException);
+ break;
+ }
+ }
+
+ validationException = checkIllegalSetting(version, VERSION.getPreferredName(), validationException);
+ validationException = checkIllegalSetting(createdBy, CREATED_BY.getPreferredName(), validationException);
+ validationException = checkIllegalSetting(createTime, CREATE_TIME.getPreferredName(), validationException);
+ validationException = checkIllegalSetting(estimatedHeapMemory,
+ ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
+ validationException);
+ validationException = checkIllegalSetting(estimatedOperations, ESTIMATED_OPERATIONS.getPreferredName(), validationException);
+ validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException);
+
+ if (validationException != null) {
+ throw validationException;
}
- checkIllegalSetting(version, VERSION.getPreferredName());
- checkIllegalSetting(createdBy, CREATED_BY.getPreferredName());
- checkIllegalSetting(createTime, CREATE_TIME.getPreferredName());
- checkIllegalSetting(estimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName());
- checkIllegalSetting(estimatedOperations, ESTIMATED_OPERATIONS.getPreferredName());
- checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName());
+ return this;
}
- private static void checkIllegalSetting(Object value, String setting) {
+ private static ActionRequestValidationException checkIllegalSetting(Object value,
+ String setting,
+ ActionRequestValidationException validationException) {
if (value != null) {
- throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", setting);
+ return addValidationError("illegal to set [" + setting + "] at inference model creation", validationException);
}
+ return validationException;
}
public TrainedModelConfig build() {
return new TrainedModelConfig(
modelId,
- createdBy,
- version,
+ createdBy == null ? "user" : createdBy,
+ version == null ? Version.CURRENT : version,
description,
createTime == null ? Instant.now() : createTime,
definition,
tags,
metadata,
input,
- estimatedHeapMemory,
- estimatedOperations,
- licenseLevel);
+ estimatedHeapMemory == null ? 0 : estimatedHeapMemory,
+ estimatedOperations == null ? 0 : estimatedOperations,
+ licenseLevel == null ? License.OperationMode.PLATINUM.description() : licenseLevel);
}
}
@@ -531,6 +591,13 @@ public static LazyModelDefinition fromStreamInput(StreamInput input) throws IOEx
return new LazyModelDefinition(input.readString(), null);
}
+ private LazyModelDefinition(LazyModelDefinition definition) {
+ if (definition != null) {
+ this.compressedString = definition.compressedString;
+ this.parsedDefinition = definition.parsedDefinition;
+ }
+ }
+
private LazyModelDefinition(String compressedString, TrainedModelDefinition trainedModelDefinition) {
if (compressedString == null && trainedModelDefinition == null) {
throw new IllegalArgumentException("unexpected null model definition");
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java
index e176d9a288568..cf7a8b7d224c5 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java
@@ -179,6 +179,12 @@ public Builder() {
this(true);
}
+ public Builder(TrainedModelDefinition definition) {
+ this(true);
+ this.preProcessors = new ArrayList<>(definition.getPreProcessors());
+ this.trainedModel = definition.trainedModel;
+ }
+
public Builder setPreProcessors(List preProcessors) {
this.preProcessors = preProcessors;
return this;
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java
index 1b80d1963598e..ef0fcd4fdb172 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java
@@ -95,6 +95,10 @@ public final class Messages {
public static final String INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED =
"Getting model definition is not supported when getting more than one model";
public static final String INFERENCE_WARNING_ALL_FIELDS_MISSING = "Model [{0}] could not be inferred as all fields were missing";
+ public static final String INFERENCE_INVALID_TAGS = "Invalid tags {0}; must only can contain lowercase alphanumeric (a-z and 0-9), " +
+ "hyphens or underscores, must start and end with alphanumeric, and must be less than {1} characters.";
+ public static final String INFERENCE_TAGS_AND_MODEL_IDS_UNIQUE = "The provided tags {0} must not match existing model_ids.";
+ public static final String INFERENCE_MODEL_ID_AND_TAGS_UNIQUE = "The provided model_id {0} must not match existing tags.";
public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
public static final String JOB_AUDIT_CREATED = "Job created";
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionRequestTests.java
new file mode 100644
index 0000000000000..0c39469c9029e
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionRequestTests.java
@@ -0,0 +1,45 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Request;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
+
+public class PutTrainedModelActionRequestTests extends AbstractWireSerializingTestCase {
+
+ @Override
+ protected Request createTestInstance() {
+ String modelId = randomAlphaOfLength(10);
+ return new Request(TrainedModelConfigTests.createTestInstance(modelId)
+ .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
+ .build());
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return (in) -> {
+ Request request = new Request(in);
+ request.getTrainedModelConfig().ensureParsedDefinition(xContentRegistry());
+ return request;
+ };
+ }
+
+ @Override
+ protected NamedXContentRegistry xContentRegistry() {
+ return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+ }
+
+ @Override
+ protected NamedWriteableRegistry getNamedWriteableRegistry() {
+ return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
+ }
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionResponseTests.java
new file mode 100644
index 0000000000000..5813b13c8ad55
--- /dev/null
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionResponseTests.java
@@ -0,0 +1,45 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Response;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
+
+public class PutTrainedModelActionResponseTests extends AbstractWireSerializingTestCase {
+
+ @Override
+ protected Response createTestInstance() {
+ String modelId = randomAlphaOfLength(10);
+ return new Response(TrainedModelConfigTests.createTestInstance(modelId)
+ .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
+ .build());
+ }
+
+ @Override
+ protected Writeable.Reader instanceReader() {
+ return (in) -> {
+ Response response = new Response(in);
+ response.getResponse().ensureParsedDefinition(xContentRegistry());
+ return response;
+ };
+ }
+
+ @Override
+ protected NamedXContentRegistry xContentRegistry() {
+ return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+ }
+
+ @Override
+ protected NamedWriteableRegistry getNamedWriteableRegistry() {
+ return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
+ }
+}
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java
index 8b267e10ca429..db570025c786a 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java
@@ -5,8 +5,8 @@
*/
package org.elasticsearch.xpack.core.ml.inference;
-import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
@@ -56,14 +56,16 @@ public static TrainedModelConfig.Builder createTestInstance(String modelId) {
return TrainedModelConfig.builder()
.setInput(TrainedModelInputTests.createRandomInput())
.setMetadata(randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)))
- .setCreateTime(Instant.ofEpochMilli(randomNonNegativeLong()))
+ .setCreateTime(Instant.ofEpochMilli(randomLongBetween(Instant.MIN.getEpochSecond(), Instant.MAX.getEpochSecond())))
.setVersion(Version.CURRENT)
.setModelId(modelId)
.setCreatedBy(randomAlphaOfLength(10))
.setDescription(randomBoolean() ? null : randomAlphaOfLength(100))
.setEstimatedHeapMemory(randomNonNegativeLong())
.setEstimatedOperations(randomNonNegativeLong())
- .setLicenseLevel(License.OperationMode.PLATINUM.description())
+ .setLicenseLevel(randomFrom(License.OperationMode.PLATINUM.description(),
+ License.OperationMode.GOLD.description(),
+ License.OperationMode.BASIC.description()))
.setTags(tags);
}
@@ -191,50 +193,52 @@ public void testParseWithBothDefinitionAndCompressedSupplied() throws IOExceptio
}
public void testValidateWithNullDefinition() {
- IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> TrainedModelConfig.builder().validate());
- assertThat(ex.getMessage(), equalTo("[definition] must not be null."));
+ ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
+ () -> TrainedModelConfig.builder().validate());
+ assertThat(ex.getMessage(), containsString("[definition] must not be null."));
}
public void testValidateWithInvalidID() {
String modelId = "InvalidID-";
- ElasticsearchException ex = expectThrows(ElasticsearchException.class,
+ ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
() -> TrainedModelConfig.builder()
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setModelId(modelId).validate());
- assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId)));
+ assertThat(ex.getMessage(), containsString(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId)));
}
public void testValidateWithLongID() {
String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining());
- ElasticsearchException ex = expectThrows(ElasticsearchException.class,
+ ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
() -> TrainedModelConfig.builder()
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setModelId(modelId).validate());
- assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT)));
+ assertThat(ex.getMessage(),
+ containsString(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT)));
}
public void testValidateWithIllegallyUserProvidedFields() {
String modelId = "simplemodel";
- ElasticsearchException ex = expectThrows(ElasticsearchException.class,
+ ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
() -> TrainedModelConfig.builder()
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setCreateTime(Instant.now())
.setModelId(modelId).validate());
- assertThat(ex.getMessage(), equalTo("illegal to set [create_time] at inference model creation"));
+ assertThat(ex.getMessage(), containsString("illegal to set [create_time] at inference model creation"));
- ex = expectThrows(ElasticsearchException.class,
+ ex = expectThrows(ActionRequestValidationException.class,
() -> TrainedModelConfig.builder()
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setVersion(Version.CURRENT)
.setModelId(modelId).validate());
- assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation"));
+ assertThat(ex.getMessage(), containsString("illegal to set [version] at inference model creation"));
- ex = expectThrows(ElasticsearchException.class,
+ ex = expectThrows(ActionRequestValidationException.class,
() -> TrainedModelConfig.builder()
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setCreatedBy("ml_user")
.setModelId(modelId).validate());
- assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation"));
+ assertThat(ex.getMessage(), containsString("illegal to set [created_by] at inference model creation"));
}
public void testSerializationWithLazyDefinition() throws IOException {
diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle
index 8c9ba6df7f009..7e19a4d606d25 100644
--- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle
+++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle
@@ -133,7 +133,6 @@ integTest.runner {
'ml/get_datafeed_stats/Test get datafeed stats given missing datafeed_id',
'ml/get_datafeeds/Test get datafeed given missing datafeed_id',
'ml/inference_crud/Test delete given used trained model',
- 'ml/inference_crud/Test delete given unused trained model',
'ml/inference_crud/Test delete with missing model',
'ml/inference_crud/Test get given missing trained model',
'ml/inference_crud/Test get given expression without matches and allow_no_match is false',
diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java
index 27e83e04b412b..216eac723115f 100644
--- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java
+++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java
@@ -9,15 +9,19 @@
import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
import org.elasticsearch.action.search.SearchRequest;
-import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.bytes.BytesArray;
+import org.elasticsearch.common.xcontent.DeprecationHandler;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.query.QueryBuilders;
-import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.builder.SearchSourceBuilder;
-import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
-import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
-import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
+import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.junit.After;
import org.junit.Before;
import java.io.IOException;
@@ -34,26 +38,14 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
@Before
public void createBothModels() throws Exception {
- assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME)
- .setId("test_classification")
- .setSource(CLASSIFICATION_CONFIG, XContentType.JSON)
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .get().status(), equalTo(RestStatus.CREATED));
- assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME)
- .setId(TrainedModelDefinitionDoc.docId("test_classification", 0))
- .setSource(buildClassificationModelDoc(), XContentType.JSON)
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .get().status(), equalTo(RestStatus.CREATED));
- assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME)
- .setId("test_regression")
- .setSource(REGRESSION_CONFIG, XContentType.JSON)
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .get().status(), equalTo(RestStatus.CREATED));
- assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME)
- .setId(TrainedModelDefinitionDoc.docId("test_regression", 0))
- .setSource(buildRegressionModelDoc(), XContentType.JSON)
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .get().status(), equalTo(RestStatus.CREATED));
+ client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildClassificationModel())).actionGet();
+ client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildRegressionModel())).actionGet();
+ }
+
+ @After
+ public void deleteBothModels() {
+ client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_classification")).actionGet();
+ client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_regression")).actionGet();
}
public void testPipelineCreationAndDeletion() throws Exception {
@@ -391,6 +383,7 @@ private Map generateSourceDoc() {
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for regression\",\n" +
" \"version\": \"8.0.0\",\n" +
+ " \"definition\": " + REGRESSION_DEFINITION + ","+
" \"license_level\": \"platinum\",\n" +
" \"created_by\": \"ml_test\",\n" +
" \"estimated_heap_memory_usage_bytes\": 0," +
@@ -518,28 +511,27 @@ private Map generateSourceDoc() {
" }\n" +
"}";
- private static String buildClassificationModelDoc() throws IOException {
- String compressed =
- InferenceToXContentCompressor.deflate(new BytesArray(CLASSIFICATION_DEFINITION.getBytes(StandardCharsets.UTF_8)));
- return modelDocString(compressed, "test_classification");
+ private TrainedModelConfig buildClassificationModel() throws IOException {
+ try (XContentParser parser = XContentHelper.createParser(xContentRegistry(),
+ DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
+ new BytesArray(CLASSIFICATION_CONFIG),
+ XContentType.JSON)) {
+ return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build();
+ }
}
- private static String buildRegressionModelDoc() throws IOException {
- String compressed = InferenceToXContentCompressor.deflate(new BytesArray(REGRESSION_DEFINITION.getBytes(StandardCharsets.UTF_8)));
- return modelDocString(compressed, "test_regression");
+ private TrainedModelConfig buildRegressionModel() throws IOException {
+ try (XContentParser parser = XContentHelper.createParser(xContentRegistry(),
+ DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
+ new BytesArray(REGRESSION_CONFIG),
+ XContentType.JSON)) {
+ return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build();
+ }
}
- private static String modelDocString(String compressedDefinition, String modelId) {
- return "" +
- "{" +
- "\"model_id\": \"" + modelId + "\",\n" +
- "\"doc_num\": 0,\n" +
- "\"doc_type\": \"trained_model_definition_doc\",\n" +
- " \"compression_version\": " + 1 + ",\n" +
- " \"total_definition_length\": " + compressedDefinition.length() + ",\n" +
- " \"definition_length\": " + compressedDefinition.length() + ",\n" +
- "\"definition\": \"" + compressedDefinition + "\"\n" +
- "}";
+ @Override
+ protected NamedXContentRegistry xContentRegistry() {
+ return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
}
private static final String CLASSIFICATION_CONFIG = "" +
@@ -547,9 +539,10 @@ private static String modelDocString(String compressedDefinition, String modelId
" \"model_id\": \"test_classification\",\n" +
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for classification\",\n" +
+ " \"definition\": " + CLASSIFICATION_DEFINITION + ","+
" \"version\": \"8.0.0\",\n" +
" \"license_level\": \"platinum\",\n" +
- " \"created_by\": \"benwtrent\",\n" +
+ " \"created_by\": \"es_test\",\n" +
" \"estimated_heap_memory_usage_bytes\": 0," +
" \"estimated_operations\": 0," +
" \"created_time\": 0\n" +
diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java
index 72c677d1aa40d..0aec6bc337412 100644
--- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java
+++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java
@@ -6,10 +6,18 @@
package org.elasticsearch.xpack.ml.integration;
import org.apache.http.util.EntityUtils;
-import org.elasticsearch.Version;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException;
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
+import org.elasticsearch.client.ml.inference.TrainedModelInput;
+import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
+import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
+import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
+import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
+import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
+import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
@@ -18,26 +26,19 @@
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentType;
-import org.elasticsearch.license.License;
import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.test.rest.ESRestTestCase;
-import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
-import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
-import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
-import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
-import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.elasticsearch.xpack.ml.MachineLearning;
-import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
import org.junit.After;
import java.io.IOException;
-import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
+import java.util.List;
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
import static org.hamcrest.Matchers.containsString;
@@ -62,22 +63,8 @@ protected boolean preserveTemplatesUponCompletion() {
public void testGetTrainedModels() throws IOException {
String modelId = "a_test_regression_model";
String modelId2 = "a_test_regression_model-2";
- Request model1 = new Request("PUT",
- InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId);
- model1.setJsonEntity(buildRegressionModel(modelId));
- assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201));
-
- Request modelDefinition1 = new Request("PUT",
- InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinitionDoc.docId(modelId, 0));
- modelDefinition1.setJsonEntity(buildRegressionModelDefinitionDoc(modelId));
- assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201));
-
- Request model2 = new Request("PUT",
- InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId2);
- model2.setJsonEntity(buildRegressionModel(modelId2));
- assertThat(client().performRequest(model2).getStatusLine().getStatusCode(), equalTo(201));
-
- adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh"));
+ putRegressionModel(modelId);
+ putRegressionModel(modelId2);
Response getModel = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/" + modelId));
@@ -164,17 +151,7 @@ public void testGetTrainedModels() throws IOException {
public void testDeleteTrainedModels() throws IOException {
String modelId = "test_delete_regression_model";
- Request model1 = new Request("PUT",
- InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId);
- model1.setJsonEntity(buildRegressionModel(modelId));
- assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201));
-
- Request modelDefinition1 = new Request("PUT",
- InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinitionDoc.docId(modelId, 0));
- modelDefinition1.setJsonEntity(buildRegressionModelDefinitionDoc(modelId));
- assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201));
-
- adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh"));
+ putRegressionModel(modelId);
Response delModel = client().performRequest(new Request("DELETE",
MachineLearning.BASE_PATH + "inference/" + modelId));
@@ -208,42 +185,68 @@ public void testGetPrePackagedModels() throws IOException {
assertThat(response, containsString("\"definition\""));
}
- private static String buildRegressionModel(String modelId) throws IOException {
+ private void putRegressionModel(String modelId) throws IOException {
try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
+ TrainedModelDefinition.Builder definition = new TrainedModelDefinition.Builder()
+ .setPreProcessors(Collections.emptyList())
+ .setTrainedModel(buildRegression());
TrainedModelConfig.builder()
+ .setDefinition(definition)
.setModelId(modelId)
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3")))
- .setCreatedBy("ml_test")
- .setVersion(Version.CURRENT)
- .setCreateTime(Instant.now())
- .setEstimatedOperations(0)
- .setLicenseLevel(License.OperationMode.PLATINUM.description())
- .setEstimatedHeapMemory(0)
- .build()
- .toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
- return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON);
+ .build().toXContent(builder, ToXContent.EMPTY_PARAMS);
+ Request model = new Request("PUT", "_ml/inference/" + modelId);
+ model.setJsonEntity(XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON));
+ assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200));
}
}
- private static String buildRegressionModelDefinitionDoc(String modelId) throws IOException {
- try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
- TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
- .setPreProcessors(Collections.emptyList())
- .setTrainedModel(LocalModelTests.buildRegression())
- .build();
- String compressedString = InferenceToXContentCompressor.deflate(definition);
- TrainedModelDefinitionDoc doc = new TrainedModelDefinitionDoc.Builder().setDocNum(0)
- .setCompressedString(compressedString)
- .setTotalDefinitionLength(compressedString.length())
- .setDefinitionLength(compressedString.length())
- .setCompressionVersion(1)
- .setModelId(modelId).build();
- doc.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
- return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON);
- }
+ private static TrainedModel buildRegression() {
+ List featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog");
+ Tree tree1 = Tree.builder()
+ .setFeatureNames(featureNames)
+ .setNodes(TreeNode.builder(0)
+ .setLeftChild(1)
+ .setRightChild(2)
+ .setSplitFeature(0)
+ .setThreshold(0.5),
+ TreeNode.builder(1).setLeafValue(0.3),
+ TreeNode.builder(2)
+ .setThreshold(0.0)
+ .setSplitFeature(3)
+ .setLeftChild(3)
+ .setRightChild(4),
+ TreeNode.builder(3).setLeafValue(0.1),
+ TreeNode.builder(4).setLeafValue(0.2))
+ .build();
+ Tree tree2 = Tree.builder()
+ .setFeatureNames(featureNames)
+ .setNodes(TreeNode.builder(0)
+ .setLeftChild(1)
+ .setRightChild(2)
+ .setSplitFeature(2)
+ .setThreshold(1.0),
+ TreeNode.builder(1).setLeafValue(1.5),
+ TreeNode.builder(2).setLeafValue(0.9))
+ .build();
+ Tree tree3 = Tree.builder()
+ .setFeatureNames(featureNames)
+ .setNodes(TreeNode.builder(0)
+ .setLeftChild(1)
+ .setRightChild(2)
+ .setSplitFeature(1)
+ .setThreshold(0.2),
+ TreeNode.builder(1).setLeafValue(1.5),
+ TreeNode.builder(2).setLeafValue(0.9))
+ .build();
+ return Ensemble.builder()
+ .setTargetType(TargetType.REGRESSION)
+ .setFeatureNames(featureNames)
+ .setTrainedModels(Arrays.asList(tree1, tree2, tree3))
+ .setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5, 0.5)))
+ .build();
}
-
@After
public void clearMlState() throws Exception {
new MlRestTestStateCleaner(logger, adminClient()).clearMlMetadata();
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
index 5216cf6a267fc..b933ea65caed7 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
@@ -113,6 +113,7 @@
import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.PutFilterAction;
import org.elasticsearch.xpack.core.ml.action.PutJobAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.RevertModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.action.SetUpgradeModeAction;
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
@@ -184,6 +185,7 @@
import org.elasticsearch.xpack.ml.action.TransportPutDatafeedAction;
import org.elasticsearch.xpack.ml.action.TransportPutFilterAction;
import org.elasticsearch.xpack.ml.action.TransportPutJobAction;
+import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelAction;
import org.elasticsearch.xpack.ml.action.TransportRevertModelSnapshotAction;
import org.elasticsearch.xpack.ml.action.TransportSetUpgradeModeAction;
import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction;
@@ -276,6 +278,7 @@
import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction;
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction;
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction;
+import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAction;
import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction;
import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction;
import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction;
@@ -761,7 +764,8 @@ public List getRestHandlers(Settings settings, RestController restC
new RestExplainDataFrameAnalyticsAction(restController),
new RestGetTrainedModelsAction(restController),
new RestDeleteTrainedModelAction(restController),
- new RestGetTrainedModelsStatsAction(restController)
+ new RestGetTrainedModelsStatsAction(restController),
+ new RestPutTrainedModelAction(restController)
);
}
@@ -837,6 +841,7 @@ public List getRestHandlers(Settings settings, RestController restC
new ActionHandler<>(GetTrainedModelsAction.INSTANCE, TransportGetTrainedModelsAction.class),
new ActionHandler<>(DeleteTrainedModelAction.INSTANCE, TransportDeleteTrainedModelAction.class),
new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class),
+ new ActionHandler<>(PutTrainedModelAction.INSTANCE, TransportPutTrainedModelAction.class),
usageAction,
infoAction);
}
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java
new file mode 100644
index 0000000000000..9fca82802b51d
--- /dev/null
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java
@@ -0,0 +1,190 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.ml.action;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.master.TransportMasterNodeAction;
+import org.elasticsearch.client.Client;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.block.ClusterBlockException;
+import org.elasticsearch.cluster.block.ClusterBlockLevel;
+import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.inject.Inject;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.license.License;
+import org.elasticsearch.license.LicenseUtils;
+import org.elasticsearch.license.XPackLicenseState;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.TransportService;
+import org.elasticsearch.xpack.core.XPackField;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Request;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Response;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
+import org.elasticsearch.xpack.core.ml.job.messages.Messages;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
+
+import java.io.IOException;
+import java.time.Instant;
+import java.util.List;
+
+import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
+import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
+
+public class TransportPutTrainedModelAction extends TransportMasterNodeAction {
+
+ private final TrainedModelProvider trainedModelProvider;
+ private final XPackLicenseState licenseState;
+ private final NamedXContentRegistry xContentRegistry;
+ private final Client client;
+
+ @Inject
+ public TransportPutTrainedModelAction(TransportService transportService, ClusterService clusterService,
+ ThreadPool threadPool, XPackLicenseState licenseState, ActionFilters actionFilters,
+ IndexNameExpressionResolver indexNameExpressionResolver, Client client,
+ TrainedModelProvider trainedModelProvider, NamedXContentRegistry xContentRegistry) {
+ super(PutTrainedModelAction.NAME, transportService, clusterService, threadPool, actionFilters, Request::new,
+ indexNameExpressionResolver);
+ this.licenseState = licenseState;
+ this.trainedModelProvider = trainedModelProvider;
+ this.xContentRegistry = xContentRegistry;
+ this.client = client;
+ }
+
+ @Override
+ protected String executor() {
+ return ThreadPool.Names.SAME;
+ }
+
+ @Override
+ protected Response read(StreamInput in) throws IOException {
+ return new Response(in);
+ }
+
+ @Override
+ protected void masterOperation(Task task,
+ PutTrainedModelAction.Request request,
+ ClusterState state,
+ ActionListener listener) {
+ try {
+ request.getTrainedModelConfig().ensureParsedDefinition(xContentRegistry);
+ request.getTrainedModelConfig().getModelDefinition().getTrainedModel().validate();
+ } catch (IOException ex) {
+ listener.onFailure(ExceptionsHelper.badRequestException("Failed to parse definition for [{}]",
+ ex,
+ request.getTrainedModelConfig().getModelId()));
+ return;
+ } catch (ElasticsearchException ex) {
+ listener.onFailure(ExceptionsHelper.badRequestException("Definition for [{}] has validation failures.",
+ ex,
+ request.getTrainedModelConfig().getModelId()));
+ return;
+ }
+
+ TrainedModelConfig trainedModelConfig = new TrainedModelConfig.Builder(request.getTrainedModelConfig())
+ .setVersion(Version.CURRENT)
+ .setCreateTime(Instant.now())
+ .setCreatedBy("user")
+ .setLicenseLevel(License.OperationMode.PLATINUM.description())
+ .setEstimatedHeapMemory(request.getTrainedModelConfig().getModelDefinition().ramBytesUsed())
+ .setEstimatedOperations(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().estimatedNumOperations())
+ .build();
+
+ ActionListener tagsModelIdCheckListener = ActionListener.wrap(
+ r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap(
+ storedConfig -> listener.onResponse(new PutTrainedModelAction.Response(trainedModelConfig)),
+ listener::onFailure
+ )),
+ listener::onFailure
+ );
+
+ ActionListener modelIdTagCheckListener = ActionListener.wrap(
+ r -> checkTagsAgainstModelIds(request.getTrainedModelConfig().getTags(), tagsModelIdCheckListener),
+ listener::onFailure
+ );
+
+ checkModelIdAgainstTags(request.getTrainedModelConfig().getModelId(), modelIdTagCheckListener);
+ }
+
+ private void checkModelIdAgainstTags(String modelId, ActionListener listener) {
+ QueryBuilder builder = QueryBuilders.constantScoreQuery(
+ QueryBuilders.boolQuery()
+ .filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), modelId)));
+ SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(builder).size(0).trackTotalHitsUpTo(1);
+ SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN).source(sourceBuilder);
+ executeAsyncWithOrigin(client.threadPool().getThreadContext(),
+ ML_ORIGIN,
+ searchRequest,
+ ActionListener.wrap(
+ response -> {
+ if (response.getHits().getTotalHits().value > 0) {
+ listener.onFailure(
+ ExceptionsHelper.badRequestException(
+ Messages.getMessage(Messages.INFERENCE_MODEL_ID_AND_TAGS_UNIQUE, modelId)));
+ return;
+ }
+ listener.onResponse(null);
+ },
+ listener::onFailure
+ ),
+ client::search);
+ }
+
+ private void checkTagsAgainstModelIds(List tags, ActionListener listener) {
+ if (tags.isEmpty()) {
+ listener.onResponse(null);
+ return;
+ }
+
+ QueryBuilder builder = QueryBuilders.constantScoreQuery(
+ QueryBuilders.boolQuery()
+ .filter(QueryBuilders.termsQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), tags)));
+ SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(builder).size(0).trackTotalHitsUpTo(1);
+ SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN).source(sourceBuilder);
+ executeAsyncWithOrigin(client.threadPool().getThreadContext(),
+ ML_ORIGIN,
+ searchRequest,
+ ActionListener.wrap(
+ response -> {
+ if (response.getHits().getTotalHits().value > 0) {
+ listener.onFailure(
+ ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_TAGS_AND_MODEL_IDS_UNIQUE, tags)));
+ return;
+ }
+ listener.onResponse(null);
+ },
+ listener::onFailure
+ ),
+ client::search);
+ }
+
+ @Override
+ protected ClusterBlockException checkBlock(Request request, ClusterState state) {
+ return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
+ }
+
+ @Override
+ protected void doExecute(Task task, Request request, ActionListener listener) {
+ if (licenseState.isMachineLearningAllowed()) {
+ super.doExecute(task, request, listener);
+ } else {
+ listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
+ }
+ }
+}
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java
index 993de87b07193..7ae8004f14a25 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java
@@ -174,10 +174,12 @@ private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfi
r -> {
assert r.getItems().length == 2;
if (r.getItems()[0].isFailed()) {
+
logger.error(new ParameterizedMessage(
"[{}] failed to store trained model config for inference",
trainedModelConfig.getModelId()),
r.getItems()[0].getFailure().getCause());
+
wrappedListener.onFailure(r.getItems()[0].getFailure().getCause());
return;
}
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java
new file mode 100644
index 0000000000000..83ec70629bc78
--- /dev/null
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java
@@ -0,0 +1,42 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.ml.rest.inference;
+
+import org.elasticsearch.client.node.NodeClient;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.rest.BaseRestHandler;
+import org.elasticsearch.rest.RestController;
+import org.elasticsearch.rest.RestRequest;
+import org.elasticsearch.rest.action.RestToXContentListener;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.ml.MachineLearning;
+
+import java.io.IOException;
+
+public class RestPutTrainedModelAction extends BaseRestHandler {
+
+ public RestPutTrainedModelAction(RestController controller) {
+ controller.registerHandler(RestRequest.Method.PUT,
+ MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}",
+ this);
+ }
+
+ @Override
+ public String getName() {
+ return "xpack_ml_put_data_frame_analytics_action";
+ }
+
+ @Override
+ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
+ String id = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
+ XContentParser parser = restRequest.contentParser();
+ PutTrainedModelAction.Request putRequest = PutTrainedModelAction.Request.parseRequest(id, parser);
+ putRequest.timeout(restRequest.paramAsTime("timeout", putRequest.timeout()));
+
+ return channel -> client.execute(PutTrainedModelAction.INSTANCE, putRequest, new RestToXContentListener<>(channel));
+ }
+}
diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java
index f819cff871637..a80e8ed709673 100644
--- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java
+++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java
@@ -13,7 +13,6 @@
import org.elasticsearch.action.ingest.SimulatePipelineRequest;
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
import org.elasticsearch.action.support.PlainActionFuture;
-import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.common.bytes.BytesArray;
@@ -32,24 +31,28 @@
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.PutJobAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
-import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
-import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.job.config.JobState;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
-import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
import org.junit.Before;
import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
import java.util.Collections;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
-import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
@@ -481,12 +484,7 @@ public void testMachineLearningCreateInferenceProcessorRestricted() throws Excep
" \"target_field\": \"regression_value\",\n" +
" \"model_id\": \"modelprocessorlicensetest\",\n" +
" \"inference_config\": {\"regression\": {}},\n" +
- " \"field_mappings\": {\n" +
- " \"col1\": \"col1\",\n" +
- " \"col2\": \"col2\",\n" +
- " \"col3\": \"col3\",\n" +
- " \"col4\": \"col4\"\n" +
- " }\n" +
+ " \"field_mappings\": {}\n" +
" }\n" +
" }]}\n";
// Creating a pipeline should work
@@ -668,76 +666,22 @@ public void testMachineLearningInferModelRestricted() throws Exception {
assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
}
- private void putInferenceModel(String modelId) throws Exception {
- String config = "" +
- "{\n" +
- " \"model_id\": \"" + modelId + "\",\n" +
- " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
- " \"description\": \"test model for classification\",\n" +
- " \"version\": \"8.0.0\",\n" +
- " \"created_by\": \"benwtrent\",\n" +
- " \"license_level\": \"platinum\",\n" +
- " \"estimated_heap_memory_usage_bytes\": 0,\n" +
- " \"estimated_operations\": 0,\n" +
- " \"created_time\": 0\n" +
- "}";
- String definition = "" +
- "{" +
- " \"trained_model\": {\n" +
- " \"tree\": {\n" +
- " \"feature_names\": [\n" +
- " \"col1_male\",\n" +
- " \"col1_female\",\n" +
- " \"col2_encoded\",\n" +
- " \"col3_encoded\",\n" +
- " \"col4\"\n" +
- " ],\n" +
- " \"tree_structure\": [\n" +
- " {\n" +
- " \"node_index\": 0,\n" +
- " \"split_feature\": 0,\n" +
- " \"split_gain\": 12.0,\n" +
- " \"threshold\": 10.0,\n" +
- " \"decision_type\": \"lte\",\n" +
- " \"default_left\": true,\n" +
- " \"left_child\": 1,\n" +
- " \"right_child\": 2\n" +
- " },\n" +
- " {\n" +
- " \"node_index\": 1,\n" +
- " \"leaf_value\": 1\n" +
- " },\n" +
- " {\n" +
- " \"node_index\": 2,\n" +
- " \"leaf_value\": 2\n" +
- " }\n" +
- " ],\n" +
- " \"target_type\": \"regression\"\n" +
- " }\n" +
- " }" +
- "}";
- String compressedDefinitionString =
- InferenceToXContentCompressor.deflate(new BytesArray(definition.getBytes(StandardCharsets.UTF_8)));
- String compressedDefinition = "" +
- "{" +
- " \"model_id\": \"" + modelId + "\",\n" +
- " \"doc_type\": \"" + TrainedModelDefinitionDoc.NAME + "\",\n" +
- " \"doc_num\": " + 0 + ",\n" +
- " \"compression_version\": " + 1 + ",\n" +
- " \"total_definition_length\": " + compressedDefinitionString.length() + ",\n" +
- " \"definition_length\": " + compressedDefinitionString.length() + ",\n" +
- " \"definition\": \"" + compressedDefinitionString + "\"\n" +
- "}";
- assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME)
- .setId(modelId)
- .setSource(config, XContentType.JSON)
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .get().status(), equalTo(RestStatus.CREATED));
- assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME)
- .setId(TrainedModelDefinitionDoc.docId(modelId, 0))
- .setSource(compressedDefinition, XContentType.JSON)
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
- .get().status(), equalTo(RestStatus.CREATED));
+ private void putInferenceModel(String modelId) {
+ TrainedModelConfig config = TrainedModelConfig.builder()
+ .setParsedDefinition(
+ new TrainedModelDefinition.Builder()
+ .setTrainedModel(
+ Tree.builder()
+ .setTargetType(TargetType.REGRESSION)
+ .setFeatureNames(Arrays.asList("feature1"))
+ .setNodes(TreeNode.builder(0).setLeafValue(1.0))
+ .build())
+ .setPreProcessors(Collections.emptyList()))
+ .setModelId(modelId)
+ .setDescription("test model for classification")
+ .setInput(new TrainedModelInput(Arrays.asList("feature1")))
+ .build();
+ client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();
}
private static OperationMode randomInvalidLicenseType() {
diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java
index bfd92beefddad..a687124066d5c 100644
--- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java
+++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java
@@ -199,7 +199,6 @@ private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String
return TrainedModelConfig.builder()
.setCreatedBy("ml_test")
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
-
.setDescription("trained model config for test")
.setModelId(modelId)
.setVersion(Version.CURRENT)
diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model.json
new file mode 100644
index 0000000000000..a58fa13540748
--- /dev/null
+++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model.json
@@ -0,0 +1,28 @@
+{
+ "ml.put_trained_model":{
+ "documentation":{
+ "url":"TODO"
+ },
+ "stability":"experimental",
+ "url":{
+ "paths":[
+ {
+ "path":"/_ml/inference/{model_id}",
+ "methods":[
+ "PUT"
+ ],
+ "parts":{
+ "model_id":{
+ "type":"string",
+ "description":"The ID of the trained models to store"
+ }
+ }
+ }
+ ]
+ },
+ "body": {
+ "description":"The trained model configuration",
+ "required":true
+ }
+ }
+}
diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml
index f72fd1120d81e..c2da1d7ba2da3 100644
--- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml
+++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml
@@ -24,25 +24,130 @@
- match: { count: 0 }
- match: { trained_model_configs: [] }
---
-"Test delete given unused trained model":
+"Test get models":
+ - do:
+ ml.put_trained_model:
+ model_id: regression-model-0
+ body: >
+ {
+ "description": "empty model for tests",
+ "input": {"field_names": ["field1", "field2"]},
+ "definition": {
+ "preprocessors": [],
+ "trained_model": {
+ "tree": {
+ "feature_names": ["field1", "field2"],
+ "tree_structure": [
+ {"node_index": 0, "leaf_value": 1}
+ ],
+ "target_type": "regression"
+ }
+ }
+ }
+ }
+ - match: { model_id: "regression-model-0" }
- do:
- index:
- id: trained_model_config-unused-regression-model-0
- index: .ml-inference-000001
+ ml.put_trained_model:
+ model_id: regression-model-1
body: >
{
- "model_id": "unused-regression-model",
- "created_by": "ml_tests",
- "version": "8.0.0",
"description": "empty model for tests",
- "create_time": 0,
- "model_version": 0,
- "model_type": "local"
+ "input": {"field_names": ["field1", "field2"]},
+ "definition": {
+ "preprocessors": [],
+ "trained_model": {
+ "tree": {
+ "feature_names": ["field1", "field2"],
+ "tree_structure": [
+ {"node_index": 0, "leaf_value": 1}
+ ],
+ "target_type": "regression"
+ }
+ }
+ }
}
+ - match: { model_id: "regression-model-1" }
+
- do:
- indices.refresh: {}
+ ml.put_trained_model:
+ model_id: classification-model
+ body: >
+ {
+ "description": "empty model for tests",
+ "input": {"field_names": ["field1", "field2"]},
+ "definition": {
+ "preprocessors": [],
+ "trained_model": {
+ "tree": {
+ "feature_names": ["field1", "field2"],
+ "tree_structure": [
+ {"node_index": 0, "leaf_value": 1}
+ ],
+ "target_type": "classification",
+ "classification_labels": ["no", "yes"]
+ }
+ }
+ }
+ }
+ - match: { model_id: "classification-model" }
+
+ - do:
+ ml.get_trained_models:
+ model_id: "*"
+ - match: { count: 3 }
+ - match: { trained_model_configs.0.model_id: "classification-model" }
+ - match: { trained_model_configs.1.model_id: "regression-model-0" }
+ - match: { trained_model_configs.2.model_id: "regression-model-1" }
+ - do:
+ ml.get_trained_models:
+ model_id: "regression*"
+ - match: { count: 2 }
+ - match: { trained_model_configs.0.model_id: "regression-model-0" }
+ - match: { trained_model_configs.1.model_id: "regression-model-1" }
+
+ - do:
+ ml.get_trained_models:
+ model_id: "*"
+ from: 0
+ size: 2
+ - match: { count: 3 }
+ - match: { trained_model_configs.0.model_id: "classification-model" }
+ - match: { trained_model_configs.1.model_id: "regression-model-0" }
+
+ - do:
+ ml.get_trained_models:
+ model_id: "*"
+ from: 1
+ size: 1
+ - match: { count: 3 }
+ - match: { trained_model_configs.0.model_id: "regression-model-0" }
+---
+"Test delete given unused trained model":
+ - do:
+ ml.put_trained_model:
+ model_id: unused-regression-model
+ body: >
+ {
+ "model_id": "unused-regression-model",
+ "input": {"field_names": ["field1", "field2"]},
+ "description": "empty model for tests",
+ "definition": {
+ "preprocessors": [],
+ "trained_model": {
+ "tree": {
+ "feature_names": ["field1", "field2"],
+ "tree_structure": [
+ {"node_index": 0, "leaf_value": 1}
+ ],
+ "target_type": "regression"
+ }
+ }
+ }
+ }
+
+ - match: { model_id: "unused-regression-model" }
- do:
ml.delete_trained_model:
model_id: "unused-regression-model"
@@ -58,22 +163,28 @@
---
"Test delete given used trained model":
- do:
- index:
- id: trained_model_config-used-regression-model-0
- index: .ml-inference-000001
+ ml.put_trained_model:
+ model_id: used-regression-model
body: >
{
"model_id": "used-regression-model",
- "created_by": "ml_tests",
- "version": "8.0.0",
+ "input": {"field_names": ["field1", "field2"]},
"description": "empty model for tests",
- "create_time": 0,
- "model_version": 0,
- "model_type": "local"
+ "definition": {
+ "preprocessors": [],
+ "trained_model": {
+ "tree": {
+ "feature_names": ["field1", "field2"],
+ "tree_structure": [
+ {"node_index": 0, "leaf_value": 1}
+ ],
+ "target_type": "regression"
+ }
+ }
+ }
}
- - do:
- indices.refresh: {}
+ - match: { model_id: "used-regression-model" }
- do:
ingest.put_pipeline:
id: "regression-model-pipeline"
From 206738642c1a8825dfcda7c970790f5205d0a501 Mon Sep 17 00:00:00 2001
From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com>
Date: Fri, 10 Jan 2020 08:15:46 -0500
Subject: [PATCH 2/6] fixing tests
---
.../client/MLRequestConvertersTests.java | 2 +
.../client/MachineLearningIT.java | 7 +++
.../MlClientDocumentationIT.java | 45 +++----------------
3 files changed, 14 insertions(+), 40 deletions(-)
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java
index 34fdda50bd1e1..c137fbc464d56 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java
@@ -92,6 +92,7 @@
import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
+import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.client.ml.inference.TrainedModelConfigTests;
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
@@ -1063,6 +1064,7 @@ protected NamedXContentRegistry xContentRegistry() {
namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+ namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent);
}
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
index 853103ff1f855..728799dfccacd 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
@@ -148,6 +148,7 @@
import org.elasticsearch.client.ml.dataframe.explain.FieldSelection;
import org.elasticsearch.client.ml.dataframe.explain.MemoryEstimation;
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
+import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
@@ -168,6 +169,7 @@
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentType;
@@ -2741,4 +2743,9 @@ public void testEnableUpgradeMode() throws Exception {
mlInfoResponse = machineLearningClient.getMlInfo(new MlInfoRequest(), RequestOptions.DEFAULT);
assertThat(mlInfoResponse.getInfo().get("upgrade_mode"), equalTo(false));
}
+
+ @Override
+ protected NamedXContentRegistry xContentRegistry() {
+ return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+ }
}
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
index 4ee22e6289db0..6a8a0f5db990d 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
@@ -164,6 +164,7 @@
import org.elasticsearch.client.ml.dataframe.explain.FieldSelection;
import org.elasticsearch.client.ml.dataframe.explain.MemoryEstimation;
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
+import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
@@ -190,12 +191,11 @@
import org.elasticsearch.client.ml.job.results.OverallBucket;
import org.elasticsearch.client.ml.job.stats.JobStats;
import org.elasticsearch.common.bytes.BytesReference;
-import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentFactory;
-import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
@@ -206,12 +206,10 @@
import org.junit.After;
import java.io.IOException;
-import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
-import java.util.Base64;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
@@ -220,7 +218,6 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
-import java.util.zip.GZIPOutputStream;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.closeTo;
@@ -4165,41 +4162,9 @@ private void putTrainedModel(String modelId) throws IOException {
highLevelClient().machineLearning().putTrainedModel(new PutTrainedModelRequest(trainedModelConfig), RequestOptions.DEFAULT);
}
- private String compressDefinition(TrainedModelDefinition definition) throws IOException {
- BytesReference reference = XContentHelper.toXContent(definition, XContentType.JSON, false);
- BytesStreamOutput out = new BytesStreamOutput();
- try (OutputStream compressedOutput = new GZIPOutputStream(out, 4096)) {
- reference.writeTo(compressedOutput);
- }
- return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
- }
-
- private static String modelConfigString(String modelId) {
- return "{\n" +
- " \"doc_type\": \"trained_model_config\",\n" +
- " \"model_id\": \"" + modelId + "\",\n" +
- " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
- " \"description\": \"test model for\",\n" +
- " \"version\": \"7.6.0\",\n" +
- " \"license_level\": \"platinum\",\n" +
- " \"created_by\": \"ml_test\",\n" +
- " \"estimated_heap_memory_usage_bytes\": 0," +
- " \"estimated_operations\": 0," +
- " \"created_time\": 0\n" +
- "}";
- }
-
- private static String modelDocString(String compressedDefinition, String modelId) {
- return "" +
- "{" +
- "\"model_id\": \"" + modelId + "\",\n" +
- "\"doc_num\": 0,\n" +
- "\"doc_type\": \"trained_model_definition_doc\",\n" +
- " \"compression_version\": " + 1 + ",\n" +
- " \"total_definition_length\": " + compressedDefinition.length() + ",\n" +
- " \"definition_length\": " + compressedDefinition.length() + ",\n" +
- "\"definition\": \"" + compressedDefinition + "\"\n" +
- "}";
+ @Override
+ protected NamedXContentRegistry xContentRegistry() {
+ return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
}
private static final DataFrameAnalyticsConfig DF_ANALYTICS_CONFIG =
From 325ec13df2875cbfcf60898e13dd681ec0ef7e5b Mon Sep 17 00:00:00 2001
From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com>
Date: Fri, 10 Jan 2020 08:45:52 -0500
Subject: [PATCH 3/6] adding compression logic to hlrc for inference
definitions
---
.../InferenceToXContentCompressor.java | 81 +++++++++++++++++++
.../inference/SimpleBoundedInputStream.java | 68 ++++++++++++++++
.../client/MachineLearningIT.java | 26 ++++++
.../MlClientDocumentationIT.java | 21 +++--
.../InferenceToXContentCompressorTests.java | 70 ++++++++++++++++
.../high-level/ml/put-trained-model.asciidoc | 12 +--
6 files changed, 268 insertions(+), 10 deletions(-)
create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressor.java
create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/SimpleBoundedInputStream.java
create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressorTests.java
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressor.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressor.java
new file mode 100644
index 0000000000000..9bec4c4eb5d52
--- /dev/null
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressor.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml.inference;
+
+import org.elasticsearch.common.CheckedFunction;
+import org.elasticsearch.common.bytes.BytesArray;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.io.Streams;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.xcontent.DeprecationHandler;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.common.xcontent.XContentType;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.Base64;
+import java.util.zip.GZIPInputStream;
+import java.util.zip.GZIPOutputStream;
+
+/**
+ * Collection of helper methods. Similar to CompressedXContent, but this utilizes GZIP.
+ */
+public final class InferenceToXContentCompressor {
+ private static final int BUFFER_SIZE = 4096;
+ private static final long MAX_INFLATED_BYTES = 1_000_000_000; // 1 gb maximum
+
+ private InferenceToXContentCompressor() {}
+
+ public static String deflate(T objectToCompress) throws IOException {
+ BytesReference reference = XContentHelper.toXContent(objectToCompress, XContentType.JSON, false);
+ return deflate(reference);
+ }
+
+ public static T inflate(String compressedString,
+ CheckedFunction parserFunction,
+ NamedXContentRegistry xContentRegistry) throws IOException {
+ try(XContentParser parser = XContentHelper.createParser(xContentRegistry,
+ DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
+ inflate(compressedString, MAX_INFLATED_BYTES),
+ XContentType.JSON)) {
+ return parserFunction.apply(parser);
+ }
+ }
+
+ static BytesReference inflate(String compressedString, long streamSize) throws IOException {
+ byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8));
+ InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE);
+ InputStream inflateStream = new SimpleBoundedInputStream(gzipStream, streamSize);
+ return Streams.readFully(inflateStream);
+ }
+
+ private static String deflate(BytesReference reference) throws IOException {
+ BytesStreamOutput out = new BytesStreamOutput();
+ try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) {
+ reference.writeTo(compressedOutput);
+ }
+ return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
+ }
+}
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/SimpleBoundedInputStream.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/SimpleBoundedInputStream.java
new file mode 100644
index 0000000000000..683e23dc9d7cf
--- /dev/null
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/SimpleBoundedInputStream.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference;
+
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.Objects;
+
+/**
+ * This is a pared down bounded input stream.
+ * Only read is specifically enforced.
+ */
+final class SimpleBoundedInputStream extends InputStream {
+
+ private final InputStream in;
+ private final long maxBytes;
+ private long numBytes;
+
+ SimpleBoundedInputStream(InputStream inputStream, long maxBytes) {
+ this.in = Objects.requireNonNull(inputStream, "inputStream");
+ if (maxBytes < 0) {
+ throw new IllegalArgumentException("[maxBytes] must be greater than or equal to 0");
+ }
+ this.maxBytes = maxBytes;
+ }
+
+
+ /**
+ * A simple wrapper around the injected input stream that restricts the total number of bytes able to be read.
+ * @return The byte read. -1 on internal stream completion or when maxBytes is exceeded.
+ * @throws IOException on failure
+ */
+ @Override
+ public int read() throws IOException {
+ // We have reached the maximum, signal stream completion.
+ if (numBytes >= maxBytes) {
+ return -1;
+ }
+ numBytes++;
+ return in.read();
+ }
+
+ /**
+ * Delegates `close` to the wrapped InputStream
+ * @throws IOException on failure
+ */
+ @Override
+ public void close() throws IOException {
+ in.close();
+ }
+}
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
index 728799dfccacd..0664d49f76841 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
@@ -148,6 +148,7 @@
import org.elasticsearch.client.ml.dataframe.explain.FieldSelection;
import org.elasticsearch.client.ml.dataframe.explain.MemoryEstimation;
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
+import org.elasticsearch.client.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
@@ -2193,6 +2194,7 @@ public void testGetTrainedModels() throws Exception {
public void testPutTrainedModel() throws Exception {
String modelId = "test-put-trained-model";
+ String modelIdCompressed = "test-put-trained-model-compressed-definition";
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
@@ -2208,6 +2210,30 @@ public void testPutTrainedModel() throws Exception {
machineLearningClient::putTrainedModelAsync);
TrainedModelConfig createdModel = putTrainedModelResponse.getResponse();
assertThat(createdModel.getModelId(), equalTo(modelId));
+
+ definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
+ trainedModelConfig = TrainedModelConfig.builder()
+ .setCompressedDefinition(InferenceToXContentCompressor.deflate(definition))
+ .setModelId(modelIdCompressed)
+ .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
+ .setDescription("test model")
+ .build();
+ putTrainedModelResponse = execute(new PutTrainedModelRequest(trainedModelConfig),
+ machineLearningClient::putTrainedModel,
+ machineLearningClient::putTrainedModelAsync);
+ createdModel = putTrainedModelResponse.getResponse();
+ assertThat(createdModel.getModelId(), equalTo(modelIdCompressed));
+
+ GetTrainedModelsResponse getTrainedModelsResponse = execute(
+ new GetTrainedModelsRequest(modelIdCompressed).setDecompressDefinition(true).setIncludeDefinition(true),
+ machineLearningClient::getTrainedModels,
+ machineLearningClient::getTrainedModelsAsync);
+
+ assertThat(getTrainedModelsResponse.getCount(), equalTo(1L));
+ assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(1));
+ assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getCompressedDefinition(), is(nullValue()));
+ assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getDefinition(), is(not(nullValue())));
+ assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdCompressed));
}
public void testGetTrainedModelsStats() throws Exception {
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
index 6a8a0f5db990d..0db9dbf222f49 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
@@ -164,6 +164,7 @@
import org.elasticsearch.client.ml.dataframe.explain.FieldSelection;
import org.elasticsearch.client.ml.dataframe.explain.MemoryEstimation;
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
+import org.elasticsearch.client.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
@@ -3631,14 +3632,24 @@ public void testPutTrainedModel() throws Exception {
// tag::put-trained-model-config
TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
.setDefinition(definition) // <1>
- .setModelId("my-new-trained-model") // <2>
- .setInput(new TrainedModelInput("col1", "col2", "col3", "col4")) // <3>
- .setDescription("test model") // <4>
- .setMetadata(new HashMap<>()) // <5>
- .setTags("my_regression_models") // <6>
+ .setCompressedDefinition(InferenceToXContentCompressor.deflate(definition)) // <2>
+ .setModelId("my-new-trained-model") // <3>
+ .setInput(new TrainedModelInput("col1", "col2", "col3", "col4")) // <4>
+ .setDescription("test model") // <5>
+ .setMetadata(new HashMap<>()) // <6>
+ .setTags("my_regression_models") // <7>
.build();
// end::put-trained-model-config
+ trainedModelConfig = TrainedModelConfig.builder()
+ .setDefinition(definition)
+ .setModelId("my-new-trained-model")
+ .setInput(new TrainedModelInput("col1", "col2", "col3", "col4"))
+ .setDescription("test model")
+ .setMetadata(new HashMap<>())
+ .setTags("my_regression_models")
+ .build();
+
RestHighLevelClient client = highLevelClient();
{
// tag::put-trained-model-request
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressorTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressorTests.java
new file mode 100644
index 0000000000000..11747638a2c15
--- /dev/null
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressorTests.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference;
+
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.common.xcontent.XContentType;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class InferenceToXContentCompressorTests extends ESTestCase {
+
+ public void testInflateAndDeflate() throws IOException {
+ for(int i = 0; i < 10; i++) {
+ TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
+ String firstDeflate = InferenceToXContentCompressor.deflate(definition);
+ TrainedModelDefinition inflatedDefinition = InferenceToXContentCompressor.inflate(firstDeflate,
+ parser -> TrainedModelDefinition.fromXContent(parser).build(),
+ xContentRegistry());
+
+ // Did we inflate to the same object?
+ assertThat(inflatedDefinition, equalTo(definition));
+ }
+ }
+
+ public void testInflateTooLargeStream() throws IOException {
+ TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
+ String firstDeflate = InferenceToXContentCompressor.deflate(definition);
+ BytesReference inflatedBytes = InferenceToXContentCompressor.inflate(firstDeflate, 10L);
+ assertThat(inflatedBytes.length(), equalTo(10));
+ try(XContentParser parser = XContentHelper.createParser(xContentRegistry(),
+ LoggingDeprecationHandler.INSTANCE,
+ inflatedBytes,
+ XContentType.JSON)) {
+ expectThrows(IOException.class, () -> TrainedModelConfig.fromXContent(parser));
+ }
+ }
+
+ public void testInflateGarbage() {
+ expectThrows(IOException.class, () -> InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L));
+ }
+
+ @Override
+ protected NamedXContentRegistry xContentRegistry() {
+ return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+ }
+
+}
diff --git a/docs/java-rest/high-level/ml/put-trained-model.asciidoc b/docs/java-rest/high-level/ml/put-trained-model.asciidoc
index 5e7694f8221e3..dadc8dcf65a4f 100644
--- a/docs/java-rest/high-level/ml/put-trained-model.asciidoc
+++ b/docs/java-rest/high-level/ml/put-trained-model.asciidoc
@@ -32,11 +32,13 @@ configuration and contains the following arguments:
include-tagged::{doc-tests-file}[{api}-config]
--------------------------------------------------
<1> The {infer} definition for the model
-<2> The unique model id
-<3> The input field names for the model definition
-<4> Optionally, a human-readable description
-<5> Optionally, an object map contain metadata about the model
-<6> Optionally, an array of tags to organize the model
+<2> Optionally, if the {infer} definition is large, you may choose to compress it for transport.
+ Do not supply both the compressed and uncompressed definitions.
+<3> The unique model id
+<4> The input field names for the model definition
+<5> Optionally, a human-readable description
+<6> Optionally, an object map contain metadata about the model
+<7> Optionally, an array of tags to organize the model
include::../execution.asciidoc[]
From 48f2e5fef6a83661fd1fb32664cfd425210e297f Mon Sep 17 00:00:00 2001
From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com>
Date: Fri, 10 Jan 2020 09:36:05 -0500
Subject: [PATCH 4/6] fixing yaml test
---
.../rest-api-spec/test/ml/inference_crud.yml | 36 +++++++++----------
1 file changed, 18 insertions(+), 18 deletions(-)
diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml
index c2da1d7ba2da3..ece97ae9d9a7a 100644
--- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml
+++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml
@@ -27,7 +27,7 @@
"Test get models":
- do:
ml.put_trained_model:
- model_id: regression-model-0
+ model_id: a-regression-model-0
body: >
{
"description": "empty model for tests",
@@ -45,11 +45,11 @@
}
}
}
- - match: { model_id: "regression-model-0" }
+ - match: { model_id: "a-regression-model-0" }
- do:
ml.put_trained_model:
- model_id: regression-model-1
+ model_id: a-regression-model-1
body: >
{
"description": "empty model for tests",
@@ -67,11 +67,11 @@
}
}
}
- - match: { model_id: "regression-model-1" }
+ - match: { model_id: "a-regression-model-1" }
- do:
ml.put_trained_model:
- model_id: classification-model
+ model_id: a-classification-model
body: >
{
"description": "empty model for tests",
@@ -90,39 +90,39 @@
}
}
}
- - match: { model_id: "classification-model" }
+ - match: { model_id: "a-classification-model" }
- do:
ml.get_trained_models:
model_id: "*"
- - match: { count: 3 }
- - match: { trained_model_configs.0.model_id: "classification-model" }
- - match: { trained_model_configs.1.model_id: "regression-model-0" }
- - match: { trained_model_configs.2.model_id: "regression-model-1" }
+ - match: { count: 4 }
+ - match: { trained_model_configs.0.model_id: "a-classification-model" }
+ - match: { trained_model_configs.1.model_id: "a-regression-model-0" }
+ - match: { trained_model_configs.2.model_id: "a-regression-model-1" }
- do:
ml.get_trained_models:
- model_id: "regression*"
+ model_id: "a-regression*"
- match: { count: 2 }
- - match: { trained_model_configs.0.model_id: "regression-model-0" }
- - match: { trained_model_configs.1.model_id: "regression-model-1" }
+ - match: { trained_model_configs.0.model_id: "a-regression-model-0" }
+ - match: { trained_model_configs.1.model_id: "a-regression-model-1" }
- do:
ml.get_trained_models:
model_id: "*"
from: 0
size: 2
- - match: { count: 3 }
- - match: { trained_model_configs.0.model_id: "classification-model" }
- - match: { trained_model_configs.1.model_id: "regression-model-0" }
+ - match: { count: 4 }
+ - match: { trained_model_configs.0.model_id: "a-classification-model" }
+ - match: { trained_model_configs.1.model_id: "a-regression-model-0" }
- do:
ml.get_trained_models:
model_id: "*"
from: 1
size: 1
- - match: { count: 3 }
- - match: { trained_model_configs.0.model_id: "regression-model-0" }
+ - match: { count: 4 }
+ - match: { trained_model_configs.0.model_id: "a-regression-model-0" }
---
"Test delete given unused trained model":
- do:
From f04193f245dfe487e807961c6fbfba339f0c934a Mon Sep 17 00:00:00 2001
From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com>
Date: Fri, 10 Jan 2020 10:46:02 -0500
Subject: [PATCH 5/6] fixing yaml tests more
---
.../rest-api-spec/test/ml/inference_crud.yml | 122 ++++++------------
1 file changed, 39 insertions(+), 83 deletions(-)
diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml
index ece97ae9d9a7a..f5f9a56bab8d8 100644
--- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml
+++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml
@@ -1,31 +1,9 @@
----
-"Test get given missing trained model":
-
- - do:
- catch: missing
- ml.get_trained_models:
- model_id: "missing-trained-model"
----
-"Test get given expression without matches and allow_no_match is false":
-
- - do:
- catch: missing
- ml.get_trained_models:
- model_id: "missing-trained-model*"
- allow_no_match: false
-
----
-"Test get given expression without matches and allow_no_match is true":
-
- - do:
- ml.get_trained_models:
- model_id: "missing-trained-model*"
- allow_no_match: true
- - match: { count: 0 }
- - match: { trained_model_configs: [] }
----
-"Test get models":
+setup:
+ - skip:
+ features: headers
- do:
+ headers:
+ Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ml.put_trained_model:
model_id: a-regression-model-0
body: >
@@ -45,9 +23,10 @@
}
}
}
- - match: { model_id: "a-regression-model-0" }
- do:
+ headers:
+ Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ml.put_trained_model:
model_id: a-regression-model-1
body: >
@@ -67,9 +46,9 @@
}
}
}
- - match: { model_id: "a-regression-model-1" }
-
- do:
+ headers:
+ Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ml.put_trained_model:
model_id: a-classification-model
body: >
@@ -90,8 +69,33 @@
}
}
}
- - match: { model_id: "a-classification-model" }
+---
+"Test get given missing trained model":
+
+ - do:
+ catch: missing
+ ml.get_trained_models:
+ model_id: "missing-trained-model"
+---
+"Test get given expression without matches and allow_no_match is false":
+
+ - do:
+ catch: missing
+ ml.get_trained_models:
+ model_id: "missing-trained-model*"
+ allow_no_match: false
+
+---
+"Test get given expression without matches and allow_no_match is true":
+ - do:
+ ml.get_trained_models:
+ model_id: "missing-trained-model*"
+ allow_no_match: true
+ - match: { count: 0 }
+ - match: { trained_model_configs: [] }
+---
+"Test get models":
- do:
ml.get_trained_models:
model_id: "*"
@@ -125,66 +129,18 @@
- match: { trained_model_configs.0.model_id: "a-regression-model-0" }
---
"Test delete given unused trained model":
- - do:
- ml.put_trained_model:
- model_id: unused-regression-model
- body: >
- {
- "model_id": "unused-regression-model",
- "input": {"field_names": ["field1", "field2"]},
- "description": "empty model for tests",
- "definition": {
- "preprocessors": [],
- "trained_model": {
- "tree": {
- "feature_names": ["field1", "field2"],
- "tree_structure": [
- {"node_index": 0, "leaf_value": 1}
- ],
- "target_type": "regression"
- }
- }
- }
- }
-
- - match: { model_id: "unused-regression-model" }
- do:
ml.delete_trained_model:
- model_id: "unused-regression-model"
+ model_id: "a-classification-model"
- match: { acknowledged: true }
-
---
"Test delete with missing model":
- do:
catch: missing
ml.delete_trained_model:
model_id: "missing-trained-model"
-
---
"Test delete given used trained model":
- - do:
- ml.put_trained_model:
- model_id: used-regression-model
- body: >
- {
- "model_id": "used-regression-model",
- "input": {"field_names": ["field1", "field2"]},
- "description": "empty model for tests",
- "definition": {
- "preprocessors": [],
- "trained_model": {
- "tree": {
- "feature_names": ["field1", "field2"],
- "tree_structure": [
- {"node_index": 0, "leaf_value": 1}
- ],
- "target_type": "regression"
- }
- }
- }
- }
-
- - match: { model_id: "used-regression-model" }
- do:
ingest.put_pipeline:
id: "regression-model-pipeline"
@@ -193,7 +149,7 @@
"processors": [
{
"inference" : {
- "model_id" : "used-regression-model",
+ "model_id" : "a-regression-model-0",
"inference_config": {"regression": {}},
"target_field": "regression_field",
"field_mappings": {}
@@ -206,12 +162,12 @@
- do:
catch: conflict
ml.delete_trained_model:
- model_id: "used-regression-model"
+ model_id: "a-regression-model-0"
---
"Test get pre-packaged trained models":
- do:
ml.get_trained_models:
- model_id: "_all"
+ model_id: "lang_ident_model_1"
allow_no_match: false
- match: { count: 1 }
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
From b102d375b26f0a22e3b2bb819c01091b38882cc4 Mon Sep 17 00:00:00 2001
From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com>
Date: Sat, 11 Jan 2020 15:14:44 -0500
Subject: [PATCH 6/6] addressing PR comments
---
.../client/MachineLearningClient.java | 2 +-
.../core/ml/action/PutTrainedModelAction.java | 2 +-
.../core/ml/inference/TrainedModelConfig.java | 27 ++++++++++++-------
.../ml/inference/TrainedModelConfigTests.java | 6 ++---
.../TransportPutTrainedModelAction.java | 2 +-
.../inference/RestPutTrainedModelAction.java | 2 +-
6 files changed, 24 insertions(+), 17 deletions(-)
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java
index 3c0bcfd9230ea..bdb2f22f3b3fa 100644
--- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java
@@ -2362,7 +2362,7 @@ public PutTrainedModelResponse putTrainedModel(PutTrainedModelRequest request, R
}
/**
- * Get trained model config asynchronously and notifies listener upon completion
+ * Put trained model config asynchronously and notifies listener upon completion
*
* For additional info
* see
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java
index 97045489001a5..06fbb6401a082 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java
@@ -46,7 +46,7 @@ public static Request parseRequest(String modelId, XContentParser parser) {
}
// Validations are done against the builder so we can build the full config object.
// This allows us to not worry about serializing a builder class between nodes.
- return new Request(builder.validate().build());
+ return new Request(builder.validate(true).build());
}
private final TrainedModelConfig config;
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java
index 4d59bc336e86b..95589ac8b61fe 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java
@@ -488,12 +488,16 @@ public Builder setLicenseLevel(String licenseLevel) {
return this;
}
+ public Builder validate() {
+ return validate(false);
+ }
+
/**
* Runs validations against the builder.
* @return The current builder object if validations are successful
* @throws ActionRequestValidationException when there are validation failures.
*/
- public Builder validate() {
+ public Builder validate(boolean forCreation) {
// We require a definition to be available here even though it will be stored in a different doc
ActionRequestValidationException validationException = null;
if (definition == null) {
@@ -531,15 +535,18 @@ public Builder validate() {
break;
}
}
-
- validationException = checkIllegalSetting(version, VERSION.getPreferredName(), validationException);
- validationException = checkIllegalSetting(createdBy, CREATED_BY.getPreferredName(), validationException);
- validationException = checkIllegalSetting(createTime, CREATE_TIME.getPreferredName(), validationException);
- validationException = checkIllegalSetting(estimatedHeapMemory,
- ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
- validationException);
- validationException = checkIllegalSetting(estimatedOperations, ESTIMATED_OPERATIONS.getPreferredName(), validationException);
- validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException);
+ if (forCreation) {
+ validationException = checkIllegalSetting(version, VERSION.getPreferredName(), validationException);
+ validationException = checkIllegalSetting(createdBy, CREATED_BY.getPreferredName(), validationException);
+ validationException = checkIllegalSetting(createTime, CREATE_TIME.getPreferredName(), validationException);
+ validationException = checkIllegalSetting(estimatedHeapMemory,
+ ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
+ validationException);
+ validationException = checkIllegalSetting(estimatedOperations,
+ ESTIMATED_OPERATIONS.getPreferredName(),
+ validationException);
+ validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException);
+ }
if (validationException != null) {
throw validationException;
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java
index db570025c786a..67b67a45500f1 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java
@@ -223,21 +223,21 @@ public void testValidateWithIllegallyUserProvidedFields() {
() -> TrainedModelConfig.builder()
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setCreateTime(Instant.now())
- .setModelId(modelId).validate());
+ .setModelId(modelId).validate(true));
assertThat(ex.getMessage(), containsString("illegal to set [create_time] at inference model creation"));
ex = expectThrows(ActionRequestValidationException.class,
() -> TrainedModelConfig.builder()
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setVersion(Version.CURRENT)
- .setModelId(modelId).validate());
+ .setModelId(modelId).validate(true));
assertThat(ex.getMessage(), containsString("illegal to set [version] at inference model creation"));
ex = expectThrows(ActionRequestValidationException.class,
() -> TrainedModelConfig.builder()
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setCreatedBy("ml_user")
- .setModelId(modelId).validate());
+ .setModelId(modelId).validate(true));
assertThat(ex.getMessage(), containsString("illegal to set [created_by] at inference model creation"));
}
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java
index 9fca82802b51d..575b8ac00dfb5 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java
@@ -100,7 +100,7 @@ protected void masterOperation(Task task,
TrainedModelConfig trainedModelConfig = new TrainedModelConfig.Builder(request.getTrainedModelConfig())
.setVersion(Version.CURRENT)
.setCreateTime(Instant.now())
- .setCreatedBy("user")
+ .setCreatedBy("api_user")
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.setEstimatedHeapMemory(request.getTrainedModelConfig().getModelDefinition().ramBytesUsed())
.setEstimatedOperations(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().estimatedNumOperations())
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java
index 83ec70629bc78..cb3f4e0eddee2 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java
@@ -27,7 +27,7 @@ public RestPutTrainedModelAction(RestController controller) {
@Override
public String getName() {
- return "xpack_ml_put_data_frame_analytics_action";
+ return "xpack_ml_put_trained_model_action";
}
@Override