diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningGetResultsIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningGetResultsIT.java index db7d9632a9c61..2612b4d976b11 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningGetResultsIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningGetResultsIT.java @@ -193,7 +193,7 @@ private void addModelSnapshotIndexRequests(BulkRequest bulkRequest) { @After public void deleteJob() throws IOException { - new MlTestStateCleaner(logger, highLevelClient().machineLearning()).clearMlMetadata(); + new MlTestStateCleaner(logger, highLevelClient()).clearMlMetadata(); } public void testGetModelSnapshots() throws IOException { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 73c16c71e4a9f..739f4d241a858 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -225,7 +225,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { @After public void cleanUp() throws IOException { - new MlTestStateCleaner(logger, highLevelClient().machineLearning()).clearMlMetadata(); + new MlTestStateCleaner(logger, highLevelClient()).clearMlMetadata(); } public void testPutJob() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MlTestStateCleaner.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MlTestStateCleaner.java index 4422b1cf4032e..e6ddcaef374d0 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MlTestStateCleaner.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MlTestStateCleaner.java @@ -8,16 +8,22 @@ package org.elasticsearch.client; import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.action.ingest.DeletePipelineRequest; +import org.elasticsearch.client.core.PageParams; import org.elasticsearch.client.ml.CloseJobRequest; import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.DeleteDatafeedRequest; import org.elasticsearch.client.ml.DeleteJobRequest; +import org.elasticsearch.client.ml.DeleteTrainedModelRequest; import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.GetDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedResponse; import org.elasticsearch.client.ml.GetJobRequest; import org.elasticsearch.client.ml.GetJobResponse; +import org.elasticsearch.client.ml.GetTrainedModelsRequest; +import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest; import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.datafeed.DatafeedConfig; @@ -25,26 +31,77 @@ import org.elasticsearch.client.ml.job.config.Job; import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** * Cleans up and ML resources created during tests */ public class MlTestStateCleaner { + private static final Set NOT_DELETED_TRAINED_MODELS = Collections.singleton("lang_ident_model_1"); private final Logger logger; private final MachineLearningClient mlClient; + private final RestHighLevelClient client; - public MlTestStateCleaner(Logger logger, MachineLearningClient mlClient) { + public MlTestStateCleaner(Logger logger, RestHighLevelClient client) { this.logger = logger; - this.mlClient = mlClient; + this.mlClient = client.machineLearning(); + this.client = client; } public void clearMlMetadata() throws IOException { + deleteAllTrainedModels(); deleteAllDatafeeds(); deleteAllJobs(); deleteAllDataFrameAnalytics(); } + @SuppressWarnings("unchecked") + private void deleteAllTrainedModels() throws IOException { + Set pipelinesWithModels = mlClient.getTrainedModelsStats( + new GetTrainedModelsStatsRequest("_all").setPageParams(new PageParams(0, 10_000)), RequestOptions.DEFAULT + ).getTrainedModelStats() + .stream() + .flatMap(stats -> { + Map ingestStats = stats.getIngestStats(); + if (ingestStats == null || ingestStats.isEmpty()) { + return Stream.empty(); + } + Map pipelines = (Map)ingestStats.get("pipelines"); + if (pipelines == null || pipelines.isEmpty()) { + return Stream.empty(); + } + return pipelines.keySet().stream(); + }) + .collect(Collectors.toSet()); + for (String pipelineId : pipelinesWithModels) { + try { + client.ingest().deletePipeline(new DeletePipelineRequest(pipelineId), RequestOptions.DEFAULT); + } catch (Exception ex) { + logger.warn(() -> new ParameterizedMessage("failed to delete pipeline [{}]", pipelineId), ex); + } + } + + mlClient.getTrainedModels( + GetTrainedModelsRequest.getAllTrainedModelConfigsRequest().setPageParams(new PageParams(0, 10_000)), + RequestOptions.DEFAULT) + .getTrainedModels() + .stream() + .filter(trainedModelConfig -> NOT_DELETED_TRAINED_MODELS.contains(trainedModelConfig.getModelId()) == false) + .forEach(config -> { + try { + mlClient.deleteTrainedModel(new DeleteTrainedModelRequest(config.getModelId()), RequestOptions.DEFAULT); + } catch (IOException ex) { + throw new UncheckedIOException(ex); + } + }); + } + private void deleteAllDatafeeds() throws IOException { stopAllDatafeeds(); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 4a623f751ceb3..cdd457403119e 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -242,7 +242,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { @After public void cleanUp() throws IOException { - new MlTestStateCleaner(logger, highLevelClient().machineLearning()).clearMlMetadata(); + new MlTestStateCleaner(logger, highLevelClient()).clearMlMetadata(); } public void testCreateJob() throws Exception { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/common/notifications/AbstractAuditor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/common/notifications/AbstractAuditor.java index 16de5d6fcadfc..c89a5be82793e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/common/notifications/AbstractAuditor.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/common/notifications/AbstractAuditor.java @@ -184,7 +184,11 @@ private XContentBuilder toXContentBuilder(ToXContent toXContent) { } } - private void writeBacklog() { + protected void clearBacklog() { + backlog = null; + } + + protected void writeBacklog() { assert backlog != null; if (backlog == null) { logger.error("Message back log has already been written"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetadata.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetadata.java index fc414cf25f37e..07ae6504814e0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetadata.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetadata.java @@ -51,8 +51,14 @@ public class MlMetadata implements XPackPlugin.XPackMetadataCustom { private static final ParseField JOBS_FIELD = new ParseField("jobs"); private static final ParseField DATAFEEDS_FIELD = new ParseField("datafeeds"); public static final ParseField UPGRADE_MODE = new ParseField("upgrade_mode"); - - public static final MlMetadata EMPTY_METADATA = new MlMetadata(Collections.emptySortedMap(), Collections.emptySortedMap(), false); + public static final ParseField RESET_MODE = new ParseField("reset_mode"); + + public static final MlMetadata EMPTY_METADATA = new MlMetadata( + Collections.emptySortedMap(), + Collections.emptySortedMap(), + false, + false + ); // This parser follows the pattern that metadata is parsed leniently (to allow for enhancements) public static final ObjectParser LENIENT_PARSER = new ObjectParser<>("ml_metadata", true, Builder::new); @@ -61,19 +67,21 @@ public class MlMetadata implements XPackPlugin.XPackMetadataCustom { LENIENT_PARSER.declareObjectArray(Builder::putDatafeeds, (p, c) -> DatafeedConfig.LENIENT_PARSER.apply(p, c).build(), DATAFEEDS_FIELD); LENIENT_PARSER.declareBoolean(Builder::isUpgradeMode, UPGRADE_MODE); - + LENIENT_PARSER.declareBoolean(Builder::isResetMode, RESET_MODE); } private final SortedMap jobs; private final SortedMap datafeeds; private final boolean upgradeMode; + private final boolean resetMode; private final GroupOrJobLookup groupOrJobLookup; - private MlMetadata(SortedMap jobs, SortedMap datafeeds, boolean upgradeMode) { + private MlMetadata(SortedMap jobs, SortedMap datafeeds, boolean upgradeMode, boolean resetMode) { this.jobs = Collections.unmodifiableSortedMap(jobs); this.datafeeds = Collections.unmodifiableSortedMap(datafeeds); this.groupOrJobLookup = new GroupOrJobLookup(jobs.values()); this.upgradeMode = upgradeMode; + this.resetMode = resetMode; } public Map getJobs() { @@ -105,6 +113,10 @@ public boolean isUpgradeMode() { return upgradeMode; } + public boolean isResetMode() { + return resetMode; + } + @Override public Version getMinimalSupportedVersion() { return Version.V_6_0_0_alpha1; @@ -144,6 +156,11 @@ public MlMetadata(StreamInput in) throws IOException { } else { this.upgradeMode = false; } + if (in.getVersion().onOrAfter(Version.V_7_13_0)) { + this.resetMode = in.readBoolean(); + } else { + this.resetMode = false; + } } @Override @@ -153,6 +170,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getVersion().onOrAfter(Version.V_6_7_0)) { out.writeBoolean(upgradeMode); } + if (out.getVersion().onOrAfter(Version.V_7_13_0)) { + out.writeBoolean(resetMode); + } } private static void writeMap(Map map, StreamOutput out) throws IOException { @@ -170,6 +190,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws mapValuesToXContent(JOBS_FIELD, jobs, builder, extendedParams); mapValuesToXContent(DATAFEEDS_FIELD, datafeeds, builder, extendedParams); builder.field(UPGRADE_MODE.getPreferredName(), upgradeMode); + builder.field(RESET_MODE.getPreferredName(), resetMode); return builder; } @@ -191,11 +212,13 @@ public static class MlMetadataDiff implements NamedDiff { final Diff> jobs; final Diff> datafeeds; final boolean upgradeMode; + final boolean resetMode; MlMetadataDiff(MlMetadata before, MlMetadata after) { this.jobs = DiffableUtils.diff(before.jobs, after.jobs, DiffableUtils.getStringKeySerializer()); this.datafeeds = DiffableUtils.diff(before.datafeeds, after.datafeeds, DiffableUtils.getStringKeySerializer()); this.upgradeMode = after.upgradeMode; + this.resetMode = after.resetMode; } public MlMetadataDiff(StreamInput in) throws IOException { @@ -208,6 +231,11 @@ public MlMetadataDiff(StreamInput in) throws IOException { } else { upgradeMode = false; } + if (in.getVersion().onOrAfter(Version.V_7_13_0)) { + resetMode = in.readBoolean(); + } else { + resetMode = false; + } } /** @@ -219,7 +247,7 @@ public MlMetadataDiff(StreamInput in) throws IOException { public Metadata.Custom apply(Metadata.Custom part) { TreeMap newJobs = new TreeMap<>(jobs.apply(((MlMetadata) part).jobs)); TreeMap newDatafeeds = new TreeMap<>(datafeeds.apply(((MlMetadata) part).datafeeds)); - return new MlMetadata(newJobs, newDatafeeds, upgradeMode); + return new MlMetadata(newJobs, newDatafeeds, upgradeMode, resetMode); } @Override @@ -229,6 +257,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getVersion().onOrAfter(Version.V_6_7_0)) { out.writeBoolean(upgradeMode); } + if (out.getVersion().onOrAfter(Version.V_7_13_0)) { + out.writeBoolean(resetMode); + } } @Override @@ -254,7 +285,8 @@ public boolean equals(Object o) { MlMetadata that = (MlMetadata) o; return Objects.equals(jobs, that.jobs) && Objects.equals(datafeeds, that.datafeeds) && - Objects.equals(upgradeMode, that.upgradeMode); + upgradeMode == that.upgradeMode && + resetMode == that.resetMode; } @Override @@ -264,7 +296,7 @@ public final String toString() { @Override public int hashCode() { - return Objects.hash(jobs, datafeeds, upgradeMode); + return Objects.hash(jobs, datafeeds, upgradeMode, resetMode); } public static class Builder { @@ -272,6 +304,11 @@ public static class Builder { private TreeMap jobs; private TreeMap datafeeds; private boolean upgradeMode; + private boolean resetMode; + + public static Builder from(@Nullable MlMetadata previous) { + return new Builder(previous); + } public Builder() { jobs = new TreeMap<>(); @@ -286,6 +323,7 @@ public Builder(@Nullable MlMetadata previous) { jobs = new TreeMap<>(previous.jobs); datafeeds = new TreeMap<>(previous.datafeeds); upgradeMode = previous.upgradeMode; + resetMode = previous.resetMode; } } @@ -353,8 +391,13 @@ public Builder isUpgradeMode(boolean upgradeMode) { return this; } + public Builder isResetMode(boolean resetMode) { + this.resetMode = resetMode; + return this; + } + public MlMetadata build() { - return new MlMetadata(jobs, datafeeds, upgradeMode); + return new MlMetadata(jobs, datafeeds, upgradeMode, resetMode); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CloseJobAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CloseJobAction.java index 978ca57d5ee03..83bdd3eae1e6f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CloseJobAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CloseJobAction.java @@ -111,44 +111,50 @@ public String getJobId() { return jobId; } - public void setJobId(String jobId) { + public Request setJobId(String jobId) { this.jobId = jobId; + return this; } public TimeValue getCloseTimeout() { return timeout; } - public void setCloseTimeout(TimeValue timeout) { + public Request setCloseTimeout(TimeValue timeout) { this.timeout = timeout; + return this; } public boolean isForce() { return force; } - public void setForce(boolean force) { + public Request setForce(boolean force) { this.force = force; + return this; } public boolean allowNoMatch() { return allowNoMatch; } - public void setAllowNoMatch(boolean allowNoMatch) { + public Request setAllowNoMatch(boolean allowNoMatch) { this.allowNoMatch = allowNoMatch; + return this; } public boolean isLocal() { return local; } - public void setLocal(boolean local) { + public Request setLocal(boolean local) { this.local = local; + return this; } public String[] getOpenJobIds() { return openJobIds; } - public void setOpenJobIds(String [] openJobIds) { + public Request setOpenJobIds(String[] openJobIds) { this.openJobIds = openJobIds; + return this; } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/SetResetModeAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/SetResetModeAction.java new file mode 100644 index 0000000000000..ade9e2337d4a6 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/SetResetModeAction.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class SetResetModeAction extends ActionType { + + public static final SetResetModeAction INSTANCE = new SetResetModeAction(); + public static final String NAME = "cluster:internal/xpack/ml/reset_mode"; + + private SetResetModeAction() { + super(NAME, AcknowledgedResponse::readFrom); + } + + public static class Request extends AcknowledgedRequest implements ToXContentObject { + + public static Request enabled() { + return new Request(true); + } + + public static Request disabled() { + return new Request(false); + } + + private final boolean enabled; + + private static final ParseField ENABLED = new ParseField("enabled"); + public static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, a -> new Request((Boolean)a[0])); + + static { + PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), ENABLED); + } + + Request(boolean enabled) { + this.enabled = enabled; + } + + public Request(StreamInput in) throws IOException { + super(in); + this.enabled = in.readBoolean(); + } + + public boolean isEnabled() { + return enabled; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeBoolean(enabled); + } + + @Override + public int hashCode() { + return Objects.hash(enabled); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || obj.getClass() != getClass()) { + return false; + } + Request other = (Request) obj; + return Objects.equals(enabled, other.enabled); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ENABLED.getPreferredName(), enabled); + builder.endObject(); + return builder; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsAction.java index 15358aca3343a..626a450d5e9a1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsAction.java @@ -95,8 +95,9 @@ public Request() { setTimeout(DEFAULT_TIMEOUT); } - public final void setId(String id) { + public final Request setId(String id) { this.id = ExceptionsHelper.requireNonNull(id, DataFrameAnalyticsConfig.ID); + return this; } public String getId() { @@ -107,16 +108,18 @@ public boolean allowNoMatch() { return allowNoMatch; } - public void setAllowNoMatch(boolean allowNoMatch) { + public Request setAllowNoMatch(boolean allowNoMatch) { this.allowNoMatch = allowNoMatch; + return this; } public boolean isForce() { return force; } - public void setForce(boolean force) { + public Request setForce(boolean force) { this.force = force; + return this; } @Nullable diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDatafeedAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDatafeedAction.java index f9eb9cd56d9ef..445c347d29cf6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDatafeedAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDatafeedAction.java @@ -109,24 +109,27 @@ public TimeValue getStopTimeout() { return stopTimeout; } - public void setStopTimeout(TimeValue stopTimeout) { + public Request setStopTimeout(TimeValue stopTimeout) { this.stopTimeout = ExceptionsHelper.requireNonNull(stopTimeout, TIMEOUT.getPreferredName()); + return this; } public boolean isForce() { return force; } - public void setForce(boolean force) { + public Request setForce(boolean force) { this.force = force; + return this; } public boolean allowNoMatch() { return allowNoMatch; } - public void setAllowNoMatch(boolean allowNoMatch) { + public Request setAllowNoMatch(boolean allowNoMatch) { this.allowNoMatch = allowNoMatch; + return this; } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/SetResetModeActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/SetResetModeActionRequestTests.java new file mode 100644 index 0000000000000..f1dbf1e970ef9 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/SetResetModeActionRequestTests.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.SetResetModeAction.Request; + +public class SetResetModeActionRequestTests extends AbstractSerializingTestCase { + + @Override + protected Request createTestInstance() { + return new Request(randomBoolean()); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } + + @Override + protected Request doParseInstance(XContentParser parser) { + return Request.PARSER.apply(parser, null); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/integration/MlRestTestStateCleaner.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/integration/MlRestTestStateCleaner.java index 5a74298350f2c..bd69fed2febf5 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/integration/MlRestTestStateCleaner.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/integration/MlRestTestStateCleaner.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.integration; import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.client.Request; import org.elasticsearch.client.Response; import org.elasticsearch.client.RestClient; @@ -18,6 +19,8 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; + public class MlRestTestStateCleaner { @@ -40,6 +43,23 @@ public void clearMlMetadata() throws IOException { @SuppressWarnings("unchecked") private void deleteAllTrainedModels() throws IOException { + final Request getAllTrainedModelStats = new Request("GET", "/_ml/trained_models/_stats"); + getAllTrainedModelStats.addParameter("size", "10000"); + final Response trainedModelsStatsResponse = adminClient.performRequest(getAllTrainedModelStats); + + final List> pipelines = (List>) XContentMapValues.extractValue( + "trained_model_stats.ingest.pipelines", + ESRestTestCase.entityAsMap(trainedModelsStatsResponse) + ); + Set pipelineIds = pipelines.stream().flatMap(m -> m.keySet().stream()).collect(Collectors.toSet()); + for (String pipelineId : pipelineIds) { + try { + adminClient.performRequest(new Request("DELETE", "/_ingest/pipeline/" + pipelineId)); + } catch (Exception ex) { + logger.warn(() -> new ParameterizedMessage("failed to delete pipeline [{}]", pipelineId), ex); + } + } + final Request getTrainedModels = new Request("GET", "/_ml/trained_models"); getTrainedModels.addParameter("size", "10000"); final Response trainedModelsResponse = adminClient.performRequest(getTrainedModels); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index f66621e58b742..2f6a92c5a730d 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -85,18 +85,18 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { - private static final String BOOLEAN_FIELD = "boolean-field"; - private static final String NUMERICAL_FIELD = "numerical-field"; - private static final String DISCRETE_NUMERICAL_FIELD = "discrete-numerical-field"; - private static final String TEXT_FIELD = "text-field"; - private static final String KEYWORD_FIELD = "keyword-field"; - private static final String NESTED_FIELD = "outer-field.inner-field"; - private static final String ALIAS_TO_KEYWORD_FIELD = "alias-to-keyword-field"; - private static final String ALIAS_TO_NESTED_FIELD = "alias-to-nested-field"; - private static final List BOOLEAN_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(false, true)); - private static final List NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0)); - private static final List DISCRETE_NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(10, 20)); - private static final List KEYWORD_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList("cat", "dog")); + static final String BOOLEAN_FIELD = "boolean-field"; + static final String NUMERICAL_FIELD = "numerical-field"; + static final String DISCRETE_NUMERICAL_FIELD = "discrete-numerical-field"; + static final String TEXT_FIELD = "text-field"; + static final String KEYWORD_FIELD = "keyword-field"; + static final String NESTED_FIELD = "outer-field.inner-field"; + static final String ALIAS_TO_KEYWORD_FIELD = "alias-to-keyword-field"; + static final String ALIAS_TO_NESTED_FIELD = "alias-to-nested-field"; + static final List BOOLEAN_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(false, true)); + static final List NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0)); + static final List DISCRETE_NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(10, 20)); + static final List KEYWORD_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList("cat", "dog")); private String jobId; private String sourceIndex; @@ -958,7 +958,7 @@ private void initialize(String jobId, boolean isDatastream) { } } - private static void createIndex(String index, boolean isDatastream) { + static void createIndex(String index, boolean isDatastream) { String mapping = "{\n" + " \"properties\": {\n" + " \"@timestamp\": {\n" + @@ -1010,7 +1010,7 @@ private static void createIndex(String index, boolean isDatastream) { } } - private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows, String dependentVariable) { + static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows, String dependentVariable) { BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); for (int i = 0; i < numTrainingRows; i++) { diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java index 5623e6dcce09b..d50802b478f50 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java @@ -200,8 +200,8 @@ protected PreviewDataFrameAnalyticsAction.Response previewDataFrame(String id) { ).actionGet(); } - protected static DataFrameAnalyticsConfig buildAnalytics(String id, String sourceIndex, String destIndex, - @Nullable String resultsField, DataFrameAnalysis analysis) throws Exception { + static DataFrameAnalyticsConfig buildAnalytics(String id, String sourceIndex, String destIndex, + @Nullable String resultsField, DataFrameAnalysis analysis) throws Exception { return buildAnalytics(id, sourceIndex, destIndex, resultsField, analysis, QueryBuilders.matchAllQuery()); } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java new file mode 100644 index 0000000000000..875b20f4af752 --- /dev/null +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java @@ -0,0 +1,197 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.action.admin.cluster.snapshots.features.ResetFeatureStateAction; +import org.elasticsearch.action.admin.cluster.snapshots.features.ResetFeatureStateRequest; +import org.elasticsearch.action.ingest.DeletePipelineAction; +import org.elasticsearch.action.ingest.DeletePipelineRequest; +import org.elasticsearch.action.ingest.PutPipelineAction; +import org.elasticsearch.action.ingest.PutPipelineRequest; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.MlMetadata; +import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; +import org.elasticsearch.xpack.core.ml.job.config.Job; +import org.elasticsearch.xpack.core.ml.job.config.JobState; +import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.DataCounts; +import org.junit.After; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor.Factory.countNumberInferenceProcessors; +import static org.elasticsearch.xpack.ml.integration.ClassificationIT.KEYWORD_FIELD; +import static org.elasticsearch.xpack.ml.integration.MlNativeDataFrameAnalyticsIntegTestCase.buildAnalytics; +import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.createDatafeed; +import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.createScheduledJob; +import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.getDataCounts; +import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.indexDocs; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.emptyArray; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase { + + private final Set createdPipelines = new HashSet<>(); + + @After + public void cleanup() throws Exception { + cleanUp(); + for (String pipeline : createdPipelines) { + try { + client().execute(DeletePipelineAction.INSTANCE, new DeletePipelineRequest(pipeline)).actionGet(); + } catch (Exception ex) { + logger.warn(() -> new ParameterizedMessage("error cleaning up pipeline [{}]", pipeline), ex); + } + } + } + + public void testMLFeatureReset() throws Exception { + startRealtime("feature_reset_anomaly_job"); + startDataFrameJob("feature_reset_data_frame_analytics_job"); + putTrainedModelIngestPipeline("feature_reset_inference_pipeline"); + createdPipelines.add("feature_reset_inference_pipeline"); + for(int i = 0; i < 100; i ++) { + indexDocForInference("feature_reset_inference_pipeline"); + } + client().execute(DeletePipelineAction.INSTANCE, new DeletePipelineRequest("feature_reset_inference_pipeline")).actionGet(); + createdPipelines.remove("feature_reset_inference_pipeline"); + + assertBusy(() -> + assertThat(countNumberInferenceProcessors(client().admin().cluster().prepareState().get().getState()), equalTo(0)) + ); + client().execute( + ResetFeatureStateAction.INSTANCE, + new ResetFeatureStateRequest() + ).actionGet(); + assertBusy(() -> assertThat(client().admin().indices().prepareGetIndex().addIndices(".ml*").get().indices(), emptyArray())); + assertThat(isResetMode(), is(false)); + } + + public void testMLFeatureResetFailureDueToPipelines() throws Exception { + putTrainedModelIngestPipeline("feature_reset_failure_inference_pipeline"); + createdPipelines.add("feature_reset_failure_inference_pipeline"); + Exception ex = expectThrows(Exception.class, () -> client().execute( + ResetFeatureStateAction.INSTANCE, + new ResetFeatureStateRequest() + ).actionGet()); + assertThat( + ex.getMessage(), + containsString( + "Unable to reset machine learning feature as there are ingest pipelines still referencing trained machine learning models" + ) + ); + client().execute(DeletePipelineAction.INSTANCE, new DeletePipelineRequest("feature_reset_failure_inference_pipeline")).actionGet(); + createdPipelines.remove("feature_reset_failure_inference_pipeline"); + assertThat(isResetMode(), is(false)); + } + + private boolean isResetMode() { + ClusterState state = client().admin().cluster().prepareState().get().getState(); + return MlMetadata.getMlMetadata(state).isResetMode(); + } + + private void startDataFrameJob(String jobId) throws Exception { + String sourceIndex = jobId + "-src"; + String destIndex = jobId + "-dest"; + ClassificationIT.createIndex(sourceIndex, false); + ClassificationIT.indexData(sourceIndex, 300, 50, KEYWORD_FIELD); + + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, + new Classification( + KEYWORD_FIELD, + BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(), + null, + null, + null, + null, + null, + null, + null)); + PutDataFrameAnalyticsAction.Request request = new PutDataFrameAnalyticsAction.Request(config); + client().execute(PutDataFrameAnalyticsAction.INSTANCE, request).actionGet(); + + client().execute(StartDataFrameAnalyticsAction.INSTANCE, new StartDataFrameAnalyticsAction.Request(jobId)); + } + + private void startRealtime(String jobId) throws Exception { + client().admin().indices().prepareCreate("data") + .addMapping("type", "time", "type=date") + .get(); + long numDocs1 = randomIntBetween(32, 2048); + long now = System.currentTimeMillis(); + long lastWeek = now - 604800000; + indexDocs(logger, "data", numDocs1, lastWeek, now); + + Job.Builder job = createScheduledJob(jobId); + registerJob(job); + putJob(job); + openJob(job.getId()); + assertBusy(() -> assertEquals(getJobStats(job.getId()).get(0).getState(), JobState.OPENED)); + + DatafeedConfig datafeedConfig = createDatafeed(job.getId() + "-datafeed", job.getId(), Collections.singletonList("data")); + registerDatafeed(datafeedConfig); + putDatafeed(datafeedConfig); + startDatafeed(datafeedConfig.getId(), 0L, null); + assertBusy(() -> { + DataCounts dataCounts = getDataCounts(job.getId()); + assertThat(dataCounts.getProcessedRecordCount(), is(equalTo(numDocs1))); + assertThat(dataCounts.getOutOfOrderTimeStampCount(), is(equalTo(0L))); + }); + + long numDocs2 = randomIntBetween(2, 64); + now = System.currentTimeMillis(); + indexDocs(logger, "data", numDocs2, now + 5000, now + 6000); + assertBusy(() -> { + DataCounts dataCounts = getDataCounts(job.getId()); + assertThat(dataCounts.getProcessedRecordCount(), is(equalTo(numDocs1 + numDocs2))); + assertThat(dataCounts.getOutOfOrderTimeStampCount(), is(equalTo(0L))); + }, 30, TimeUnit.SECONDS); + } + + private void putTrainedModelIngestPipeline(String pipelineId) throws Exception { + client().execute( + PutPipelineAction.INSTANCE, + new PutPipelineRequest( + pipelineId, + new BytesArray( + "{\n" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"inference_config\": {\"classification\":{}},\n" + + " \"model_id\": \"lang_ident_model_1\",\n" + + " \"field_map\": {}\n" + + " }\n" + + " }\n" + + " ]\n" + + " }" + ), + XContentType.JSON + ) + ).actionGet(); + } + + private void indexDocForInference(String pipelineId) { + client().prepareIndex("type", "foo") + .setPipeline(pipelineId) + .setSource("{\"text\": \"this is some plain text.\"}", XContentType.JSON) + .get(); + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 4b0d7d3701168..d920d2d6f95ac 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -13,8 +13,10 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse; import org.elasticsearch.action.admin.cluster.snapshots.features.ResetFeatureStateResponse; import org.elasticsearch.action.support.ActionFilter; +import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.Client; import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.client.node.NodeClient; @@ -128,6 +130,7 @@ import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAliasAction; import org.elasticsearch.xpack.core.ml.action.RevertModelSnapshotAction; +import org.elasticsearch.xpack.core.ml.action.SetResetModeAction; import org.elasticsearch.xpack.core.ml.action.SetUpgradeModeAction; import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction; @@ -149,6 +152,7 @@ import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.notifications.NotificationsIndex; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.template.TemplateUtils; import org.elasticsearch.xpack.ml.action.TransportCloseJobAction; import org.elasticsearch.xpack.ml.action.TransportDeleteCalendarAction; @@ -203,6 +207,7 @@ import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelAction; import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelAliasAction; import org.elasticsearch.xpack.ml.action.TransportRevertModelSnapshotAction; +import org.elasticsearch.xpack.ml.action.TransportSetResetModeAction; import org.elasticsearch.xpack.ml.action.TransportSetUpgradeModeAction; import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.action.TransportStartDatafeedAction; @@ -358,6 +363,7 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields.RESULTS_INDEX_PREFIX; import static org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields.STATE_INDEX_PREFIX; +import static org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor.Factory.countNumberInferenceProcessors; public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, @@ -1027,7 +1033,8 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(UpgradeJobModelSnapshotAction.INSTANCE, TransportUpgradeJobModelSnapshotAction.class), new ActionHandler<>(PutTrainedModelAliasAction.INSTANCE, TransportPutTrainedModelAliasAction.class), new ActionHandler<>(DeleteTrainedModelAliasAction.INSTANCE, TransportDeleteTrainedModelAliasAction.class), - new ActionHandler<>(PreviewDataFrameAnalyticsAction.INSTANCE, TransportPreviewDataFrameAnalyticsAction.class) + new ActionHandler<>(PreviewDataFrameAnalyticsAction.INSTANCE, TransportPreviewDataFrameAnalyticsAction.class), + new ActionHandler<>(SetResetModeAction.INSTANCE, TransportSetResetModeAction.class) ); } @@ -1198,28 +1205,89 @@ public String getFeatureDescription() { return "Provides anomaly detection and forecasting functionality"; } - @Override public void cleanUpFeature( + @Override + public void cleanUpFeature( ClusterService clusterService, Client client, - ActionListener listener) { + ActionListener finalListener) { + logger.info("Starting machine learning feature reset"); + + ActionListener unsetResetModeListener = ActionListener.wrap( + success -> client.execute(SetResetModeAction.INSTANCE, SetResetModeAction.Request.disabled(), ActionListener.wrap( + resetSuccess -> finalListener.onResponse(success), + resetFailure -> { + logger.error("failed to disable reset mode after state otherwise successful machine learning reset", resetFailure); + finalListener.onFailure( + ExceptionsHelper.serverError( + "failed to disable reset mode after state otherwise successful machine learning reset", + resetFailure + ) + ); + }) + ), + failure -> client.execute(SetResetModeAction.INSTANCE, SetResetModeAction.Request.disabled(), ActionListener.wrap( + resetSuccess -> finalListener.onFailure(failure), + resetFailure -> { + logger.error("failed to disable reset mode after state clean up failure", resetFailure); + finalListener.onFailure(failure); + }) + ) + ); Map results = new ConcurrentHashMap<>(); + ActionListener afterWaitingForTasks = ActionListener.wrap( + listTasksResponse -> { + listTasksResponse.rethrowFailures("Waiting for indexing requests for .ml-* indices"); + if (results.values().stream().allMatch(b -> b)) { + // Call into the original listener to clean up the indices + SystemIndexPlugin.super.cleanUpFeature(clusterService, client, unsetResetModeListener); + } else { + final List failedComponents = results.entrySet().stream() + .filter(result -> result.getValue() == false) + .map(Map.Entry::getKey) + .collect(Collectors.toList()); + unsetResetModeListener.onFailure( + new RuntimeException("Some machine learning components failed to reset: " + failedComponents) + ); + } + }, + unsetResetModeListener::onFailure + ); + ActionListener afterDataframesStopped = ActionListener.wrap(dataFrameStopResponse -> { // Handle the response results.put("data_frame/analytics", dataFrameStopResponse.isStopped()); - if (results.values().stream().allMatch(b -> b)) { - // Call into the original listener to clean up the indices - SystemIndexPlugin.super.cleanUpFeature(clusterService, client, listener); + client.admin() + .cluster() + .prepareListTasks() + .setActions("xpack/ml/*") + .setWaitForCompletion(true) + .execute(ActionListener.wrap( + listMlTasks -> { + listMlTasks.rethrowFailures("Waiting for machine learning tasks"); + client.admin() + .cluster() + .prepareListTasks() + .setActions("indices:data/write/bulk") + .setDetailed(true) + .setWaitForCompletion(true) + .setDescriptions("*.ml-*") + .execute(afterWaitingForTasks); + }, + unsetResetModeListener::onFailure + )); } else { final List failedComponents = results.entrySet().stream() .filter(result -> result.getValue() == false) .map(Map.Entry::getKey) .collect(Collectors.toList()); - listener.onFailure(new RuntimeException("Some components failed to reset: " + failedComponents)); + unsetResetModeListener.onFailure( + new RuntimeException("Some machine learning components failed to reset: " + failedComponents) + ); } - }, listener::onFailure); + }, unsetResetModeListener::onFailure); ActionListener afterAnomalyDetectionClosed = ActionListener.wrap(closeJobResponse -> { @@ -1227,11 +1295,11 @@ public String getFeatureDescription() { results.put("anomaly_detectors", closeJobResponse.isClosed()); // Stop data frame analytics - StopDataFrameAnalyticsAction.Request stopDataFramesReq = new StopDataFrameAnalyticsAction.Request("_all"); - stopDataFramesReq.setForce(true); - stopDataFramesReq.setAllowNoMatch(true); + StopDataFrameAnalyticsAction.Request stopDataFramesReq = new StopDataFrameAnalyticsAction.Request("_all") + .setForce(true) + .setAllowNoMatch(true); client.execute(StopDataFrameAnalyticsAction.INSTANCE, stopDataFramesReq, afterDataframesStopped); - }, listener::onFailure); + }, unsetResetModeListener::onFailure); // Close anomaly detection jobs ActionListener afterDataFeedsStopped = ActionListener.wrap(datafeedResponse -> { @@ -1239,19 +1307,45 @@ public String getFeatureDescription() { results.put("datafeeds", datafeedResponse.isStopped()); // Close anomaly detection jobs - CloseJobAction.Request closeJobsRequest = new CloseJobAction.Request(); - closeJobsRequest.setForce(true); - closeJobsRequest.setAllowNoMatch(true); - closeJobsRequest.setJobId("_all"); + CloseJobAction.Request closeJobsRequest = new CloseJobAction.Request() + .setForce(true) + .setAllowNoMatch(true) + .setJobId("_all"); client.execute(CloseJobAction.INSTANCE, closeJobsRequest, afterAnomalyDetectionClosed); - }, listener::onFailure); + }, unsetResetModeListener::onFailure); // Stop data feeds - StopDatafeedAction.Request stopDatafeedsReq = new StopDatafeedAction.Request("_all"); - stopDatafeedsReq.setAllowNoMatch(true); - stopDatafeedsReq.setForce(true); - client.execute(StopDatafeedAction.INSTANCE, stopDatafeedsReq, - afterDataFeedsStopped); + ActionListener pipelineValidation = ActionListener.wrap( + acknowledgedResponse -> { + StopDatafeedAction.Request stopDatafeedsReq = new StopDatafeedAction.Request("_all") + .setAllowNoMatch(true) + .setForce(true); + client.execute(StopDatafeedAction.INSTANCE, stopDatafeedsReq, + afterDataFeedsStopped); + }, + unsetResetModeListener::onFailure + ); + + // validate no pipelines are using machine learning models + ActionListener afterResetModeSet = ActionListener.wrap( + acknowledgedResponse -> { + int numberInferenceProcessors = countNumberInferenceProcessors(clusterService.state()); + if (numberInferenceProcessors > 0) { + unsetResetModeListener.onFailure( + new RuntimeException( + "Unable to reset machine learning feature as there are ingest pipelines " + + "still referencing trained machine learning models" + ) + ); + return; + } + pipelineValidation.onResponse(AcknowledgedResponse.of(true)); + }, + finalListener::onFailure + ); + + // Indicate that a reset is now in progress + client.execute(SetResetModeAction.INSTANCE, SetResetModeAction.Request.enabled(), afterResetModeSet); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlDailyMaintenanceService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlDailyMaintenanceService.java index 1e0be7a65c6e2..7c699af3945ba 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlDailyMaintenanceService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlDailyMaintenanceService.java @@ -151,6 +151,10 @@ private void triggerTasks() { LOGGER.warn("skipping scheduled [ML] maintenance tasks because upgrade mode is enabled"); return; } + if (MlMetadata.getMlMetadata(clusterService.state()).isResetMode()) { + LOGGER.warn("skipping scheduled [ML] maintenance tasks because machine learning feature reset is in progress"); + return; + } LOGGER.info("triggering scheduled [ML] maintenance tasks"); // Step 3: Log any error that could have happened diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportSetResetModeAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportSetResetModeAction.java new file mode 100644 index 0000000000000..31d87f556eb3f --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportSetResetModeAction.java @@ -0,0 +1,111 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.ml.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ElasticsearchTimeoutException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction; +import org.elasticsearch.cluster.AckedClusterStateUpdateTask; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.MlMetadata; +import org.elasticsearch.xpack.core.ml.action.SetResetModeAction; + + +public class TransportSetResetModeAction extends AcknowledgedTransportMasterNodeAction { + + private static final Logger logger = LogManager.getLogger(TransportSetResetModeAction.class); + private final ClusterService clusterService; + + @Inject + public TransportSetResetModeAction(TransportService transportService, ThreadPool threadPool, ClusterService clusterService, + ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver) { + super(SetResetModeAction.NAME, transportService, clusterService, threadPool, actionFilters, SetResetModeAction.Request::new, + indexNameExpressionResolver, ThreadPool.Names.SAME); + this.clusterService = clusterService; + } + + @Override + protected void masterOperation(SetResetModeAction.Request request, + ClusterState state, + ActionListener listener) throws Exception { + + // Noop, nothing for us to do, simply return fast to the caller + if (request.isEnabled() == MlMetadata.getMlMetadata(state).isResetMode()) { + logger.debug("Reset mode noop"); + listener.onResponse(AcknowledgedResponse.TRUE); + return; + } + + logger.debug( + () -> new ParameterizedMessage( + "Starting to set [reset_mode] to [{}] from [{}]", request.isEnabled(), MlMetadata.getMlMetadata(state).isResetMode() + ) + ); + + ActionListener wrappedListener = ActionListener.wrap( + r -> { + logger.debug("Completed reset mode request"); + listener.onResponse(r); + }, + e -> { + logger.debug("Completed reset mode request but with failure", e); + listener.onFailure(e); + } + ); + + ActionListener clusterStateUpdateListener = ActionListener.wrap( + acknowledgedResponse -> { + if (acknowledgedResponse.isAcknowledged() == false) { + logger.info("Cluster state update is NOT acknowledged"); + wrappedListener.onFailure(new ElasticsearchTimeoutException("Unknown error occurred while updating cluster state")); + return; + } + wrappedListener.onResponse(acknowledgedResponse); + }, + wrappedListener::onFailure + ); + + clusterService.submitStateUpdateTask("ml-set-reset-mode", + new AckedClusterStateUpdateTask(request, clusterStateUpdateListener) { + + @Override + protected AcknowledgedResponse newResponse(boolean acknowledged) { + logger.trace(() -> new ParameterizedMessage("Cluster update response built: {}", acknowledged)); + return AcknowledgedResponse.of(acknowledged); + } + + @Override + public ClusterState execute(ClusterState currentState) { + logger.trace("Executing cluster state update"); + MlMetadata.Builder builder = MlMetadata.Builder + .from(currentState.metadata().custom(MlMetadata.TYPE)) + .isResetMode(request.isEnabled()); + ClusterState.Builder newState = ClusterState.builder(currentState); + newState.metadata(Metadata.builder(currentState.getMetadata()).putCustom(MlMetadata.TYPE, builder.build()).build()); + return newState.build(); + } + }); + } + + @Override + protected ClusterBlockException checkBlock(SetResetModeAction.Request request, ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index 761154fa54341..32857e788bd5d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -198,15 +198,17 @@ public Factory(Client client, ClusterService clusterService, Settings settings) @Override public void accept(ClusterState state) { minNodeVersion = state.nodes().getMinNodeVersion(); + currentInferenceProcessors = countNumberInferenceProcessors(state); + } + + public static int countNumberInferenceProcessors(ClusterState state) { Metadata metadata = state.getMetadata(); if (metadata == null) { - currentInferenceProcessors = 0; - return; + return 0; } IngestMetadata ingestMetadata = metadata.custom(IngestMetadata.TYPE); if (ingestMetadata == null) { - currentInferenceProcessors = 0; - return; + return 0; } int count = 0; @@ -219,14 +221,14 @@ public void accept(ClusterState state) { count += numInferenceProcessors(entry.getKey(), entry.getValue()); } } - // We cannot throw any exception here. It might break other pipelines. + // We cannot throw any exception here. It might break other pipelines. } catch (Exception ex) { logger.debug( () -> new ParameterizedMessage("failed gathering processors for pipeline [{}]", configuration.getId()), ex); } } - currentInferenceProcessors = count; + return count; } @SuppressWarnings("unchecked") diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/AbstractMlAuditor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/AbstractMlAuditor.java new file mode 100644 index 0000000000000..3072c02786462 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/AbstractMlAuditor.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.notifications; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.xpack.core.common.notifications.AbstractAuditMessage; +import org.elasticsearch.xpack.core.common.notifications.AbstractAuditMessageFactory; +import org.elasticsearch.xpack.core.common.notifications.AbstractAuditor; +import org.elasticsearch.xpack.core.ml.MlMetadata; +import org.elasticsearch.xpack.core.ml.notifications.NotificationsIndex; +import org.elasticsearch.xpack.ml.MlIndexTemplateRegistry; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; + +abstract class AbstractMlAuditor extends AbstractAuditor { + + private static final Logger logger = LogManager.getLogger(AbstractMlAuditor.class); + private volatile boolean isResetMode; + + protected AbstractMlAuditor(Client client, AbstractAuditMessageFactory messageFactory, ClusterService clusterService) { + super( + new OriginSettingClient(client, ML_ORIGIN), + NotificationsIndex.NOTIFICATIONS_INDEX, + MlIndexTemplateRegistry.NOTIFICATIONS_TEMPLATE, + clusterService.getNodeName(), + messageFactory, + clusterService + ); + clusterService.addListener(event -> { + if (event.metadataChanged()) { + setResetMode(MlMetadata.getMlMetadata(event.state()).isResetMode()); + } + }); + } + + private void setResetMode(boolean value) { + isResetMode = value; + } + + @Override + protected void writeBacklog() { + if (isResetMode) { + logger.trace("Skipped writing the audit message backlog as reset_mode is enabled"); + clearBacklog(); + } else { + super.writeBacklog(); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/AnomalyDetectionAuditor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/AnomalyDetectionAuditor.java index c2981335fe391..5c3079dd6c5f3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/AnomalyDetectionAuditor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/AnomalyDetectionAuditor.java @@ -7,21 +7,12 @@ package org.elasticsearch.xpack.ml.notifications; import org.elasticsearch.client.Client; -import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.xpack.core.common.notifications.AbstractAuditor; -import org.elasticsearch.xpack.core.ml.notifications.NotificationsIndex; import org.elasticsearch.xpack.core.ml.notifications.AnomalyDetectionAuditMessage; -import org.elasticsearch.xpack.ml.MlIndexTemplateRegistry; -import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; - -public class AnomalyDetectionAuditor extends AbstractAuditor { +public class AnomalyDetectionAuditor extends AbstractMlAuditor { public AnomalyDetectionAuditor(Client client, ClusterService clusterService) { - super(new OriginSettingClient(client, ML_ORIGIN), NotificationsIndex.NOTIFICATIONS_INDEX, - MlIndexTemplateRegistry.NOTIFICATIONS_TEMPLATE, - clusterService.getNodeName(), - AnomalyDetectionAuditMessage::new, clusterService); + super(client, AnomalyDetectionAuditMessage::new, clusterService); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/DataFrameAnalyticsAuditor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/DataFrameAnalyticsAuditor.java index de776154c6881..608f7876672d6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/DataFrameAnalyticsAuditor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/DataFrameAnalyticsAuditor.java @@ -7,21 +7,12 @@ package org.elasticsearch.xpack.ml.notifications; import org.elasticsearch.client.Client; -import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.xpack.core.common.notifications.AbstractAuditor; -import org.elasticsearch.xpack.core.ml.notifications.NotificationsIndex; import org.elasticsearch.xpack.core.ml.notifications.DataFrameAnalyticsAuditMessage; -import org.elasticsearch.xpack.ml.MlIndexTemplateRegistry; -import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; - -public class DataFrameAnalyticsAuditor extends AbstractAuditor { +public class DataFrameAnalyticsAuditor extends AbstractMlAuditor { public DataFrameAnalyticsAuditor(Client client, ClusterService clusterService) { - super(new OriginSettingClient(client, ML_ORIGIN), NotificationsIndex.NOTIFICATIONS_INDEX, - MlIndexTemplateRegistry.NOTIFICATIONS_TEMPLATE, - clusterService.getNodeName(), - DataFrameAnalyticsAuditMessage::new, clusterService); + super(client, DataFrameAnalyticsAuditMessage::new, clusterService); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/InferenceAuditor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/InferenceAuditor.java index 56851d1764526..fdcb86fd76400 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/InferenceAuditor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/notifications/InferenceAuditor.java @@ -7,19 +7,12 @@ package org.elasticsearch.xpack.ml.notifications; import org.elasticsearch.client.Client; -import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.xpack.core.common.notifications.AbstractAuditor; -import org.elasticsearch.xpack.core.ml.notifications.NotificationsIndex; import org.elasticsearch.xpack.core.ml.notifications.InferenceAuditMessage; -import org.elasticsearch.xpack.ml.MlIndexTemplateRegistry; -import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; - -public class InferenceAuditor extends AbstractAuditor { +public class InferenceAuditor extends AbstractMlAuditor { public InferenceAuditor(Client client, ClusterService clusterService) { - super(new OriginSettingClient(client, ML_ORIGIN), NotificationsIndex.NOTIFICATIONS_INDEX, - MlIndexTemplateRegistry.NOTIFICATIONS_TEMPLATE, clusterService.getNodeName(), InferenceAuditMessage::new, clusterService); + super(client, InferenceAuditMessage::new, clusterService); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java index fa62883e2f3e6..f5e3e977f9da7 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java @@ -9,8 +9,11 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.admin.cluster.snapshots.features.ResetFeatureStateResponse; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.TransportAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.settings.Settings; @@ -36,18 +39,19 @@ public class LocalStateMachineLearning extends LocalStateCompositeXPackPlugin { - public LocalStateMachineLearning(final Settings settings, final Path configPath) throws Exception { + private final MachineLearning mlPlugin; + public LocalStateMachineLearning(final Settings settings, final Path configPath) { super(settings, configPath); LocalStateMachineLearning thisVar = this; - MachineLearning plugin = new MachineLearning(settings, configPath){ + mlPlugin = new MachineLearning(settings, configPath){ @Override protected XPackLicenseState getLicenseState() { return thisVar.getLicenseState(); } }; - plugin.setCircuitBreaker(new NoopCircuitBreaker(TRAINED_MODEL_CIRCUIT_BREAKER_NAME)); + mlPlugin.setCircuitBreaker(new NoopCircuitBreaker(TRAINED_MODEL_CIRCUIT_BREAKER_NAME)); plugins.add(new Autoscaling()); - plugins.add(plugin); + plugins.add(mlPlugin); plugins.add(new Monitoring(settings) { @Override protected SSLService getSslService() { @@ -74,6 +78,15 @@ protected XPackLicenseState getLicenseState() { plugins.add(new MockedRollupPlugin()); } + @Override + public void cleanUpFeature( + ClusterService clusterService, + Client client, + ActionListener finalListener) { + mlPlugin.cleanUpFeature(clusterService, client, finalListener); + } + + /** * This is only required as we now have to have the GetRollupIndexCapsAction as a valid action in our node. * The MachineLearningLicenseTests attempt to create a datafeed referencing this LocalStateMachineLearning object. diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetadataTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetadataTests.java index d907157d9503d..f887c6e79f1cf 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetadataTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetadataTests.java @@ -56,7 +56,7 @@ protected MlMetadata createTestInstance() { builder.putJob(job, false); } } - return builder.build(); + return builder.isResetMode(randomBoolean()).isUpgradeMode(randomBoolean()).build(); } @Override @@ -81,6 +81,14 @@ protected NamedXContentRegistry xContentRegistry() { return new NamedXContentRegistry(searchModule.getNamedXContents()); } + public void testBuilderClone() { + for (int i = 0; i < NUMBER_OF_TEST_RUNS; i++) { + MlMetadata first = createTestInstance(); + MlMetadata cloned = MlMetadata.Builder.from(first).build(); + assertThat(cloned, equalTo(first)); + } + } + public void testPutJob() { Job job1 = buildJobBuilder("1").build(); Job job2 = buildJobBuilder("2").build(); @@ -146,6 +154,8 @@ private static MlMetadata.Builder newMlMetadataWithJobs(String... jobIds) { protected MlMetadata mutateInstance(MlMetadata instance) { Map jobs = instance.getJobs(); Map datafeeds = instance.getDatafeeds(); + boolean isUpgrade = instance.isUpgradeMode(); + boolean isReset = instance.isResetMode(); MlMetadata.Builder metadataBuilder = new MlMetadata.Builder(); for (Map.Entry entry : jobs.entrySet()) { @@ -155,7 +165,7 @@ protected MlMetadata mutateInstance(MlMetadata instance) { metadataBuilder.putDatafeed(entry.getValue(), Collections.emptyMap(), xContentRegistry()); } - switch (between(0, 1)) { + switch (between(0, 3)) { case 0: metadataBuilder.putJob(JobTests.createRandomizedJob(), true); break; @@ -175,6 +185,12 @@ protected MlMetadata mutateInstance(MlMetadata instance) { metadataBuilder.putJob(randomJob, false); metadataBuilder.putDatafeed(datafeedConfig, Collections.emptyMap(), xContentRegistry()); break; + case 2: + metadataBuilder.isUpgradeMode(isUpgrade == false); + break; + case 3: + metadataBuilder.isResetMode(isReset == false); + break; default: throw new AssertionError("Illegal randomisation branch"); } diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index c175754c3416b..984f919e50006 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -212,6 +212,7 @@ public class Constants { "cluster:internal/xpack/ml/job/finalize_job_execution", "cluster:internal/xpack/ml/job/kill/process", "cluster:internal/xpack/ml/job/update/process", + "cluster:internal/xpack/ml/reset_mode", "cluster:monitor/allocation/explain", "cluster:monitor/async_search/status", "cluster:monitor/ccr/follow_info",