diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAction.java new file mode 100644 index 0000000000000..521070a959db6 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelAction.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ToXContentFragment; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class DeleteTrainedModelAction extends ActionType { + + public static final DeleteTrainedModelAction INSTANCE = new DeleteTrainedModelAction(); + public static final String NAME = "cluster:admin/xpack/ml/inference/delete"; + + private DeleteTrainedModelAction() { + super(NAME, AcknowledgedResponse::new); + } + + public static class Request extends AcknowledgedRequest implements ToXContentFragment { + + private String id; + + public Request(StreamInput in) throws IOException { + super(in); + id = in.readString(); + } + + public Request() {} + + public Request(String id) { + this.id = ExceptionsHelper.requireNonNull(id, TrainedModelConfig.MODEL_ID); + } + + public String getId() { + return id; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), id); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DeleteTrainedModelAction.Request request = (DeleteTrainedModelAction.Request) o; + return Objects.equals(id, request.id); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(id); + } + + @Override + public int hashCode() { + return Objects.hash(id); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java new file mode 100644 index 0000000000000..005f0d180cdc1 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest; +import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse; +import org.elasticsearch.xpack.core.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; + +import java.io.IOException; + +public class GetTrainedModelsAction extends ActionType { + + public static final GetTrainedModelsAction INSTANCE = new GetTrainedModelsAction(); + public static final String NAME = "cluster:monitor/xpack/ml/inference/get"; + + private GetTrainedModelsAction() { + super(NAME, Response::new); + } + + public static class Request extends AbstractGetResourcesRequest { + + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + public Request() { + setAllowNoResources(true); + } + + public Request(String id) { + setResourceId(id); + setAllowNoResources(true); + } + + public Request(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getResourceIdField() { + return TrainedModelConfig.MODEL_ID.getPreferredName(); + } + + } + + public static class Response extends AbstractGetResourcesResponse { + + public static final ParseField RESULTS_FIELD = new ParseField("trained_model_configs"); + + public Response(StreamInput in) throws IOException { + super(in); + } + + public Response(QueryPage trainedModels) { + super(trainedModels); + } + + @Override + protected Reader getReader() { + return TrainedModelConfig::new; + } + } + + public static class RequestBuilder extends ActionRequestBuilder { + + public RequestBuilder(ElasticsearchClient client) { + super(client, INSTANCE, new Request()); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessage.java new file mode 100644 index 0000000000000..7c1b93786bc91 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessage.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.notifications; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.xpack.core.common.notifications.AbstractAuditMessage; +import org.elasticsearch.xpack.core.common.notifications.Level; +import org.elasticsearch.xpack.core.ml.job.config.Job; + +import java.util.Date; + + +public class InferenceAuditMessage extends AbstractAuditMessage { + + //TODO this should be MODEL_ID... + private static final ParseField JOB_ID = Job.ID; + public static final ConstructingObjectParser PARSER = + createParser("ml_inference_audit_message", InferenceAuditMessage::new, JOB_ID); + + public InferenceAuditMessage(String resourceId, String message, Level level, Date timestamp, String nodeName) { + super(resourceId, message, level, timestamp, nodeName); + } + + @Override + public final String getJobType() { + return "inference"; + } + + @Override + protected String getResourceField() { + return JOB_ID.getPreferredName(); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java index d2e5a207355f3..9cc4a5cdfbf53 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java @@ -43,6 +43,10 @@ public static ResourceAlreadyExistsException dataFrameAnalyticsAlreadyExists(Str return new ResourceAlreadyExistsException("A data frame analytics with id [{}] already exists", id); } + public static ResourceNotFoundException missingTrainedModel(String modelId) { + return new ResourceNotFoundException("No known trained model with model_id [{}]", modelId); + } + public static ElasticsearchException serverError(String msg) { return new ElasticsearchException(msg); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelsRequestTests.java new file mode 100644 index 0000000000000..0797b20d438be --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/DeleteTrainedModelsRequestTests.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction.Request; + +public class DeleteTrainedModelsRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + return new Request(randomAlphaOfLengthBetween(1, 20)); + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java new file mode 100644 index 0000000000000..0abc0318e215e --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request; + +public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + Request request = new Request(randomAlphaOfLength(20)); + request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100))); + return request; + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AnomalyDetectionAuditMessageTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AnomalyDetectionAuditMessageTests.java index c6a904228b6a7..f6a319dab7ab3 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AnomalyDetectionAuditMessageTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AnomalyDetectionAuditMessageTests.java @@ -6,19 +6,16 @@ package org.elasticsearch.xpack.core.ml.notifications; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.xpack.core.common.notifications.Level; import org.elasticsearch.xpack.core.ml.job.config.Job; import java.util.Date; -import static org.hamcrest.Matchers.equalTo; +public class AnomalyDetectionAuditMessageTests extends AuditMessageTests { -public class AnomalyDetectionAuditMessageTests extends AbstractXContentTestCase { - - public void testGetJobType() { - AnomalyDetectionAuditMessage message = createTestInstance(); - assertThat(message.getJobType(), equalTo(Job.ANOMALY_DETECTOR_JOB_TYPE)); + @Override + public String getJobType() { + return Job.ANOMALY_DETECTOR_JOB_TYPE; } @Override @@ -26,11 +23,6 @@ protected AnomalyDetectionAuditMessage doParseInstance(XContentParser parser) { return AnomalyDetectionAuditMessage.PARSER.apply(parser, null); } - @Override - protected boolean supportsUnknownFields() { - return true; - } - @Override protected AnomalyDetectionAuditMessage createTestInstance() { return new AnomalyDetectionAuditMessage( diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AuditMessageTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AuditMessageTests.java new file mode 100644 index 0000000000000..2ccb1fbcbf4b3 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/AuditMessageTests.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.notifications; + +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xpack.core.common.notifications.AbstractAuditMessage; + + +import static org.hamcrest.Matchers.equalTo; + +public abstract class AuditMessageTests extends AbstractXContentTestCase { + + public abstract String getJobType(); + + public void testGetJobType() { + AbstractAuditMessage message = createTestInstance(); + assertThat(message.getJobType(), equalTo(getJobType())); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/DataFrameAnalyticsAuditMessageTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/DataFrameAnalyticsAuditMessageTests.java index 139e76160d4a6..9637af79a947c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/DataFrameAnalyticsAuditMessageTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/DataFrameAnalyticsAuditMessageTests.java @@ -6,28 +6,20 @@ package org.elasticsearch.xpack.core.ml.notifications; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.xpack.core.common.notifications.Level; import java.util.Date; -import static org.hamcrest.Matchers.equalTo; +public class DataFrameAnalyticsAuditMessageTests extends AuditMessageTests { -public class DataFrameAnalyticsAuditMessageTests extends AbstractXContentTestCase { - - public void testGetJobType() { - DataFrameAnalyticsAuditMessage message = createTestInstance(); - assertThat(message.getJobType(), equalTo("data_frame_analytics")); - } - @Override - protected DataFrameAnalyticsAuditMessage doParseInstance(XContentParser parser) { - return DataFrameAnalyticsAuditMessage.PARSER.apply(parser, null); + public String getJobType() { + return "data_frame_analytics"; } @Override - protected boolean supportsUnknownFields() { - return true; + protected DataFrameAnalyticsAuditMessage doParseInstance(XContentParser parser) { + return DataFrameAnalyticsAuditMessage.PARSER.apply(parser, null); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessageTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessageTests.java new file mode 100644 index 0000000000000..5a9b86578ef59 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/notifications/InferenceAuditMessageTests.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.notifications; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.common.notifications.Level; + +import java.util.Date; + +public class InferenceAuditMessageTests extends AuditMessageTests { + + @Override + public String getJobType() { + return "inference"; + } + + @Override + protected InferenceAuditMessage doParseInstance(XContentParser parser) { + return InferenceAuditMessage.PARSER.apply(parser, null); + } + + @Override + protected InferenceAuditMessage createTestInstance() { + return new InferenceAuditMessage( + randomBoolean() ? null : randomAlphaOfLength(10), + randomAlphaOfLengthBetween(1, 20), + randomFrom(Level.values()), + new Date(), + randomBoolean() ? null : randomAlphaOfLengthBetween(1, 20) + ); + } +} diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index e330d032c0a0d..2dd63883b523a 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -125,6 +125,11 @@ integTest.runner { 'ml/filter_crud/Test get all filter given index exists but no mapping for filter_id', 'ml/get_datafeed_stats/Test get datafeed stats given missing datafeed_id', 'ml/get_datafeeds/Test get datafeed given missing datafeed_id', + 'ml/inference_crud/Test delete given used trained model', + 'ml/inference_crud/Test delete given unused trained model', + 'ml/inference_crud/Test delete with missing model', + 'ml/inference_crud/Test get given missing trained model', + 'ml/inference_crud/Test get given expression without matches and allow_no_match is false', 'ml/jobs_crud/Test cannot create job with existing categorizer state document', 'ml/jobs_crud/Test cannot create job with existing quantiles document', 'ml/jobs_crud/Test cannot create job with existing result document', diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java new file mode 100644 index 0000000000000..6d3fe32332a72 --- /dev/null +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -0,0 +1,168 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.apache.http.util.EntityUtils; +import org.elasticsearch.Version; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.client.ResponseException; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.test.SecuritySettingsSourceField; +import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests; +import org.junit.After; + +import java.io.IOException; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; + +import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; + +public class TrainedModelIT extends ESRestTestCase { + + private static final String BASIC_AUTH_VALUE = basicAuthHeaderValue("x_pack_rest_user", + SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING); + + @Override + protected Settings restClientSettings() { + return Settings.builder().put(super.restClientSettings()).put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE).build(); + } + + @Override + protected boolean preserveTemplatesUponCompletion() { + return true; + } + + public void testGetTrainedModels() throws IOException { + String modelId = "test_regression_model"; + String modelId2 = "test_regression_model-2"; + Request model1 = new Request("PUT", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId); + model1.setJsonEntity(buildRegressionModel(modelId)); + assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201)); + + Request model2 = new Request("PUT", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId2); + model2.setJsonEntity(buildRegressionModel(modelId2)); + assertThat(client().performRequest(model2).getStatusLine().getStatusCode(), equalTo(201)); + + adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh")); + Response getModel = client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/" + modelId)); + + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + String response = EntityUtils.toString(getModel.getEntity()); + + assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); + assertThat(response, containsString("\"count\":1")); + + getModel = client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/test_regression*")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); + assertThat(response, containsString("\"model_id\":\"test_regression_model-2\"")); + assertThat(response, containsString("\"count\":2")); + + getModel = client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/test_regression_model,test_regression_model-2")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); + assertThat(response, containsString("\"model_id\":\"test_regression_model-2\"")); + assertThat(response, containsString("\"count\":2")); + + getModel = client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/classification*?allow_no_match=true")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"count\":0")); + + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/classification*?allow_no_match=false"))); + assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(404)); + + getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference?from=0&size=1")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"count\":2")); + assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); + assertThat(response, not(containsString("\"model_id\":\"test_regression_model-2\""))); + + getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference?from=1&size=1")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"count\":2")); + assertThat(response, not(containsString("\"model_id\":\"test_regression_model\""))); + assertThat(response, containsString("\"model_id\":\"test_regression_model-2\"")); + } + + public void testDeleteTrainedModels() throws IOException { + String modelId = "test_delete_regression_model"; + Request model1 = new Request("PUT", + InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId); + model1.setJsonEntity(buildRegressionModel(modelId)); + assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201)); + + adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh")); + + Response delModel = client().performRequest(new Request("DELETE", + MachineLearning.BASE_PATH + "inference/" + modelId)); + String response = EntityUtils.toString(delModel.getEntity()); + assertThat(response, containsString("\"acknowledged\":true")); + + ResponseException responseException = expectThrows(ResponseException.class, + () -> client().performRequest(new Request("DELETE", MachineLearning.BASE_PATH + "inference/" + modelId))); + assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404)); + } + + private static String buildRegressionModel(String modelId) throws IOException { + try(XContentBuilder builder = XContentFactory.jsonBuilder()) { + TrainedModelConfig.builder() + .setModelId(modelId) + .setCreatedBy("ml_test") + .setDefinition(new TrainedModelDefinition.Builder() + .setInput(new TrainedModelDefinition.Input(Arrays.asList("col1", "col2", "col3"))) + .setPreProcessors(Collections.emptyList()) + .setTrainedModel(LocalModelTests.buildRegression())) + .setVersion(Version.CURRENT) + .setCreateTime(Instant.now()) + .build() + .toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"))); + return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON); + } + } + + + @After + public void clearMlState() throws Exception { + new MlRestTestStateCleaner(logger, adminClient()).clearMlMetadata(); + ESRestTestCase.waitForPendingTasks(adminClient()); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 46084836b6b89..500a71b3a9416 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -42,12 +42,14 @@ import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.analysis.TokenizerFactory; import org.elasticsearch.indices.analysis.AnalysisModule.AnalysisProvider; +import org.elasticsearch.ingest.Processor; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.monitor.os.OsProbe; import org.elasticsearch.monitor.os.OsStats; import org.elasticsearch.persistent.PersistentTasksExecutor; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.AnalysisPlugin; +import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.PersistentTaskPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.RestController; @@ -73,6 +75,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction; import org.elasticsearch.xpack.core.ml.action.DeleteJobAction; import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction; @@ -94,6 +97,7 @@ import org.elasticsearch.xpack.core.ml.action.GetModelSnapshotsAction; import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction; import org.elasticsearch.xpack.core.ml.action.GetRecordsAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.KillProcessAction; @@ -141,6 +145,7 @@ import org.elasticsearch.xpack.ml.action.TransportDeleteForecastAction; import org.elasticsearch.xpack.ml.action.TransportDeleteJobAction; import org.elasticsearch.xpack.ml.action.TransportDeleteModelSnapshotAction; +import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction; import org.elasticsearch.xpack.ml.action.TransportEstimateMemoryUsageAction; import org.elasticsearch.xpack.ml.action.TransportEvaluateDataFrameAction; import org.elasticsearch.xpack.ml.action.TransportFinalizeJobExecutionAction; @@ -163,6 +168,7 @@ import org.elasticsearch.xpack.ml.action.TransportGetOverallBucketsAction; import org.elasticsearch.xpack.ml.action.TransportGetRecordsAction; import org.elasticsearch.xpack.ml.action.TransportInferModelAction; +import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsAction; import org.elasticsearch.xpack.ml.action.TransportIsolateDatafeedAction; import org.elasticsearch.xpack.ml.action.TransportKillProcessAction; import org.elasticsearch.xpack.ml.action.TransportMlInfoAction; @@ -203,6 +209,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.job.JobManager; @@ -225,6 +232,7 @@ import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerProcessFactory; import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import org.elasticsearch.xpack.ml.process.DummyController; import org.elasticsearch.xpack.ml.process.MlController; import org.elasticsearch.xpack.ml.process.MlControllerHolder; @@ -263,6 +271,8 @@ import org.elasticsearch.xpack.ml.rest.filter.RestGetFiltersAction; import org.elasticsearch.xpack.ml.rest.filter.RestPutFilterAction; import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction; +import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction; +import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction; import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction; @@ -303,7 +313,7 @@ import static java.util.Collections.emptyList; import static org.elasticsearch.index.mapper.MapperService.SINGLE_MAPPING_NAME; -public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlugin, PersistentTaskPlugin { +public class MachineLearning extends Plugin implements ActionPlugin, IngestPlugin, AnalysisPlugin, PersistentTaskPlugin { public static final String NAME = "ml"; public static final String BASE_PATH = "/_ml/"; public static final String PRE_V7_BASE_PATH = "/_xpack/ml/"; @@ -327,6 +337,15 @@ protected Setting roleSetting() { }; + @Override + public Map getProcessors(Processor.Parameters parameters) { + InferenceProcessor.Factory inferenceFactory = new InferenceProcessor.Factory(parameters.client, + parameters.ingestService.getClusterService(), + this.settings); + parameters.ingestService.addIngestClusterStateListener(inferenceFactory); + return Collections.singletonMap(InferenceProcessor.TYPE, inferenceFactory); + } + @Override public Set getRoles() { return Collections.singleton(ML_ROLE); @@ -487,6 +506,7 @@ public Collection createComponents(Client client, ClusterService cluster AnomalyDetectionAuditor anomalyDetectionAuditor = new AnomalyDetectionAuditor(client, clusterService.getNodeName()); DataFrameAnalyticsAuditor dataFrameAnalyticsAuditor = new DataFrameAnalyticsAuditor(client, clusterService.getNodeName()); + InferenceAuditor inferenceAuditor = new InferenceAuditor(client, clusterService.getNodeName()); this.dataFrameAnalyticsAuditor.set(dataFrameAnalyticsAuditor); JobResultsProvider jobResultsProvider = new JobResultsProvider(client, settings); JobResultsPersister jobResultsPersister = new JobResultsPersister(client); @@ -619,6 +639,7 @@ public Collection createComponents(Client client, ClusterService cluster datafeedManager, anomalyDetectionAuditor, dataFrameAnalyticsAuditor, + inferenceAuditor, mlAssignmentNotifier, memoryTracker, analyticsProcessManager, @@ -709,7 +730,9 @@ public List getRestHandlers(Settings settings, RestController restC new RestStartDataFrameAnalyticsAction(restController), new RestStopDataFrameAnalyticsAction(restController), new RestEvaluateDataFrameAction(restController), - new RestEstimateMemoryUsageAction(restController) + new RestEstimateMemoryUsageAction(restController), + new RestGetTrainedModelsAction(restController), + new RestDeleteTrainedModelAction(restController) ); } @@ -782,6 +805,8 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(EvaluateDataFrameAction.INSTANCE, TransportEvaluateDataFrameAction.class), new ActionHandler<>(EstimateMemoryUsageAction.INSTANCE, TransportEstimateMemoryUsageAction.class), new ActionHandler<>(InferModelAction.INSTANCE, TransportInferModelAction.class), + new ActionHandler<>(GetTrainedModelsAction.INSTANCE, TransportGetTrainedModelsAction.class), + new ActionHandler<>(DeleteTrainedModelAction.INSTANCE, TransportDeleteTrainedModelAction.class), usageAction, infoAction); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java new file mode 100644 index 0000000000000..aadcb9dd34708 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java @@ -0,0 +1,134 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.IngestService; +import org.elasticsearch.ingest.Pipeline; +import org.elasticsearch.ingest.PipelineConfiguration; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + + +/** + * The action is a master node action to ensure it reads an up-to-date cluster + * state in order to determine if there is a processor referencing the trained model + */ +public class TransportDeleteTrainedModelAction + extends TransportMasterNodeAction { + + private static final Logger LOGGER = LogManager.getLogger(TransportDeleteTrainedModelAction.class); + + private final TrainedModelProvider trainedModelProvider; + private final InferenceAuditor auditor; + private final IngestService ingestService; + + @Inject + public TransportDeleteTrainedModelAction(TransportService transportService, ClusterService clusterService, + ThreadPool threadPool, ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver, + TrainedModelProvider configProvider, InferenceAuditor auditor, + IngestService ingestService) { + super(DeleteTrainedModelAction.NAME, transportService, clusterService, threadPool, actionFilters, + DeleteTrainedModelAction.Request::new, indexNameExpressionResolver); + this.trainedModelProvider = configProvider; + this.ingestService = ingestService; + this.auditor = Objects.requireNonNull(auditor); + } + + @Override + protected String executor() { + return ThreadPool.Names.SAME; + } + + @Override + protected AcknowledgedResponse read(StreamInput in) throws IOException { + return new AcknowledgedResponse(in); + } + + @Override + protected void masterOperation(Task task, + DeleteTrainedModelAction.Request request, + ClusterState state, + ActionListener listener) { + String id = request.getId(); + IngestMetadata currentIngestMetadata = state.metaData().custom(IngestMetadata.TYPE); + Set referencedModels = getReferencedModelKeys(currentIngestMetadata); + + if (referencedModels.contains(id)) { + listener.onFailure(new ElasticsearchStatusException("Cannot delete model [{}] as it is still referenced by ingest processors", + RestStatus.CONFLICT, + id)); + return; + } + + trainedModelProvider.deleteTrainedModel(request.getId(), ActionListener.wrap( + r -> { + auditor.info(request.getId(), "trained model deleted"); + listener.onResponse(new AcknowledgedResponse(true)); + }, + listener::onFailure + )); + } + + private Set getReferencedModelKeys(IngestMetadata ingestMetadata) { + Set allReferencedModelKeys = new HashSet<>(); + if (ingestMetadata == null) { + return allReferencedModelKeys; + } + for(Map.Entry entry : ingestMetadata.getPipelines().entrySet()) { + String pipelineId = entry.getKey(); + Map config = entry.getValue().getConfigAsMap(); + try { + Pipeline pipeline = Pipeline.create(pipelineId, + config, + ingestService.getProcessorFactories(), + ingestService.getScriptService()); + pipeline.getProcessors().stream() + .filter(p -> p instanceof InferenceProcessor) + .map(p -> (InferenceProcessor) p) + .map(InferenceProcessor::getModelId) + .forEach(allReferencedModelKeys::add); + } catch (Exception ex) { + LOGGER.warn(new ParameterizedMessage("failed to load pipeline [{}]", pipelineId), ex); + } + } + return allReferencedModelKeys; + } + + + @Override + protected ClusterBlockException checkBlock(DeleteTrainedModelAction.Request request, ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java new file mode 100644 index 0000000000000..ee95ddbd9670d --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.action.AbstractTransportGetResourcesAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Response; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; + +public class TransportGetTrainedModelsAction extends AbstractTransportGetResourcesAction { + + @Inject + public TransportGetTrainedModelsAction(TransportService transportService, ActionFilters actionFilters, Client client, + NamedXContentRegistry xContentRegistry) { + super(GetTrainedModelsAction.NAME, transportService, actionFilters, GetTrainedModelsAction.Request::new, client, + xContentRegistry); + } + + @Override + protected ParseField getResultsField() { + return GetTrainedModelsAction.Response.RESULTS_FIELD; + } + + @Override + protected String[] getIndices() { + return new String[] { InferenceIndexConstants.INDEX_PATTERN }; + } + + @Override + protected TrainedModelConfig parse(XContentParser parser) { + return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build(); + } + + @Override + protected ResourceNotFoundException notFoundException(String resourceId) { + return ExceptionsHelper.missingTrainedModel(resourceId); + } + + @Override + protected void doExecute(Task task, GetTrainedModelsAction.Request request, + ActionListener listener) { + searchResources(request, ActionListener.wrap( + queryPage -> listener.onResponse(new GetTrainedModelsAction.Response(queryPage)), + listener::onFailure + )); + } + + @Override + protected String executionOrigin() { + return ML_ORIGIN; + } + + @Override + protected String extractIdFromResource(TrainedModelConfig config) { + return config.getModelId(); + } + + @Override + protected SearchSourceBuilder customSearchOptions(SearchSourceBuilder searchSourceBuilder) { + return searchSourceBuilder.sort("_index", SortOrder.DESC); + } + + @Nullable + protected QueryBuilder additionalQuery() { + return QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index 59f6c62a7f55e..b8cccc0d45e23 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -10,6 +10,14 @@ import org.elasticsearch.ingest.IngestDocument; import java.util.function.BiConsumer; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.ingest.ConfigurationUtils; +import org.elasticsearch.ingest.Processor; + +import java.util.Map; +import java.util.function.Consumer; public class InferenceProcessor extends AbstractProcessor { @@ -17,10 +25,16 @@ public class InferenceProcessor extends AbstractProcessor { public static final String MODEL_ID = "model_id"; private final Client client; + private final String modelId; - public InferenceProcessor(Client client, String tag) { + public InferenceProcessor(Client client, String tag, String modelId) { super(tag); this.client = client; + this.modelId = modelId; + } + + public String getModelId() { + return modelId; } @Override @@ -38,4 +52,27 @@ public IngestDocument execute(IngestDocument ingestDocument) { public String getType() { return TYPE; } + + public static class Factory implements Processor.Factory, Consumer { + + private final Client client; + private final ClusterService clusterService; + + public Factory(Client client, ClusterService clusterService, Settings settings) { + this.client = client; + this.clusterService = clusterService; + } + + @Override + public Processor create(Map processorFactories, String tag, Map config) + throws Exception { + String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID); + return new InferenceProcessor(client, tag, modelId); + } + + @Override + public void accept(ClusterState clusterState) { + + } + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 6f1e543896c9d..3ad5004b9a032 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -28,9 +28,12 @@ import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.reindex.DeleteByQueryAction; +import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; @@ -117,6 +120,30 @@ public void getTrainedModel(String modelId, ActionListener l listener::onFailure)); } + public void deleteTrainedModel(String modelId, ActionListener listener) { + DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false); + + request.indices(InferenceIndexConstants.INDEX_PATTERN); + QueryBuilder query = QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId); + request.setQuery(query); + request.setRefresh(true); + + executeAsyncWithOrigin(client, ML_ORIGIN, DeleteByQueryAction.INSTANCE, request, ActionListener.wrap(deleteResponse -> { + if (deleteResponse.getDeleted() == 0) { + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); + return; + } + listener.onResponse(true); + }, e -> { + if (e.getClass() == IndexNotFoundException.class) { + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); + } else { + listener.onFailure(e); + } + })); + } private void parseInferenceDocLenientlyFromSource(BytesReference source, String modelId, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/InferenceAuditor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/InferenceAuditor.java new file mode 100644 index 0000000000000..dfce44af7c9a4 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/InferenceAuditor.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.notifications; + +import org.elasticsearch.client.Client; +import org.elasticsearch.xpack.core.common.notifications.AbstractAuditor; +import org.elasticsearch.xpack.core.ml.notifications.AuditorField; +import org.elasticsearch.xpack.core.ml.notifications.InferenceAuditMessage; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; + +public class InferenceAuditor extends AbstractAuditor { + + public InferenceAuditor(Client client, String nodeName) { + super(client, nodeName, AuditorField.NOTIFICATIONS_INDEX, ML_ORIGIN, InferenceAuditMessage::new); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestDeleteTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestDeleteTrainedModelAction.java new file mode 100644 index 0000000000000..e9675be4d29fd --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestDeleteTrainedModelAction.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; + +import static org.elasticsearch.rest.RestRequest.Method.DELETE; + +public class RestDeleteTrainedModelAction extends BaseRestHandler { + + public RestDeleteTrainedModelAction(RestController controller) { + controller.registerHandler( + DELETE, MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}", this); + } + + @Override + public String getName() { + return "ml_delete_trained_models_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + DeleteTrainedModelAction.Request request = new DeleteTrainedModelAction.Request(modelId); + return channel -> client.execute(DeleteTrainedModelAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java new file mode 100644 index 0000000000000..40ddd05827043 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.rest.inference; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.common.Strings; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; + +import static org.elasticsearch.rest.RestRequest.Method.GET; +import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request.ALLOW_NO_MATCH; + +public class RestGetTrainedModelsAction extends BaseRestHandler { + + public RestGetTrainedModelsAction(RestController controller) { + controller.registerHandler( + GET, MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}", this); + controller.registerHandler(GET, MachineLearning.BASE_PATH + "inference", this); + } + + @Override + public String getName() { + return "ml_get_trained_models_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + if (Strings.isNullOrEmpty(modelId)) { + modelId = MetaData.ALL; + } + GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId); + if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) { + request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), + restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); + } + request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources())); + return channel -> client.execute(GetTrainedModelsAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index c2cfbe4f15498..04dff88417abb 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -166,7 +166,7 @@ public void testInferMissingModel() { try { client().execute(InferModelAction.INSTANCE, request).actionGet(); } catch (ElasticsearchException ex) { - assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model, 0))); + assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model))); } } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_trained_model.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_trained_model.json new file mode 100644 index 0000000000000..edfc157646f91 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_trained_model.json @@ -0,0 +1,24 @@ +{ + "ml.delete_trained_model":{ + "documentation":{ + "url":"TODO" + }, + "stability":"experimental", + "url":{ + "paths":[ + { + "path":"/_ml/inference/{model_id}", + "methods":[ + "DELETE" + ], + "parts":{ + "model_id":{ + "type":"string", + "description":"The ID of the trained model to delete" + } + } + } + ] + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json new file mode 100644 index 0000000000000..481f8b25975bb --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json @@ -0,0 +1,48 @@ +{ + "ml.get_trained_models":{ + "documentation":{ + "url":"TODO" + }, + "stability":"experimental", + "url":{ + "paths":[ + { + "path":"/_ml/inference/{model_id}", + "methods":[ + "GET" + ], + "parts":{ + "model_id":{ + "type":"string", + "description":"The ID of the trained models to fetch" + } + } + }, + { + "path":"/_ml/inference", + "methods":[ + "GET" + ] + } + ] + }, + "params":{ + "allow_no_match":{ + "type":"boolean", + "required":false, + "description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)", + "default":true + }, + "from":{ + "type":"int", + "description":"skips a number of trained models", + "default":0 + }, + "size":{ + "type":"int", + "description":"specifies a max number of trained models to get", + "default":100 + } + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml new file mode 100644 index 0000000000000..a18b29487eac5 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -0,0 +1,110 @@ +--- +"Test get-all given no trained models exist": + + - do: + ml.get_trained_models: + model_id: "_all" + - match: { count: 0 } + - match: { trained_model_configs: [] } + + - do: + ml.get_trained_models: + model_id: "*" + - match: { count: 0 } + - match: { trained_model_configs: [] } + +--- +"Test get given missing trained model": + + - do: + catch: missing + ml.get_trained_models: + model_id: "missing-trained-model" +--- +"Test get given expression without matches and allow_no_match is false": + + - do: + catch: missing + ml.get_trained_models: + model_id: "missing-trained-model*" + allow_no_match: false + +--- +"Test get given expression without matches and allow_no_match is true": + + - do: + ml.get_trained_models: + model_id: "missing-trained-model*" + allow_no_match: true + - match: { count: 0 } + - match: { trained_model_configs: [] } +--- +"Test delete given unused trained model": + + - do: + index: + id: trained_model_config-unused-regression-model-0 + index: .ml-inference-000001 + body: > + { + "model_id": "unused-regression-model", + "created_by": "ml_tests", + "version": "8.0.0", + "description": "empty model for tests", + "create_time": 0, + "model_version": 0, + "model_type": "local" + } + - do: + indices.refresh: {} + + - do: + ml.delete_trained_model: + model_id: "unused-regression-model" + - match: { acknowledged: true } + +--- +"Test delete with missing model": + - do: + catch: missing + ml.delete_trained_model: + model_id: "missing-trained-model" + +--- +"Test delete given used trained model": + - do: + index: + id: trained_model_config-used-regression-model-0 + index: .ml-inference-000001 + body: > + { + "model_id": "used-regression-model", + "created_by": "ml_tests", + "version": "8.0.0", + "description": "empty model for tests", + "create_time": 0, + "model_version": 0, + "model_type": "local" + } + - do: + indices.refresh: {} + + - do: + ingest.put_pipeline: + id: "regression-model-pipeline" + body: > + { + "processors": [ + { + "inference" : { + "model_id" : "used-regression-model" + } + } + ] + } + - match: { acknowledged: true } + + - do: + catch: conflict + ml.delete_trained_model: + model_id: "used-regression-model"